diff --git a/dotnet/Extism.sln.DotSettings b/dotnet/Extism.sln.DotSettings new file mode 100644 index 000000000..4940646b8 --- /dev/null +++ b/dotnet/Extism.sln.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/dotnet/samples/Extism.Sdk.Sample/Extism.Sdk.Sample.csproj b/dotnet/samples/Extism.Sdk.Sample/Extism.Sdk.Sample.csproj index f87856c7b..7205d0e8b 100644 --- a/dotnet/samples/Extism.Sdk.Sample/Extism.Sdk.Sample.csproj +++ b/dotnet/samples/Extism.Sdk.Sample/Extism.Sdk.Sample.csproj @@ -13,6 +13,10 @@ + + + + diff --git a/dotnet/src/Extism.Sdk/Context.cs b/dotnet/src/Extism.Sdk/Context.cs index bfbcc6b0e..fc140bd98 100644 --- a/dotnet/src/Extism.Sdk/Context.cs +++ b/dotnet/src/Extism.Sdk/Context.cs @@ -6,7 +6,7 @@ namespace Extism.Sdk.Native; /// /// Represents an Extism context through which you can load s. /// -public class Context : IDisposable +public unsafe class Context : IDisposable { private const int DisposedMarker = 1; @@ -17,13 +17,16 @@ public class Context : IDisposable /// public Context() { - NativeHandle = LibExtism.extism_context_new(); + unsafe + { + NativeHandle = LibExtism.extism_context_new(); + } } /// /// Native pointer to the Extism Context. /// - internal IntPtr NativeHandle { get; } + internal LibExtism.ExtismContext* NativeHandle { get; } /// /// Loads an Extism . @@ -34,12 +37,18 @@ public Plugin CreatePlugin(ReadOnlySpan wasm, bool withWasi) { CheckNotDisposed(); + unsafe { fixed (byte* wasmPtr = wasm) { var plugin = LibExtism.extism_plugin_new(NativeHandle, wasmPtr, wasm.Length, null, 0, withWasi); - return new Plugin(this, plugin); + if (plugin == -1) + { + throw new ExtismException(GetError() ?? "Unknown exception when calling extism_plugin_new"); + } + var cancelHandle = LibExtism.extism_plugin_cancel_handle(NativeHandle, plugin); + return new Plugin(this, plugin, cancelHandle); } } } @@ -131,6 +140,7 @@ public static string GetExtismVersion() return Marshal.PtrToStringUTF8(pointer); } + // TODO: this should not be within the context, neither should the version. /// /// Set Extism's log file and level. This is applied for all s. /// @@ -152,6 +162,7 @@ public static bool SetExtismLogFile(string logPath, LogLevel level) } } +// TODO: check if the enums are correctly set /// /// Extism Log Levels /// diff --git a/dotnet/src/Extism.Sdk/LibExtism.cs b/dotnet/src/Extism.Sdk/LibExtism.cs index 31b471c3d..2cc1ad06c 100644 --- a/dotnet/src/Extism.Sdk/LibExtism.cs +++ b/dotnet/src/Extism.Sdk/LibExtism.cs @@ -1,5 +1,4 @@ using System.Runtime.InteropServices; - namespace Extism.Sdk.Native; /// @@ -7,19 +6,182 @@ namespace Extism.Sdk.Native; /// internal static class LibExtism { + internal enum ExtismValType + { + /// + /// Signed 32 bit integer. Equivalent of or + /// + I32, + + /// + /// Signed 64 bit integer. Equivalent of or + /// + I64, + + /// + /// Floating point 32 bit integer. Equivalent of + /// + F32, + + /// + /// Floating point 64 bit integer. Equivalent of + /// + F64, + + /// + /// A 128 bit number. + /// + V128, + + /// + /// A reference to opaque data in the Wasm instance. + /// + FuncRef, + + /// + /// A reference to opaque data in the Wasm instance. + /// + ExternRef + } + + /// + /// A `Context` is used to store and manage plugins. + /// + [StructLayout(LayoutKind.Sequential)] + internal struct ExtismContext { } + + /// + /// Wraps host functions + /// + [StructLayout(LayoutKind.Sequential)] + internal struct ExtismFunction { } + + /// + /// Plugin contains everything needed to execute a WASM function. + /// + [StructLayout(LayoutKind.Sequential)] + internal struct ExtismCurrentPlugin { } + + /// + /// A union type for host function argument/return values. + /// + [StructLayout(LayoutKind.Explicit)] + internal struct ExtismValUnion + { + [FieldOffset(0)] + internal int i32; + + [FieldOffset(0)] + internal long i64; + + [FieldOffset(0)] + internal float f32; + + [FieldOffset(0)] + internal double f64; + } + + /// + /// `ExtismVal` holds the type and value of a function argument/return + /// + [StructLayout(LayoutKind.Sequential)] + internal struct ExtismVal + { + internal ExtismValType t; + internal ExtismValUnion v; + } + + /// + /// Host function signature + /// + /// + /// + /// + /// + /// + /// + internal delegate void ExtismFunctionType(ref ExtismCurrentPlugin plugin, Span inputs, uint n_inputs, Span outputs, uint n_outputs, IntPtr data); + + /// + /// Returns a pointer to the memory of the currently running plugin. + /// NOTE: this should only be called from host functions. + /// + /// + /// + [DllImport("extism", EntryPoint = "extism_current_plugin_memory")] + public static extern IntPtr CurrentPluginMemory(ref ExtismCurrentPlugin plugin); + + /// + /// + /// + /// + /// + /// + [DllImport("extism", EntryPoint = "extism_current_plugin_memory_alloc")] + public static extern IntPtr CurrentPluginMemoryAlloc(ref ExtismCurrentPlugin plugin, long n); + + /// + /// Allocate a memory block in the currently running plugin. + /// NOTE: this should only be called from host functions. + /// + /// + /// + /// + [DllImport("extism", EntryPoint = "extism_current_plugin_memory_length")] + public static extern long CurrentPluginMemoryLength(ref ExtismCurrentPlugin plugin, long n); + + /// + /// Get the length of an allocated block. + /// NOTE: this should only be called from host functions. + /// + /// + /// + [DllImport("extism", EntryPoint = "extism_current_plugin_memory_free")] + public static extern void CurrentPluginMemoryFree(ref ExtismCurrentPlugin plugin, IntPtr ptr); + + /// + /// Create a new host function. + /// + /// function name, this should be valid UTF-8 + /// argument types + /// number of argument types + /// return types + /// number of return types + /// the function to call + /// a pointer that will be passed to the function when it's called this value should live as long as the function exists + /// a callback to release the `user_data` value when the resulting `ExtismFunction` is freed. + /// + [DllImport("extism", EntryPoint = "extism_function_new")] + public static extern IntPtr FunctionNew(string name, IntPtr inputs, long nInputs, IntPtr outputs, long nOutputs, ExtismFunctionType func, IntPtr userData, IntPtr freeUserData); + + /// + /// Set the namespace of an + /// + /// + /// + [DllImport("extism", EntryPoint = "extism_function_set_namespace")] + public static extern void FunctionSetNamespace(ref ExtismFunction ptr, string @namespace); + + /// + /// Free an + /// + /// + [DllImport("extism", EntryPoint = "extism_function_free")] + public static extern void FunctionFree(ref ExtismFunction ptr); + /// /// Create a new context. /// /// A pointer to the newly created context. [DllImport("extism")] - public static extern IntPtr extism_context_new(); + unsafe internal static extern ExtismContext* extism_context_new(); /// /// Remove a context from the registry and free associated memory. /// /// [DllImport("extism")] - public static extern void extism_context_free(IntPtr context); + unsafe internal static extern void extism_context_free(ExtismContext* context); /// /// Load a WASM plugin. @@ -30,9 +192,9 @@ internal static class LibExtism /// Array of host function pointers. /// Number of host functions. /// Enables/disables WASI. - /// + /// The plugin's index in the Extism context, or -1 if the plugin could not be successfully created. [DllImport("extism")] - unsafe public static extern IntPtr extism_plugin_new(IntPtr context, byte* wasm, int wasmSize, IntPtr *functions, int nFunctions, bool withWasi); + unsafe internal static extern Int32 extism_plugin_new(ExtismContext* context, byte* wasm, int wasmSize, IntPtr* functions, int nFunctions, bool withWasi); /// /// Update a plugin, keeping the existing ID. @@ -42,13 +204,13 @@ internal static class LibExtism /// Pointer to the context the plugin is associated with. /// Pointer to the plugin you want to update. /// A WASM module (wat or wasm) or a JSON encoded manifest. - /// The length of the `wasm` parameter. + /// The length of the `wasm` parameter. /// Array of host function pointers. /// Number of host functions. /// Enables/disables WASI. /// [DllImport("extism")] - unsafe public static extern bool extism_plugin_update(IntPtr context, IntPtr plugin, byte* wasm, int wasmLength, IntPtr *functions, int nFunctions, bool withWasi); + unsafe internal static extern bool extism_plugin_update(ExtismContext* context, Int32 plugin, byte* wasm, long wasmSize, IntPtr* functions, long nFunctions, bool withWasi); /// /// Remove a plugin from the registry and free associated memory. @@ -56,74 +218,91 @@ internal static class LibExtism /// Pointer to the context the plugin is associated with. /// Pointer to the plugin you want to free. [DllImport("extism")] - public static extern void extism_plugin_free(IntPtr context, IntPtr plugin); + unsafe internal static extern void extism_plugin_free(ExtismContext* context, Int32 plugin); + + /// + /// Request cancellation on a currently-executing plugin. Will cancel before executing the next epoch. + /// + /// Pointer to the context the plugin is associated with. + /// Index of the plugin you want to cancel. + [DllImport("extism")] + unsafe internal static extern IntPtr extism_plugin_cancel_handle(ExtismContext* context, Int32 pluginIndex); + + /// + /// Request cancellation on a currently-executing plugin. Will cancel before executing the next epoch. + /// + /// Pointer to the plugin's cancel handle. + [DllImport("extism")] + unsafe internal static extern void extism_plugin_cancel(IntPtr handle); /// /// Remove all plugins from the registry. /// /// [DllImport("extism")] - public static extern void extism_context_reset(IntPtr context); + unsafe internal static extern void extism_context_reset(ExtismContext* context); /// /// Update plugin config values, this will merge with the existing values. /// /// Pointer to the context the plugin is associated with. - /// Pointer to the plugin you want to update the configurations for. + /// Index of the plugin you want to update the configurations for. /// The configuration JSON encoded in UTF8. /// The length of the `json` parameter. /// [DllImport("extism")] - unsafe public static extern bool extism_plugin_config(IntPtr context, IntPtr plugin, byte* json, int jsonLength); + unsafe internal static extern bool extism_plugin_config(ExtismContext* context, Int32 pluginIndex, byte* json, int jsonLength); /// /// Returns true if funcName exists. /// /// - /// + /// /// /// [DllImport("extism")] - public static extern bool extism_plugin_function_exists(IntPtr context, IntPtr plugin, string funcName); + unsafe internal static extern bool extism_plugin_function_exists(ExtismContext* context, Int32 pluginIndex, string funcName); /// /// Call a function. /// /// - /// + /// /// The function to call. /// Input data. /// The length of the `data` parameter. /// [DllImport("extism")] - unsafe public static extern int extism_plugin_call(IntPtr context, IntPtr plugin, string funcName, byte* data, int dataLen); + unsafe internal static extern int extism_plugin_call(ExtismContext* context, Int32 pluginIndex, string funcName, byte* data, int dataLen); + + /// /// Get the error associated with a Context or Plugin, if plugin is -1 then the context error will be returned. /// /// - /// A plugin pointer, or -1 for the context error. + /// A plugin index, or -1 for the context error. /// [DllImport("extism")] - public static extern IntPtr extism_error(IntPtr context, nint plugin); + unsafe internal static extern IntPtr extism_error(ExtismContext* context, Int32 pluginIndex); /// /// Get the length of a plugin's output data. /// /// - /// + /// /// [DllImport("extism")] - public static extern long extism_plugin_output_length(IntPtr context, IntPtr plugin); + unsafe internal static extern long extism_plugin_output_length(ExtismContext* context, Int32 pluginIndex); /// /// Get the plugin's output data. /// /// - /// + /// /// [DllImport("extism")] - public static extern IntPtr extism_plugin_output_data(IntPtr context, IntPtr plugin); + unsafe internal static extern IntPtr extism_plugin_output_data(ExtismContext* context, Int32 pluginIndex); /// /// Set log file and level. @@ -132,43 +311,83 @@ internal static class LibExtism /// /// [DllImport("extism")] - public static extern bool extism_log_file(string filename, string logLevel); + internal static extern bool extism_log_file(string filename, string logLevel); /// /// Get the Extism version string. /// /// [DllImport("extism", EntryPoint = "extism_version")] - public static extern IntPtr extism_version(); + internal static extern IntPtr extism_version(); /// /// Extism Log Levels /// - public static class LogLevels + internal static class LogLevels { /// /// Designates very serious errors. /// - public const string Error = "Error"; + internal const string Error = "Error"; /// /// Designates hazardous situations. /// - public const string Warn = "Warn"; + internal const string Warn = "Warn"; /// /// Designates useful information. /// - public const string Info = "Info"; + internal const string Info = "Info"; /// /// Designates lower priority information. /// - public const string Debug = "Debug"; + internal const string Debug = "Debug"; /// /// Designates very low priority, often extremely verbose, information. /// - public const string Trace = "Trace"; + internal const string Trace = "Trace"; } + + + + // TODO: I think this should be refactored to have host functions in their own file, as well as enums, etc. + /* + Pointer extism_function_new(String name, + int[] inputs, + int nInputs, + int[] outputs, + int nOutputs, + InternalExtismFunction func, + Pointer userData, + Pointer freeUserData); + + + /** + * Get the length of an allocated block + * NOTE: this should only be called from host functions. + * + int extism_current_plugin_memory_length(Pointer plugin, long n); + + /** + * Returns a pointer to the memory of the currently running plugin + * NOTE: this should only be called from host functions. + * + Pointer extism_current_plugin_memory(Pointer plugin); + + /** + * Allocate a memory block in the currently running plugin + * NOTE: this should only be called from host functions. + * + int extism_current_plugin_memory_alloc(Pointer plugin, long n); + + /** + * Free an allocated memory block + * NOTE: this should only be called from host functions. + * + void extism_current_plugin_memory_free(Pointer plugin, long ptr); + + */ } \ No newline at end of file diff --git a/dotnet/src/Extism.Sdk/Plugin.cs b/dotnet/src/Extism.Sdk/Plugin.cs index 9ec9f9a3e..ade3d2759 100644 --- a/dotnet/src/Extism.Sdk/Plugin.cs +++ b/dotnet/src/Extism.Sdk/Plugin.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; @@ -12,17 +13,19 @@ public class Plugin : IDisposable private readonly Context _context; private int _disposed; + private readonly IntPtr _cancelHandle; - internal Plugin(Context context, IntPtr handle) + internal Plugin(Context context, int pluginIndex, IntPtr cancelHandle) { _context = context; - NativeHandle = handle; + PluginIndex = pluginIndex; + _cancelHandle = cancelHandle; } /// - /// A pointer to the native Plugin struct. + /// This plugin's current index in the Context /// - internal IntPtr NativeHandle { get; } + internal Int32 PluginIndex { get; } /// /// Update a plugin, keeping the existing ID. @@ -35,9 +38,20 @@ unsafe public bool Update(ReadOnlySpan wasm, bool withWasi) fixed (byte* wasmPtr = wasm) { - return LibExtism.extism_plugin_update(_context.NativeHandle, NativeHandle, wasmPtr, wasm.Length, null, 0, withWasi); + return LibExtism.extism_plugin_update(_context.NativeHandle, PluginIndex, wasmPtr, wasm.Length, null, 0, withWasi); } } + + /// + /// Request to cancel a currently-executing plugin at the next epoch. + /// + /// The plugin WASM bytes. + /// Enable/Disable WASI. + unsafe public void Cancel() + { + CheckNotDisposed(); + LibExtism.extism_plugin_cancel(_cancelHandle); + } /// /// Update plugin config values, this will merge with the existing values. @@ -49,28 +63,26 @@ unsafe public bool SetConfig(ReadOnlySpan json) fixed (byte* jsonPtr = json) { - return LibExtism.extism_plugin_config(_context.NativeHandle, NativeHandle, jsonPtr, json.Length); + return LibExtism.extism_plugin_config(_context.NativeHandle, PluginIndex, jsonPtr, json.Length); } } /// /// Checks if a specific function exists in the current plugin. /// - public bool FunctionExists(string name) + unsafe public bool FunctionExists(string name) { CheckNotDisposed(); - return LibExtism.extism_plugin_function_exists(_context.NativeHandle, NativeHandle, name); + return LibExtism.extism_plugin_function_exists(_context.NativeHandle, PluginIndex, name); } /// - /// Calls a function in the current plugin and returns a status. - /// If the status represents an error, call to get the error. - /// Othewise, call to get the function's output data. + /// Calls a function in the current plugin and returns the plugin's output bytes. /// /// Name of the function in the plugin to invoke. /// A buffer to provide as input to the function. - /// The exit code of the function. + /// A buffer with the plugin's output bytes. /// unsafe public ReadOnlySpan CallFunction(string functionName, ReadOnlySpan data) { @@ -78,29 +90,132 @@ unsafe public ReadOnlySpan CallFunction(string functionName, ReadOnlySpan< fixed (byte* dataPtr = data) { - int response = LibExtism.extism_plugin_call(_context.NativeHandle, NativeHandle, functionName, dataPtr, data.Length); - if (response == 0) { + int response = LibExtism.extism_plugin_call(_context.NativeHandle, PluginIndex, functionName, dataPtr, data.Length); + if (response == 0) + { return OutputData(); - } else { + } + else + { var errorMsg = GetError(); - if (errorMsg != null) { + if (errorMsg != null) + { throw new ExtismException(errorMsg); - } else { + } + else + { throw new ExtismException("Call to Extism failed"); } } } } + public async Task CallFunctionAsync(string functionName, byte[] data, int? timeoutMs = null, CancellationToken? cancellationToken = null) + { + + // If we don't set a timeout or a cancellation token, our watcher thread will run forever even after it leaves scope. + // A task does not get automatically canceled or garbage collected if it leaves scope. + // we need to make our own internal cancellation to ensure that we cancel this thread when we're done with it before exiting the function. + CancellationTokenSource internalTokenSource = new CancellationTokenSource(); + + + // Create an async function that will run forever or until timeoutMs, + // but will exit early if the external or internal cancellation tokens are canceled. + // Check token for cancellation every 10 ms to minimize CPU utilization. + // Even though the Task.Run gets a cancellation token, it doesn't continually check for cancellation, + // only checks once at the beginning to determine if it should run the task or not, + // so we must also check within this task to ensure the token hasn't been canceled. + var runUntilTimeoutOrCancelled = async () => + { + // If no timeout is set, save the resources of having a stopwatch and just run forever until task is canceled. + if (timeoutMs == null) + { + while (true) + { + if (internalTokenSource.IsCancellationRequested || (cancellationToken?.IsCancellationRequested ?? false)) + { + break; + } + + try + { + await Task.Delay(1000, cancellationToken ?? internalTokenSource.Token); + } + catch (Exception e) + { + } + } + } + else + { + var executionTime = Stopwatch.StartNew(); + while (executionTime.ElapsedMilliseconds < timeoutMs) + { + if (internalTokenSource.IsCancellationRequested || (cancellationToken?.IsCancellationRequested ?? false)) + { + break; + } + try + { + await Task.Delay(1000, cancellationToken ?? internalTokenSource.Token); + } + catch (Exception e) + { + } + } + } + }; + + // Create tasks for invoking called function as well as for running a timeout / cancelled checker. + // Note, we are not awaiting the task here. We will await the tasks later in parallel to determine when they are completed. + + Task cancellableTimeoutTask; + Task executeFunctionInternalTask; + + // If cancellation token has already been canceled, don't execute any code. + if (cancellationToken?.IsCancellationRequested ?? false) + { + throw new TaskCanceledException(); + } + + cancellableTimeoutTask = Task.Run(runUntilTimeoutOrCancelled); + executeFunctionInternalTask = Task.Run(() => { return CallFunction(functionName, data).ToArray(); }); + + // Race the 2 tasks. When they're complete, either executeFunctionInternal + // will have completed with a result, or the cancellableTimeout will have run to completion + // meaning the task was either cancelled or timed out. + await Task.WhenAny(executeFunctionInternalTask, cancellableTimeoutTask); + + if (executeFunctionInternalTask.IsCompletedSuccessfully) + { + internalTokenSource.Cancel(); // Cancel internal token so the background task will terminate. + await cancellableTimeoutTask; // Wait for the cancellation monitor task to exit gracefully + return await executeFunctionInternalTask; + } + + internalTokenSource.Cancel(); // Cancel internal token so the background task will terminate. + await cancellableTimeoutTask; // Wait for the task to exit gracefully + Cancel(); // Cancel the plugin so Extism can free it when we exit + + // If an exception was thrown by the task, rethrow it. + if (executeFunctionInternalTask.Exception.InnerException != null) + { + throw executeFunctionInternalTask.Exception.InnerException; + } + // Dispose(); + // Throw an exception so the caller knows that something timed out. + throw new TaskCanceledException(); + } + /// /// Get the length of a plugin's output data. /// /// - internal int OutputLength() + unsafe internal int OutputLength() { CheckNotDisposed(); - return (int)LibExtism.extism_plugin_output_length(_context.NativeHandle, NativeHandle); + return (int)LibExtism.extism_plugin_output_length(_context.NativeHandle, PluginIndex); } /// @@ -114,7 +229,7 @@ internal ReadOnlySpan OutputData() unsafe { - var ptr = LibExtism.extism_plugin_output_data(_context.NativeHandle, NativeHandle).ToPointer(); + var ptr = LibExtism.extism_plugin_output_data(_context.NativeHandle, PluginIndex).ToPointer(); return new Span(ptr, length); } } @@ -123,11 +238,11 @@ internal ReadOnlySpan OutputData() /// Get the error associated with the current plugin. /// /// - internal string? GetError() + unsafe internal string? GetError() { CheckNotDisposed(); - var result = LibExtism.extism_error(_context.NativeHandle, NativeHandle); + var result = LibExtism.extism_error(_context.NativeHandle, PluginIndex); return Marshal.PtrToStringUTF8(result); } @@ -168,7 +283,7 @@ private static void ThrowDisposedException() /// /// Frees all resources held by this Plugin. /// - protected virtual void Dispose(bool disposing) + unsafe protected virtual void Dispose(bool disposing) { if (disposing) { @@ -176,7 +291,7 @@ protected virtual void Dispose(bool disposing) } // Free up unmanaged resources - LibExtism.extism_plugin_free(_context.NativeHandle, NativeHandle); + LibExtism.extism_plugin_free(_context.NativeHandle, PluginIndex); } /// diff --git a/dotnet/test/Extism.Sdk/AsyncTests.cs b/dotnet/test/Extism.Sdk/AsyncTests.cs new file mode 100644 index 000000000..3571b4ee8 --- /dev/null +++ b/dotnet/test/Extism.Sdk/AsyncTests.cs @@ -0,0 +1,130 @@ +using System.Diagnostics; +using Extism.Sdk.Native; + +using System.Reflection; +using System.Text; + +using Xunit; +using Xunit.Abstractions; + +namespace Extism.Sdk.Tests; + +public class AsyncTests +{ + private readonly ITestOutputHelper output; + private byte[] count_vowels; + private byte[] sleepMs; + public AsyncTests(ITestOutputHelper output) + { + this.output = output; + var binDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; + count_vowels = File.ReadAllBytes(Path.Combine(binDirectory, "code.wasm")); + sleepMs = File.ReadAllBytes(Path.Combine(binDirectory, "sleepMs.wasm")); + } + + [Fact] + public async void MultipleInvocations_InvokesAsyncMethod_ReturnsExpectedValues() + { + using var context = new Context(); + + + using var plugin = context.CreatePlugin(count_vowels, withWasi: true); + for (int i = 0; i < 1000; i++) + { + var response = await plugin.CallFunctionAsync("count_vowels", Encoding.UTF8.GetBytes("Hello World")); + Assert.Equal("{\"count\": 3}", Encoding.UTF8.GetString(response)); + } + } + + [Fact] + public async void InvokingUnknownFunctionAsync_DoesntWork_ThrowsException() { + using var context = new Context(); + // Test multiple plugin invocations to ensure that plugin calls can be repeated + using var plugin = context.CreatePlugin(count_vowels, withWasi: true); + var exception = await Assert.ThrowsAsync( + async () => { await plugin.CallFunctionAsync("unknown_function_name", Encoding.UTF8.GetBytes("Hello World")); }); + Assert.Equal("Function not found: unknown_function_name", exception.Message); + } + + [Fact] + public async void LongRunningtask_WithScheduledCancellation_ThrowsTaskCanceledException() + { + using var context = new Context(); + using var plugin = context.CreatePlugin(sleepMs, withWasi: true); + + CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + + // Schedule token to automatically cancel after 1 second. + cancellationTokenSource.CancelAfter(1000); + + var sw = Stopwatch.StartNew(); + await Assert.ThrowsAsync(async () => + { + // This should throw after ~1000 ms and not complete execution. + var response = await plugin.CallFunctionAsync("sleepMs", Encoding.UTF8.GetBytes("[50000]"), null, cancellationTokenSource.Token); + Assert.Fail("This should be unreachable."); + }); + + + // Verify that the time that has passed is less than 1.5 seconds (should be very close to 1 second). + output.WriteLine($"Expected: 1000 ms. Actual: {sw.ElapsedMilliseconds}."); + Assert.True(sw.ElapsedMilliseconds < 1500); + + } + + // TODO: No cancellation of multiple parallel plugins (should make the wasm return the # of ms slept or something) + // TODO: Cancel a plugin early and try running it again several times + // TODO: Test native timeout support in Extism Context + + + [Fact] + public async void PrecancelledToken_DoesntRunWASMModule_ThrowsTaskCanceledException() + { + using var context = new Context(); + using var plugin = context.CreatePlugin(count_vowels, withWasi: true); + + CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + var token = cancellationTokenSource.Token; + cancellationTokenSource.Cancel(); + await Assert.ThrowsAsync(async () => + { + await plugin.CallFunctionAsync("count_vowels", Encoding.UTF8.GetBytes("Hello World"), null, token); + }); + // TODO: once Host Functions are implemented, we can verify that a host function is not called to verify the WASM isn't actually executing. + } + + [Fact] + public async void ExecutingPlugin_IsCanceledEarlyWhenCallingCancelMethod_StopsExecutingWithin500ms() + { + using var context = new Context(); + using var plugin = context.CreatePlugin(sleepMs, withWasi: true); + + // Register a WASM function that will take 50 seconds to complete. + // NOTE: Purposely not awaited so plugin can run asynchronously in a parallel with our cancellation request. + var functionInvocationTask = plugin.CallFunctionAsync("sleepMs", Encoding.UTF8.GetBytes("[50000]")); + + // Cancel the plugin after 1 second + Stopwatch sw = new Stopwatch(); + await Task.Run(async () => + { + sw.Start(); + await Task.Delay(1000); + + plugin.Cancel(); + }); + + // Verify that the task throws a TaskCanceledException + await Assert.ThrowsAsync(async () => + { + var response = await functionInvocationTask; + }); + + sw.Stop(); + + // Verify that the time that has passed is less than 1.5 seconds (should be very close to 1 second). + output.WriteLine($"Expected: 1000 ms. Actual: {sw.ElapsedMilliseconds}."); + Assert.True(sw.ElapsedMilliseconds < 1500); + + } + +} \ No newline at end of file diff --git a/dotnet/test/Extism.Sdk/BasicTests.cs b/dotnet/test/Extism.Sdk/BasicTests.cs index 9ff0f1bec..70bcd50af 100644 --- a/dotnet/test/Extism.Sdk/BasicTests.cs +++ b/dotnet/test/Extism.Sdk/BasicTests.cs @@ -1,24 +1,84 @@ +using System.Diagnostics; using Extism.Sdk.Native; using System.Reflection; using System.Text; using Xunit; +using Xunit.Abstractions; namespace Extism.Sdk.Tests; public class BasicTests { + private byte[] count_vowels; + public BasicTests(ITestOutputHelper output) + { + var binDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; + count_vowels = File.ReadAllBytes(Path.Combine(binDirectory, "code.wasm")); + } + [Fact] - public void CountHelloWorldVowels() + public void SingleInvocation_InvokesMethod_ReturnsExpectedValue() { using var context = new Context(); - - var binDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; - var wasm = File.ReadAllBytes(Path.Combine(binDirectory, "code.wasm")); - using var plugin = context.CreatePlugin(wasm, withWasi: true); - + // Test multiple plugin invocations to ensure that plugin calls can be repeated + using var plugin = context.CreatePlugin(count_vowels, withWasi: true); + var response = plugin.CallFunction("count_vowels", Encoding.UTF8.GetBytes("Hello World")); + Assert.Equal("{\"count\": 3}", Encoding.UTF8.GetString(response)); + } + + [Fact] + public void PluginWithoutWasi_CanBeInvokedWithoutWasi_ReturnsExpectedValue() + { + using var context = new Context(); + // Test multiple plugin invocations to ensure that plugin calls can be repeated + using var plugin = context.CreatePlugin(count_vowels, withWasi: false); var response = plugin.CallFunction("count_vowels", Encoding.UTF8.GetBytes("Hello World")); Assert.Equal("{\"count\": 3}", Encoding.UTF8.GetString(response)); } + + + [Fact] + public void MultipleInvocations_InvokesMethod_ReturnsExpectedValues() + { + using var context = new Context(); + // Test multiple plugin invocations to ensure that plugin calls can be repeated + using var plugin = context.CreatePlugin(count_vowels, withWasi: true); + for (int i = 0; i < 1000; i++) + { + var response = plugin.CallFunction("count_vowels", Encoding.UTF8.GetBytes("Hello World")); + Assert.Equal("{\"count\": 3}", Encoding.UTF8.GetString(response)); + } + } + + [Fact] + public void InvokingUnknownFunction_DoesntWork_ThrowsException() { + using var context = new Context(); + // Test multiple plugin invocations to ensure that plugin calls can be repeated + using var plugin = context.CreatePlugin(count_vowels, withWasi: true); + var exception = Assert.Throws( + () => { plugin.CallFunction("unknown_function_name", Encoding.UTF8.GetBytes("Hello World")); }); + Assert.Equal("Function not found: unknown_function_name", exception.Message); + } + // TODO: implement these from Java SDK + /* + @Test + public void shouldInvokeFunctionWithMemoryOptions() { + //FIXME check whether memory options are effective + var manifest = new Manifest(List.of(CODE.pathWasmSource()), new MemoryOptions(0)); + var output = Extism.invokeFunction(manifest, "count_vowels", "Hello World"); + assertThat(output).isEqualTo("{\"count\": 3}"); + } + + @Test + public void shouldInvokeFunctionWithConfig() { + //FIXME check if config options are available in wasm call + var config = Map.of("key1", "value1"); + var manifest = new Manifest(List.of(CODE.pathWasmSource()), null, config); + var output = Extism.invokeFunction(manifest, "count_vowels", "Hello World"); + assertThat(output).isEqualTo("{\"count\": 3}"); + } + */ + } \ No newline at end of file diff --git a/dotnet/test/Extism.Sdk/Extism.Sdk.Tests.csproj b/dotnet/test/Extism.Sdk/Extism.Sdk.Tests.csproj index 1f2b092e9..887f6b3c1 100644 --- a/dotnet/test/Extism.Sdk/Extism.Sdk.Tests.csproj +++ b/dotnet/test/Extism.Sdk/Extism.Sdk.Tests.csproj @@ -1,13 +1,11 @@ - - + + net7.0 enable enable - false - @@ -20,16 +18,25 @@ all - PreserveNewest + + PreserveNewest + - - + + + + + + + PreserveNewest + + \ No newline at end of file diff --git a/dotnet/test/Extism.Sdk/HostFunctionTests.cs b/dotnet/test/Extism.Sdk/HostFunctionTests.cs new file mode 100644 index 000000000..5b50b976d --- /dev/null +++ b/dotnet/test/Extism.Sdk/HostFunctionTests.cs @@ -0,0 +1,30 @@ +using System.Diagnostics; +using Extism.Sdk.Native; + +using System.Reflection; +using System.Text; + +using Xunit; +using Xunit.Abstractions; + +namespace Extism.Sdk.Tests; + +public class HostFunctionTests +{ + private byte[] code_functions; + public HostFunctionTests(ITestOutputHelper output) + { + var binDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!; + code_functions = File.ReadAllBytes(Path.Combine(binDirectory, "code-functions.wasm")); + } + + [Fact] + public void InvokeCodeThatDependsOnHostFunction_WithNoHostFunctionDefined_ThrowsException() + { + using var context = new Context(); + + var exception = Assert.Throws(() => context.CreatePlugin(code_functions, withWasi: true)); + Assert.Equal("unknown import: `env::hello_world` has not been defined", exception.Message); + } + +} \ No newline at end of file diff --git a/dotnet/test/Extism.Sdk/README.md b/dotnet/test/Extism.Sdk/README.md new file mode 100644 index 000000000..a8e5e3642 --- /dev/null +++ b/dotnet/test/Extism.Sdk/README.md @@ -0,0 +1,6 @@ + +# Test wasm files +## code.wasm +code.wasm is an example Rust wasm application that has an exposed method `count_vowels`. It takes in a string and returns a json object of the form `{"count": 3}`. +## sleepMs.wasm +sleepMs.wasm is a js wasm module that takes in a json array with a single numeric (eg. `[5000]`) and sleeps for that many milliseconds before returning. \ No newline at end of file diff --git a/dotnet/test/Extism.Sdk/code-functions.wasm b/dotnet/test/Extism.Sdk/code-functions.wasm new file mode 100644 index 000000000..8301227b6 Binary files /dev/null and b/dotnet/test/Extism.Sdk/code-functions.wasm differ diff --git a/dotnet/test/Extism.Sdk/sleepMs.wasm b/dotnet/test/Extism.Sdk/sleepMs.wasm new file mode 100644 index 000000000..a325ad30b Binary files /dev/null and b/dotnet/test/Extism.Sdk/sleepMs.wasm differ