Skip to content

Commit

Permalink
Update OrtApi wrappers to use structs and IntPtr's
Browse files Browse the repository at this point in the history
  • Loading branch information
tommcdon committed Sep 11, 2023
1 parent 50555ae commit abbec2e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
17 changes: 10 additions & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace Microsoft.ML.OnnxRuntime
{
[StructLayout(LayoutKind.Sequential)]
public class OrtApiBase
public struct OrtApiBase
{
public IntPtr GetApi;
public IntPtr GetVersionString;
Expand All @@ -17,7 +17,7 @@ public class OrtApiBase
// OrtApi ort_api_1_to_<latest_version> (onnxruntime/core/session/onnxruntime_c_api.cc)
// If syncing your new C API, any other C APIs before yours also need to be synced here if haven't
[StructLayout(LayoutKind.Sequential)]
public class OrtApi
public struct OrtApi
{
public IntPtr CreateStatus;
public IntPtr GetErrorCode;
Expand Down Expand Up @@ -294,7 +294,7 @@ internal static class NativeMethods
static OrtApi api_;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate OrtApi DOrtGetApi(UInt32 version);
public delegate IntPtr DOrtGetApi(UInt32 version);

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr DOrtGetVersionString();
Expand All @@ -303,11 +303,14 @@ internal static class NativeMethods

static NativeMethods()
{
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi));
IntPtr ortApiBasePtr = OrtGetApiBase();
OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi));

// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(14 /*ORT_API_VERSION*/);
OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetVersionString, typeof(DOrtGetVersionString));
IntPtr ortApiPtr = OrtGetApi(14 /*ORT_API_VERSION*/);
api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi));
OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetVersionString, typeof(DOrtGetVersionString));

OrtCreateEnv = (DOrtCreateEnv)Marshal.GetDelegateForFunctionPointer(api_.CreateEnv, typeof(DOrtCreateEnv));
OrtCreateEnvWithCustomLogger = (DOrtCreateEnvWithCustomLogger)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLogger, typeof(DOrtCreateEnvWithCustomLogger));
Expand Down Expand Up @@ -509,7 +512,7 @@ internal class NativeLib
}

[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
public static extern OrtApiBase OrtGetApiBase();
public static extern IntPtr OrtGetApiBase();

#region Runtime/Environment API

Expand Down
16 changes: 9 additions & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,10 @@ public void AppendExecutionProvider(string providerName, Dictionary<string, stri
/// Loads a DLL named 'libraryPath' and looks for this entry point:
/// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
/// It then passes in the provided session options to this function along with the api base.
///
/// Prior to v1.15 this leaked the library handle and RegisterCustomOpLibraryV2
/// was added to resolve that.
///
///
/// Prior to v1.15 this leaked the library handle and RegisterCustomOpLibraryV2
/// was added to resolve that.
///
/// From v1.15 on ONNX Runtime will manage the lifetime of the handle.
/// </summary>
/// <param name="libraryPath">path to the custom op library</param>
Expand Down Expand Up @@ -435,20 +435,22 @@ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHand
// SessionOptions.RegisterCustomOpLibrary calls NativeMethods.OrtRegisterCustomOpsLibrary_V2
// SessionOptions.RegisterCustomOpLibraryV2 calls NativeMethods.OrtRegisterCustomOpsLibrary
var utf8Path = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath);
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path,
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path,
out libraryHandle));
}

/// <summary>
/// Register the custom operators from the Microsoft.ML.OnnxRuntime.Extensions NuGet package.
/// A reference to Microsoft.ML.OnnxRuntime.Extensions must be manually added to your project.
/// A reference to Microsoft.ML.OnnxRuntime.Extensions must be manually added to your project.
/// </summary>
/// <exception cref="OnnxRuntimeException">Throws if the extensions library is not found.</exception>
public void RegisterOrtExtensions()
{
try
{
var ortApiBase = NativeMethods.OrtGetApiBase();
//var ortApiBase = NativeMethods.OrtGetApiBase();
IntPtr ortApiBasePtr = NativeMethods.OrtGetApiBase();
OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
NativeApiStatus.VerifySuccess(
OrtExtensionsNativeMethods.RegisterCustomOps(this.handle, ref ortApiBase)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,23 @@ internal static class NativeTrainingMethods
static IntPtr trainingApiPtr;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate OrtApi DOrtGetApi(UInt32 version);
public delegate IntPtr DOrtGetApi(UInt32 version);

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtTrainingApi* */ DOrtGetTrainingApi(UInt32 version);
public static DOrtGetTrainingApi OrtGetTrainingApi;

static NativeTrainingMethods()
{
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));
//DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));
IntPtr ortApiBasePtr = NativeMethods.OrtGetApiBase();
OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(ortApiBase.GetApi, typeof(DOrtGetApi));

// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(13 /*ORT_API_VERSION*/);
//api_ = (OrtApi)OrtGetApi(13 /*ORT_API_VERSION*/);
IntPtr ortApiPtr = OrtGetApi(14 /*ORT_API_VERSION*/);
api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi));

OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
trainingApiPtr = OrtGetTrainingApi(13 /*ORT_API_VERSION*/);
Expand Down

0 comments on commit abbec2e

Please sign in to comment.