From a29bfb16f321a5b123e4bf493e04c3386f93cd03 Mon Sep 17 00:00:00 2001 From: Cameron Gutman Date: Mon, 4 Mar 2024 18:43:16 -0600 Subject: [PATCH] Fix mismatched case and unhandled exception in open_drm_fd_for_cuda_device() --- src/platform/linux/cuda.cpp | 43 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/platform/linux/cuda.cpp b/src/platform/linux/cuda.cpp index 5b121c70691..33c939243f2 100644 --- a/src/platform/linux/cuda.cpp +++ b/src/platform/linux/cuda.cpp @@ -247,29 +247,38 @@ namespace cuda { // There's no way to directly go from CUDA to a DRM device, so we'll // use sysfs to look up the DRM device name from the PCI ID. - char pci_bus_id[13]; - CU_CHECK(cdf->cuDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), device), "Couldn't get CUDA device PCI bus ID"); - BOOST_LOG(debug) << "Found CUDA device with PCI bus ID: "sv << pci_bus_id; + std::array pci_bus_id; + CU_CHECK(cdf->cuDeviceGetPCIBusId(pci_bus_id.data(), pci_bus_id.size(), device), "Couldn't get CUDA device PCI bus ID"); + BOOST_LOG(debug) << "Found CUDA device with PCI bus ID: "sv << pci_bus_id.data(); + + // Linux uses lowercase hexadecimal while CUDA uses uppercase + std::transform(pci_bus_id.begin(), pci_bus_id.end(), pci_bus_id.begin(), + [](char c) { return std::tolower(c); }); // Look for the name of the primary node in sysfs - char sysfs_path[PATH_MAX]; - std::snprintf(sysfs_path, sizeof(sysfs_path), "/sys/bus/pci/devices/%s/drm", pci_bus_id); - fs::path sysfs_dir { sysfs_path }; - for (auto &entry : fs::directory_iterator { sysfs_dir }) { - auto file = entry.path().filename(); - auto filestring = file.generic_u8string(); - if (std::string_view { filestring }.substr(0, 4) != "card"sv) { - continue; - } + try { + char sysfs_path[PATH_MAX]; + std::snprintf(sysfs_path, sizeof(sysfs_path), "/sys/bus/pci/devices/%s/drm", pci_bus_id.data()); + fs::path sysfs_dir { sysfs_path }; + for (auto &entry : fs::directory_iterator { sysfs_dir }) { + auto file = entry.path().filename(); + auto filestring = file.generic_u8string(); + if (std::string_view { filestring }.substr(0, 4) != "card"sv) { + continue; + } - BOOST_LOG(debug) << "Found DRM primary node: "sv << filestring; + BOOST_LOG(debug) << "Found DRM primary node: "sv << filestring; - fs::path dri_path { "/dev/dri"sv }; - auto device_path = dri_path / file; - return open(device_path.c_str(), O_RDWR); + fs::path dri_path { "/dev/dri"sv }; + auto device_path = dri_path / file; + return open(device_path.c_str(), O_RDWR); + } + } + catch (const std::filesystem::filesystem_error &err) { + BOOST_LOG(error) << "Failed to read sysfs: "sv << err.what(); } - BOOST_LOG(error) << "Unable to find DRM device with PCI bus ID: "sv << pci_bus_id; + BOOST_LOG(error) << "Unable to find DRM device with PCI bus ID: "sv << pci_bus_id.data(); return -1; }