Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IOMMUFD support #56

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/nvpci/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ func (m *MockNvpci) AddMockA100(address string, numaNode int, sriov *SriovInfo)
return err
}

vfioDev := filepath.Join(deviceDir, "vfio-dev")
vfioFD := filepath.Join(vfioDev, "vfio8")
err = os.MkdirAll(vfioFD, 0755)
if err != nil {
return err
}
Comment on lines +87 to +92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is there a benefit in having two variables?

Suggested change
vfioDev := filepath.Join(deviceDir, "vfio-dev")
vfioFD := filepath.Join(vfioDev, "vfio8")
err = os.MkdirAll(vfioFD, 0755)
if err != nil {
return err
}
vfioFD := filepath.Join(deviceDir, "vfio-dev", "vfio8)
err = os.MkdirAll(vfioFD, 0755)
if err != nil {
return err
}


iommuGroup := 20
_, err = os.Create(filepath.Join(deviceDir, strconv.Itoa(iommuGroup)))
if err != nil {
Expand Down
32 changes: 32 additions & 0 deletions pkg/nvpci/nvpci.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type NvidiaPCIDevice struct {
Device uint16
DeviceName string
Driver string
IommuFD int
IommuGroup int
NumaNode int
Config *ConfigSpace
Expand Down Expand Up @@ -290,6 +291,11 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi
return nil, fmt.Errorf("unable to detect IOMMU group for %s: %w", address, err)
}

iommuFD, err := getIOMMUFD(devicePath)
if err != nil {
return nil, fmt.Errorf("unable to get IOMMU FD for %s: %w", address, err)
}

numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
Expand Down Expand Up @@ -373,6 +379,7 @@ func (p *nvpci) getGPUByPciBusID(address string, cache map[string]*NvidiaPCIDevi
Class: uint32(classID),
Device: uint16(deviceID),
Driver: driver,
IommuFD: int(iommuFD),
IommuGroup: int(iommuGroup),
NumaNode: int(numaNode),
Config: config,
Expand Down Expand Up @@ -521,6 +528,31 @@ func getDriver(devicePath string) (string, error) {
return "", err
}

// /sys/bus/pci/devices/0000:df:00.0/vfio-dev/vfio0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a more descriptive docstring.

func getIOMMUFD(devicePath string) (int, error) {
vfioDevDir := filepath.Join(devicePath, "vfio-dev")
// Read the directory; expect exactly one entry
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't check for exactly one entry below. Should we?

entries, err := os.ReadDir(vfioDevDir)
switch {
case os.IsNotExist(err):
return -1, nil
case err == nil:
if len(entries) == 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for the vfio prefix in the entries? Does using filepath.Glob make sense here if we're expecting a particular pattern?

return -1, fmt.Errorf("no VFIO device file found in %s", vfioDevDir)
}
name := entries[0].Name()
// Strip the "vfio" prefix to get the numeric part
idxStr := strings.TrimPrefix(name, "vfio")
idx, err := strconv.Atoi(idxStr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should we use ParseInt here and specify a 64 bit width as we do for the IOMMU Group?

if err != nil {
return -1, fmt.Errorf("failed to parse VFIO index from %q: %w", name, err)
}

return idx, nil
}
return -1, err
}

func getIOMMUGroup(devicePath string) (int64, error) {
var iommuGroup int64
iommu, err := filepath.EvalSymlinks(path.Join(devicePath, "iommu_group"))
Expand Down
27 changes: 27 additions & 0 deletions pkg/nvpci/nvpci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ func TestNvpci(t *testing.T) {
_, err = nvpci.GetGPUByIndex(1)
require.Error(t, err, "No error returned when getting GPU at invalid index")
}
func TestNvpciIOMMUFD(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we need a newline before this?

Suggested change
func TestNvpciIOMMUFD(t *testing.T) {
func TestNvpciIOMMUFD(t *testing.T) {

testCases := []struct {
Description string
IOMMUFD int
}{
{
Description: "IOMMUFD 8",
IOMMUFD: 8,
},
}

for _, tc := range testCases {
t.Run(tc.Description, func(t *testing.T) {
nvpci, err := NewMockNvpci()
require.Nil(t, err, "Error creating NewMockNvpci")
defer nvpci.Cleanup()

err = nvpci.AddMockA100("0000:80:05.1", 0, nil)
require.Nil(t, err, "Error adding Mock A100 device to MockNvpci")

devices, err := nvpci.GetGPUs()
require.Nil(t, err, "Error getting GPUs")
require.Equal(t, 1, len(devices), "Wrong number of GPU devices")
require.Equal(t, 8, devices[0].IommuFD, "Wrong IOMMUFD found for device")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
require.Equal(t, 8, devices[0].IommuFD, "Wrong IOMMUFD found for device")
require.Equal(t, tc.IOMMUFD, devices[0].IommuFD, "Wrong IOMMUFD found for device")

})
}
}

func TestNvpciNUMANode(t *testing.T) {
testCases := []struct {
Expand Down