Skip to content

Commit a96fab3

Browse files
authored
Give the EP all devices so it can create OrtEpDevice instances with full knowledge (#24568)
### Description <!-- Describe your changes. --> GetDeviceInfoIfSupported -> GetSupportedDevices EP sees all devices so it can make decisions with full knowledge. This is mainly applicable to GPU EPs like WebGPU. EP has to iterate device and call CreateEpDevice for devices it supports. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 4adef01 commit a96fab3

18 files changed

+454
-137
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

+52-13
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi;
700700
struct OrtCompileApi;
701701
typedef struct OrtCompileApi OrtCompileApi;
702702

703+
struct OrtEpApi;
704+
typedef struct OrtEpApi OrtEpApi;
705+
703706
/** \brief The helper interface to get the right version of OrtApi
704707
*
705708
* Get a pointer to this structure through ::OrtGetApiBase
@@ -5186,6 +5189,12 @@ struct OrtApi {
51865189
* \since Version 1.22.
51875190
*/
51885191
const OrtHardwareDevice*(ORT_API_CALL* EpDevice_Device)(_In_ const OrtEpDevice* ep_device);
5192+
5193+
/** \brief Get the OrtEpApi instance for implementing an execution provider.
5194+
*
5195+
* \since Version 1.22.
5196+
*/
5197+
const OrtEpApi*(ORT_API_CALL* GetEpApi)();
51895198
};
51905199

51915200
/*
@@ -5889,6 +5898,29 @@ struct OrtCompileApi {
58895898
ORT_RUNTIME_CLASS(Ep);
58905899
ORT_RUNTIME_CLASS(EpFactory);
58915900

5901+
struct OrtEpApi {
5902+
/** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice.
5903+
* \param[in] ep_factory Execution provider factory that is creating the instance.
5904+
* \param[in] hardware_device Hardware device that the EP can utilize.
5905+
* \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used
5906+
* during execution provider selection and passed to CreateEp.
5907+
* ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
5908+
* \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added
5909+
* to the Session configuration options if the execution provider is selected.
5910+
* ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
5911+
* \param ep_device OrtExecutionDevice that is created.
5912+
*
5913+
* \since Version 1.22.
5914+
*/
5915+
ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory,
5916+
_In_ const OrtHardwareDevice* hardware_device,
5917+
_In_opt_ const OrtKeyValuePairs* ep_metadata,
5918+
_In_opt_ const OrtKeyValuePairs* ep_options,
5919+
_Out_ OrtEpDevice** ep_device);
5920+
5921+
ORT_CLASS_RELEASE(EpDevice);
5922+
};
5923+
58925924
/**
58935925
* \brief The OrtEp struct provides functions to implement for an execution provider.
58945926
* \since Version 1.22.
@@ -5993,33 +6025,40 @@ struct OrtEpFactory {
59936025
/** \brief Get information from the execution provider if it supports the OrtHardwareDevice.
59946026
*
59956027
* \param[in] this_ptr The OrtEpFactory instance.
5996-
* \param[in] device The OrtHardwareDevice instance.
5997-
* \param[out] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used
5998-
* during execution provider selection and/or CreateEp.
5999-
* \param[out] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added
6000-
* to the Session configuration options if the execution provider is selected.
6028+
* Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice.
6029+
* \param[in] devices The OrtHardwareDevice instances that are available.
6030+
* \param[in] num_devices The number of OrtHardwareDevice instances.
6031+
* \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use.
6032+
* The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice
6033+
* instances to this pre-allocated array. ORT will take ownership of the values returned.
6034+
* i.e. usage is `ep_devices[0] = <ptr to OrtEpDevice created with OrtEpApi::CreateEpDevice>;`
6035+
* \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices.
6036+
* Current default is 8. This can be increased if needed.
6037+
* \param[out] num_ep_devices The number of EP devices added to ep_devices.
60016038
* \return true if the factory can create an execution provider that uses `device`.
60026039
*
60036040
* \note ORT will take ownership or ep_metadata and/or ep_options if they are not null.
60046041
*
60056042
* \since Version 1.22.
60066043
*/
6007-
bool(ORT_API_CALL* GetDeviceInfoIfSupported)(const OrtEpFactory* this_ptr,
6008-
_In_ const OrtHardwareDevice* device,
6009-
_Out_opt_ OrtKeyValuePairs** ep_metadata,
6010-
_Out_opt_ OrtKeyValuePairs** ep_options);
6044+
OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr,
6045+
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
6046+
_In_ size_t num_devices,
6047+
_Inout_ OrtEpDevice** ep_devices,
6048+
_In_ size_t max_ep_devices,
6049+
_Out_ size_t* num_ep_devices);
60116050

60126051
/** \brief Function to create an OrtEp instance for use in a Session.
60136052
*
60146053
* ORT will call ReleaseEp to release the instance when it is no longer needed.
60156054
*
60166055
* \param[in] this_ptr The OrtEpFactory instance.
60176056
* \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use.
6018-
* \param[in] ep_metadata_pairs Execution provider metadata that was returned in GetDeviceInfoIfSupported, for each
6057+
* \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each
60196058
* device.
60206059
* \param[in] num_devices The number of devices the execution provider was selected for.
60216060
* \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the
6022-
* session. This will include ep_options from GetDeviceInfoIfSupported as well as any
6061+
* session. This will include ep_options from GetSupportedDevices as well as any
60236062
* user provided overrides.
60246063
* Execution provider options will have been added with a prefix of 'ep.<ep name>.'.
60256064
* The OrtSessionOptions instance will NOT be valid after this call and should not be
@@ -6029,7 +6068,7 @@ struct OrtEpFactory {
60296068
*
60306069
* \snippet{doc} snippets.dox OrtStatus Return Value
60316070
*
6032-
* \since Version 1.22.
6071+
* \since Version <coming soon>. This is a placeholder.
60336072
*/
60346073
OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr,
60356074
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
@@ -6043,7 +6082,7 @@ struct OrtEpFactory {
60436082
* \param[in] this_ptr The OrtEpFactory instance.
60446083
* \param[in] ep The OrtEp instance to release.
60456084
*
6046-
* \since Version 1.22.
6085+
* \since Version <coming soon>. This is a placeholder.
60476086
*/
60486087
void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep);
60496088
};

include/onnxruntime/core/session/onnxruntime_cxx_api.h

+33-1
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ inline const OrtCompileApi& GetCompileApi() {
172172
return *api;
173173
}
174174

175+
/// <summary>
176+
/// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider.
177+
/// </summary>
178+
/// <returns>ORT C EP API reference</returns>
179+
inline const OrtEpApi& GetEpApi() {
180+
auto* api = GetApi().GetEpApi();
181+
if (api == nullptr) {
182+
// minimal build
183+
ORT_CXX_API_THROW("EP API is not available in this build", ORT_FAIL);
184+
}
185+
186+
return *api;
187+
}
188+
175189
/** \brief IEEE 754 half-precision floating point data type
176190
*
177191
* \details This struct is used for converting float to float16 and back
@@ -561,6 +575,7 @@ ORT_DEFINE_RELEASE(Graph);
561575
ORT_DEFINE_RELEASE(Model);
562576
ORT_DEFINE_RELEASE(KeyValuePairs)
563577
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
578+
ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
564579

565580
#undef ORT_DEFINE_RELEASE
566581
#undef ORT_DEFINE_RELEASE_FROM_API_STRUCT
@@ -763,10 +778,16 @@ struct KeyValuePairs : detail::KeyValuePairsImpl<OrtKeyValuePairs> {
763778
/// Take ownership of a pointer created by C API
764779
explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl<OrtKeyValuePairs>{p} {}
765780

781+
/// \brief Wraps OrtApi::CreateKeyValuePairs
766782
explicit KeyValuePairs();
783+
784+
/// \brief Wraps OrtApi::CreateKeyValuePairs and OrtApi::AddKeyValuePair
767785
explicit KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs);
768786

787+
/// \brief Wraps OrtApi::AddKeyValuePair
769788
void Add(const char* key, const char* value);
789+
790+
/// \brief Wraps OrtApi::RemoveKeyValuePair
770791
void Remove(const char* key);
771792

772793
ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; }
@@ -806,10 +827,21 @@ struct EpDeviceImpl : Ort::detail::Base<T> {
806827
} // namespace detail
807828

808829
/** \brief Wrapper around ::OrtEpDevice
809-
* \remarks EpDevice is always read-only for API users.
830+
* \remarks EpDevice is always read-only for ORT API users.
810831
*/
811832
using ConstEpDevice = detail::EpDeviceImpl<Ort::detail::Unowned<const OrtEpDevice>>;
812833

834+
/** \brief Mutable EpDevice that is created by EpApi users.
835+
*/
836+
struct EpDevice : detail::EpDeviceImpl<OrtEpDevice> {
837+
explicit EpDevice(std::nullptr_t) {} ///< No instance is created
838+
explicit EpDevice(OrtEpDevice* p) : EpDeviceImpl<OrtEpDevice>{p} {} ///< Take ownership of a pointer created by C API
839+
840+
/// \brief Wraps OrtEpApi::CreateEpDevice
841+
EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
842+
ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {});
843+
};
844+
813845
/** \brief The Env (Environment)
814846
*
815847
* The Env holds the logging state used by all other objects.

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

+5
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,11 @@ inline ConstHardwareDevice EpDeviceImpl<T>::Device() const {
593593
}
594594
} // namespace detail
595595

596+
inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
597+
ConstKeyValuePairs ep_metadata, ConstKeyValuePairs ep_options) {
598+
ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_));
599+
}
600+
596601
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
597602
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
598603
if (strcmp(logid, "onnxruntime-node") == 0) {

onnxruntime/core/providers/cuda/cuda_provider_factory.cc

+20-12
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ struct CudaEpFactory : OrtEpFactory {
314314
CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} {
315315
GetName = GetNameImpl;
316316
GetVendor = GetVendorImpl;
317-
GetDeviceInfoIfSupported = GetDeviceInfoIfSupportedImpl;
317+
GetSupportedDevices = GetSupportedDevicesImpl;
318318
CreateEp = CreateEpImpl;
319319
ReleaseEp = ReleaseEpImpl;
320320
}
@@ -329,18 +329,26 @@ struct CudaEpFactory : OrtEpFactory {
329329
return factory->vendor.c_str();
330330
}
331331

332-
static bool GetDeviceInfoIfSupportedImpl(const OrtEpFactory* this_ptr,
333-
const OrtHardwareDevice* device,
334-
_Out_opt_ OrtKeyValuePairs** /*ep_metadata*/,
335-
_Out_opt_ OrtKeyValuePairs** /*ep_options*/) {
336-
const auto* factory = static_cast<const CudaEpFactory*>(this_ptr);
337-
338-
if (factory->ort_api.HardwareDevice_Type(device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&
339-
factory->ort_api.HardwareDevice_VendorId(device) == 0x10de) {
340-
return true;
332+
static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr,
333+
const OrtHardwareDevice* const* devices,
334+
size_t num_devices,
335+
OrtEpDevice** ep_devices,
336+
size_t max_ep_devices,
337+
size_t* p_num_ep_devices) {
338+
size_t& num_ep_devices = *p_num_ep_devices;
339+
auto* factory = static_cast<CudaEpFactory*>(this_ptr);
340+
341+
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
342+
const OrtHardwareDevice& device = *devices[i];
343+
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&
344+
factory->ort_api.HardwareDevice_VendorId(&device) == 0x10de) {
345+
ORT_API_RETURN_IF_ERROR(
346+
factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, nullptr,
347+
&ep_devices[num_ep_devices++]));
348+
}
341349
}
342350

343-
return false;
351+
return nullptr;
344352
}
345353

346354
static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/,
@@ -385,7 +393,7 @@ OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase
385393
}
386394

387395
OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
388-
delete factory;
396+
delete static_cast<CudaEpFactory*>(factory);
389397
return nullptr;
390398
}
391399
}

onnxruntime/core/session/environment.cc

+37-23
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
#include "core/session/environment.h"
55

6+
#include <array>
7+
68
#include "core/common/basic_types.h"
79
#include "core/framework/allocator_utils.h"
10+
#include "core/framework/error_code_helper.h"
811
#include "core/graph/constants.h"
912
#include "core/graph/op.h"
1013
#include "core/platform/device_discovery.h"
@@ -468,6 +471,28 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam
468471
return status;
469472
}
470473

474+
namespace {
475+
std::vector<const OrtHardwareDevice*> SortDevicesByType() {
476+
auto& devices = DeviceDiscovery::GetDevices();
477+
std::vector<const OrtHardwareDevice*> sorted_devices;
478+
sorted_devices.reserve(devices.size());
479+
480+
const auto select_by_type = [&](OrtHardwareDeviceType type) {
481+
for (const auto& device : devices) {
482+
if (device.type == type) {
483+
sorted_devices.push_back(&device);
484+
}
485+
}
486+
};
487+
488+
select_by_type(OrtHardwareDeviceType_NPU);
489+
select_by_type(OrtHardwareDeviceType_GPU);
490+
select_by_type(OrtHardwareDeviceType_CPU);
491+
492+
return sorted_devices;
493+
}
494+
} // namespace
495+
471496
Status Environment::EpInfo::Create(std::unique_ptr<EpLibrary> library_in, std::unique_ptr<EpInfo>& out,
472497
const std::vector<EpFactoryInternal*>& internal_factories) {
473498
if (!library_in) {
@@ -482,36 +507,25 @@ Status Environment::EpInfo::Create(std::unique_ptr<EpLibrary> library_in, std::u
482507
ORT_RETURN_IF_ERROR(instance.library->Load());
483508
const auto& factories = instance.library->GetFactories();
484509

510+
// OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured.
511+
// the set of hardware devices is static so this can also be static.
512+
const static std::vector<const OrtHardwareDevice*> sorted_devices = SortDevicesByType();
513+
485514
for (auto* factory_ptr : factories) {
486515
ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:",
487516
instance.library->RegistrationName());
488517

489518
auto& factory = *factory_ptr;
490519

491-
// for each device
492-
for (const auto& device : DeviceDiscovery::GetDevices()) {
493-
OrtKeyValuePairs* ep_metadata = nullptr;
494-
OrtKeyValuePairs* ep_options = nullptr;
495-
496-
if (factory.GetDeviceInfoIfSupported(&factory, &device, &ep_metadata, &ep_options)) {
497-
auto ed = std::make_unique<OrtEpDevice>();
498-
ed->ep_name = factory.GetName(&factory);
499-
ed->ep_vendor = factory.GetVendor(&factory);
500-
ed->device = &device;
501-
502-
if (ep_metadata) {
503-
ed->ep_metadata = std::move(*ep_metadata);
504-
delete ep_metadata;
505-
}
506-
507-
if (ep_options) {
508-
ed->ep_options = std::move(*ep_options);
509-
delete ep_options;
510-
}
511-
512-
ed->ep_factory = &factory;
520+
std::array<OrtEpDevice*, 8> ep_devices{nullptr};
521+
size_t num_ep_devices = 0;
522+
ORT_RETURN_IF_ERROR(ToStatus(
523+
factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(),
524+
ep_devices.data(), ep_devices.size(), &num_ep_devices)));
513525

514-
instance.execution_devices.push_back(std::move(ed));
526+
for (size_t i = 0; i < num_ep_devices; ++i) {
527+
if (ep_devices[i] != nullptr) { // should never happen but just in case...
528+
instance.execution_devices.emplace_back(ep_devices[i]); // take ownership
515529
}
516530
}
517531
}

0 commit comments

Comments
 (0)