From abbec2e4fad1c610b5103ace142757d27979b3a7 Mon Sep 17 00:00:00 2001 From: Tom McDonald Date: Sun, 10 Sep 2023 23:48:42 -0400 Subject: [PATCH] Update OrtApi wrappers to use structs and IntPtr's --- .../NativeMethods.shared.cs | 17 ++++++++++------- .../SessionOptions.shared.cs | 16 +++++++++------- .../Training/NativeTrainingMethods.shared.cs | 11 ++++++++--- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 0490f4e26e18..ccedc77a07c8 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -7,7 +7,7 @@ namespace Microsoft.ML.OnnxRuntime { [StructLayout(LayoutKind.Sequential)] - public class OrtApiBase + public struct OrtApiBase { public IntPtr GetApi; public IntPtr GetVersionString; @@ -17,7 +17,7 @@ public class OrtApiBase // OrtApi ort_api_1_to_ (onnxruntime/core/session/onnxruntime_c_api.cc) // If syncing your new C API, any other C APIs before yours also need to be synced here if haven't [StructLayout(LayoutKind.Sequential)] - public class OrtApi + public struct OrtApi { public IntPtr CreateStatus; public IntPtr GetErrorCode; @@ -294,7 +294,7 @@ internal static class NativeMethods static OrtApi api_; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate OrtApi DOrtGetApi(UInt32 version); + public delegate IntPtr DOrtGetApi(UInt32 version); [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtGetVersionString(); @@ -303,11 +303,14 @@ internal static class NativeMethods static NativeMethods() { - DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi)); + IntPtr ortApiBasePtr = OrtGetApiBase(); + OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); + DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(14 /*ORT_API_VERSION*/); - OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetVersionString, typeof(DOrtGetVersionString)); + IntPtr ortApiPtr = OrtGetApi(14 /*ORT_API_VERSION*/); + api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); + OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetVersionString, typeof(DOrtGetVersionString)); OrtCreateEnv = (DOrtCreateEnv)Marshal.GetDelegateForFunctionPointer(api_.CreateEnv, typeof(DOrtCreateEnv)); OrtCreateEnvWithCustomLogger = (DOrtCreateEnvWithCustomLogger)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLogger, typeof(DOrtCreateEnvWithCustomLogger)); @@ -509,7 +512,7 @@ internal class NativeLib } [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] - public static extern OrtApiBase OrtGetApiBase(); + public static extern IntPtr OrtGetApiBase(); #region Runtime/Environment API diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 30951bae3f9f..b321d3d8c727 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -397,10 +397,10 @@ public void AppendExecutionProvider(string providerName, Dictionary /// path to the custom op library @@ -435,20 +435,22 @@ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHand // SessionOptions.RegisterCustomOpLibrary calls NativeMethods.OrtRegisterCustomOpsLibrary_V2 // SessionOptions.RegisterCustomOpLibraryV2 calls NativeMethods.OrtRegisterCustomOpsLibrary var utf8Path = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); - NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, out libraryHandle)); } /// /// Register the custom operators from the Microsoft.ML.OnnxRuntime.Extensions NuGet package. - /// A reference to Microsoft.ML.OnnxRuntime.Extensions must be manually added to your project. + /// A reference to Microsoft.ML.OnnxRuntime.Extensions must be manually added to your project. /// /// Throws if the extensions library is not found. public void RegisterOrtExtensions() { try { - var ortApiBase = NativeMethods.OrtGetApiBase(); + //var ortApiBase = NativeMethods.OrtGetApiBase(); + IntPtr ortApiBasePtr = NativeMethods.OrtGetApiBase(); + OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); NativeApiStatus.VerifySuccess( OrtExtensionsNativeMethods.RegisterCustomOps(this.handle, ref ortApiBase) ); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 7a538b2b2d4a..00046718a623 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -49,7 +49,7 @@ internal static class NativeTrainingMethods static IntPtr trainingApiPtr; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate OrtApi DOrtGetApi(UInt32 version); + public delegate IntPtr DOrtGetApi(UInt32 version); [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtTrainingApi* */ DOrtGetTrainingApi(UInt32 version); @@ -57,10 +57,15 @@ internal static class NativeTrainingMethods static NativeTrainingMethods() { - DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); + //DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); + IntPtr ortApiBasePtr = NativeMethods.OrtGetApiBase(); + OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); + DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(13 /*ORT_API_VERSION*/); + //api_ = (OrtApi)OrtGetApi(13 /*ORT_API_VERSION*/); + IntPtr ortApiPtr = OrtGetApi(14 /*ORT_API_VERSION*/); + api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); trainingApiPtr = OrtGetTrainingApi(13 /*ORT_API_VERSION*/);