Skip to content
6 changes: 4 additions & 2 deletions js/node/lib/binding.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { isMainThread } from 'worker_threads';

import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common';

type SessionOptions = InferenceSession.SessionOptions;
Expand Down Expand Up @@ -57,7 +59,7 @@ export const binding =
// eslint-disable-next-line @typescript-eslint/naming-convention
InferenceSession: Binding.InferenceSessionConstructor;
listSupportedBackends: () => Binding.SupportedBackend[];
initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void;
initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor, isMainThread: boolean) => void;
};

let ortInitialized = false;
Expand Down Expand Up @@ -86,6 +88,6 @@ export const initOrt = (): void => {
throw new Error(`Unsupported log level: ${env.logLevel}`);
}
}
binding.initOrtOnce(logLevel, Tensor);
binding.initOrtOnce(logLevel, Tensor, isMainThread);
}
};
30 changes: 25 additions & 5 deletions js/node/src/inference_session_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,32 @@ Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) {

int log_level = info[0].As<Napi::Number>().Int32Value();
Napi::Function tensorConstructor = info[1].As<Napi::Function>();
bool is_main_thread = info[2].As<Napi::Boolean>().Value();

OrtInstanceData::InitOrt(env, log_level, tensorConstructor);
OrtInstanceData::InitOrt(env, log_level, tensorConstructor, is_main_thread);

return env.Undefined();
}

InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info)
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false), disposed_(false), session_(nullptr) {}

InferenceSessionWrap::~InferenceSessionWrap() {
// If the ORT singleton has already been destroyed (e.g. during process shutdown when the
// cleanup hook fires before N-API finalizers run), we must not call into ORT to
// release owned ORT objects — doing so would crash. Intentionally leak in that case.
if (!OrtSingletonData::GetOrtObjects()) {
for (auto& type_info : inputTypes_) {
(void)type_info.release();
}
for (auto& type_info : outputTypes_) {
(void)type_info.release();
}
(void)ioBinding_.release();
(void)session_.release();
}
}

Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
Expand All @@ -73,7 +90,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
Napi::String value = info[0].As<Napi::String>();

ParseSessionOptions(info[1].As<Napi::Object>(), sessionOptions);
this->session_.reset(new Ort::Session(OrtSingletonData::Env(),
this->session_.reset(new Ort::Session(OrtSingletonData::GetOrtObjects()->env,
#ifdef _WIN32
reinterpret_cast<const wchar_t*>(value.Utf16Value().c_str()),
#else
Expand All @@ -88,7 +105,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
int64_t bytesLength = info[2].As<Napi::Number>().Int64Value();

ParseSessionOptions(info[3].As<Napi::Object>(), sessionOptions);
this->session_.reset(new Ort::Session(OrtSingletonData::Env(),
this->session_.reset(new Ort::Session(OrtSingletonData::GetOrtObjects()->env,
reinterpret_cast<char*>(buffer) + bytesOffset, bytesLength,
sessionOptions));
} else {
Expand Down Expand Up @@ -208,7 +225,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
ParseRunOptions(info[2].As<Napi::Object>(), runOptions);
}
if (preferredOutputLocations_.size() == 0) {
session_->Run(runOptions == nullptr ? OrtSingletonData::DefaultRunOptions() : runOptions,
session_->Run(runOptions == nullptr ? OrtSingletonData::GetOrtObjects()->default_run_options : runOptions,
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
Expand Down Expand Up @@ -237,7 +254,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
}
}

session_->Run(runOptions == nullptr ? OrtSingletonData::DefaultRunOptions() : runOptions, *ioBinding_);
session_->Run(runOptions == nullptr ? OrtSingletonData::GetOrtObjects()->default_run_options : runOptions, *ioBinding_);

auto outputs = ioBinding_->GetOutputValues();
ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch.");
Expand All @@ -260,6 +277,9 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");

this->inputTypes_.clear();
this->outputTypes_.clear();

this->ioBinding_.reset(nullptr);
this->session_.reset(nullptr);

Expand Down
1 change: 1 addition & 0 deletions js/node/src/inference_session_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
static Napi::Object Init(Napi::Env env, Napi::Object exports);

InferenceSessionWrap(const Napi::CallbackInfo& info);
~InferenceSessionWrap();

private:
/**
Expand Down
10 changes: 4 additions & 6 deletions js/node/src/ort_instance_data.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <atomic>
#include <mutex>

#include "common.h"
#include "ort_instance_data.h"
#include "ort_singleton_data.h"
Expand All @@ -19,14 +16,15 @@ void OrtInstanceData::Create(Napi::Env env, Napi::Function inferenceSessionWrapp
env.SetInstanceData(data);
}

void OrtInstanceData::InitOrt(Napi::Env env, int log_level, Napi::Function tensorConstructor) {
void OrtInstanceData::InitOrt(Napi::Env env, int log_level, Napi::Function tensorConstructor, bool is_main_thread) {
auto data = env.GetInstanceData<OrtInstanceData>();
ORT_NAPI_THROW_ERROR_IF(data == nullptr, env, "OrtInstanceData not created.");

data->ortTensorConstructor = Napi::Persistent(tensorConstructor);

// Only the first time call to OrtSingletonData::GetOrCreateOrtObjects() will create the Ort::Env
OrtSingletonData::GetOrCreateOrtObjects(log_level);
// Initialize ORT singleton and register cleanup hook for this env.
// The first call creates the OrtObjects; subsequent calls increment the ref count.
OrtSingletonData::InitOrtObjects(env, log_level, is_main_thread);
}

const Napi::FunctionReference& OrtInstanceData::TensorConstructor(Napi::Env env) {
Expand Down
2 changes: 1 addition & 1 deletion js/node/src/ort_instance_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct OrtInstanceData {
// Create a new OrtInstanceData object related to the Napi::Env
static void Create(Napi::Env env, Napi::Function inferenceSessionWrapperFunction);
// Initialize Ort for the Napi::Env
static void InitOrt(Napi::Env env, int log_level, Napi::Function tensorConstructor);
static void InitOrt(Napi::Env env, int log_level, Napi::Function tensorConstructor, bool is_main_thread);
// Get the Tensor constructor reference for the Napi::Env
static const Napi::FunctionReference& TensorConstructor(Napi::Env env);

Expand Down
41 changes: 34 additions & 7 deletions js/node/src/ort_singleton_data.cc
Original file line number Diff line number Diff line change
@@ -1,22 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <atomic>
#include <mutex>

#include "ort_singleton_data.h"

namespace {
std::mutex ort_singleton_mutex;
std::atomic<OrtSingletonData::OrtObjects*> ort_objects{nullptr};
std::atomic<int> ref_count{0};
} // namespace

OrtSingletonData::OrtObjects::OrtObjects(int log_level)
: env{OrtLoggingLevel(log_level), "onnxruntime-node"},
default_run_options{} {
}

OrtSingletonData::OrtObjects& OrtSingletonData::GetOrCreateOrtObjects(int log_level) {
static OrtObjects ort_objects(log_level);
return ort_objects;
void OrtSingletonData::InitOrtObjects(napi_env env, int log_level,
bool is_main_thread) {
{
std::lock_guard<std::mutex> lock(ort_singleton_mutex);
if (!ort_objects.load(std::memory_order_relaxed)) {
ort_objects.store(new OrtObjects(log_level), std::memory_order_release);
}
ref_count++;
}

// Register a cleanup hook for this napi_env. The hook will be called when this env is torn down.
// We encode the is_main_thread flag directly into the void* arg to avoid a heap allocation.
napi_add_env_cleanup_hook(env, CleanupHook, reinterpret_cast<void*>(static_cast<uintptr_t>(is_main_thread)));
}

const Ort::Env& OrtSingletonData::Env() {
return GetOrCreateOrtObjects().env;
void OrtSingletonData::CleanupHook(void* arg) {
bool is_main_thread = static_cast<bool>(reinterpret_cast<uintptr_t>(arg));

std::lock_guard<std::mutex> lock(ort_singleton_mutex);
ref_count--;

if (ref_count == 0 && is_main_thread) {
delete ort_objects.load(std::memory_order_relaxed);
ort_objects.store(nullptr, std::memory_order_release);
}
}

const Ort::RunOptions& OrtSingletonData::DefaultRunOptions() {
return GetOrCreateOrtObjects().default_run_options;
OrtSingletonData::OrtObjects* OrtSingletonData::GetOrtObjects() {
return ort_objects.load(std::memory_order_acquire);
}
31 changes: 24 additions & 7 deletions js/node/src/ort_singleton_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,22 @@
* - The Ort::RunOptions singleton instance.
* This is an empty default RunOptions instance. It is created once to allow reuse across all session inference runs.
*
* The OrtSingletonData class uses the "Meyers Singleton" pattern to ensure thread-safe lazy initialization, as well as
* proper destruction order at program exit.
* The OrtSingletonData class uses a ref-counted, heap-allocated singleton with best-effort cleanup.
*
* Each napi_env (one per thread) that initializes ORT increments a ref count and registers a cleanup hook via
* napi_add_env_cleanup_hook. When the hook fires, the ref count is decremented. The singleton is only deleted when:
* 1. The ref count reaches 0, AND
* 2. The cleanup hook is running on the main thread (determined by the isMainThread flag from worker_threads).
*
* This ensures:
* - On normal single-threaded usage, the singleton is properly destroyed (no leak).
* - On multi-threaded usage where workers exit before the main thread, the main thread's hook fires last with
* ref count 0 and performs the cleanup.
* - If cleanup hooks don't fire (e.g., uncaught exception — see https://github.com/nodejs/node/issues/58341),
* the ref count stays >0 and the singleton safely leaks, avoiding crashes from calling into an already-unloaded
* onnxruntime shared library.
* - If the main thread's hook fires but workers are still alive (e.g., process.exit()), the ref count is >0 and
* the singleton safely leaks.
*/
struct OrtSingletonData {
struct OrtObjects {
Expand All @@ -30,11 +44,14 @@ struct OrtSingletonData {
friend struct OrtSingletonData;
};

static OrtObjects& GetOrCreateOrtObjects(int log_level = ORT_LOGGING_LEVEL_WARNING);
// Initialize ORT objects and register a cleanup hook for the given napi_env.
// Each napi_env (thread) should call this once.
// is_main_thread should be set to true if the calling thread is the main thread (from worker_threads.isMainThread).
static void InitOrtObjects(napi_env env, int log_level, bool is_main_thread);

// Get the global Ort::Env
static const Ort::Env& Env();
// Get the ORT singleton objects. Returns nullptr if the singleton has been destroyed.
static OrtObjects* GetOrtObjects();

// Get the default Ort::RunOptions
static const Ort::RunOptions& DefaultRunOptions();
private:
static void CleanupHook(void* arg);
};
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ class BucketCacheManager : public IBufferCacheManager {
wgpuBufferRelease(buffer);
}
}
for (auto& buffer_info : pending_buffers_) {
wgpuBufferRelease(buffer_info.first);
}
}

protected:
Expand Down
40 changes: 26 additions & 14 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,15 @@

// TODO: revise temporary error handling
device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, wgpu::StringView message) {
LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << std::string_view{message};
if (logging::LoggingManager::HasDefaultLogger()) {
LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << std::string_view{message};
}
});
// TODO: revise temporary device lost handling
device_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device& /*device*/, wgpu::DeviceLostReason reason, wgpu::StringView message) {
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
if (logging::LoggingManager::HasDefaultLogger()) {
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
}
});

ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter.RequestDevice(
Expand Down Expand Up @@ -924,10 +928,11 @@
}
}

std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo>* WebGpuContextFactory::contexts_ = nullptr;
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;
wgpu::Instance WebGpuContextFactory::default_instance_;

std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo>* WebGpuContextFactory::contexts_ = nullptr;
WGPUInstance WebGpuContextFactory::default_instance_ = nullptr;

WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
const int context_id = config.context_id;
Expand Down Expand Up @@ -960,18 +965,13 @@

std::lock_guard<std::mutex> lock(mutex_);

// Lazy-allocate the contexts map on first use (heap-allocated to avoid static destruction crash).
if (contexts_ == nullptr) {
contexts_ = new std::unordered_map<int32_t, WebGpuContextInfo>();
}

if (default_instance_ == nullptr) {
// Create wgpu::Instance
wgpu::InstanceFeatureName required_instance_features[] = {wgpu::InstanceFeatureName::TimedWaitAny};
wgpu::InstanceDescriptor instance_desc{};
instance_desc.requiredFeatures = required_instance_features;
instance_desc.requiredFeatureCount = sizeof(required_instance_features) / sizeof(required_instance_features[0]);
default_instance_ = wgpu::CreateInstance(&instance_desc);
default_instance_ = wgpu::CreateInstance(&instance_desc).MoveToCHandle();

ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance.");
}
Expand All @@ -981,13 +981,18 @@
ORT_ENFORCE(instance == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");

instance = default_instance_.Get();
instance = default_instance_;
} else {
// for context ID > 0, user must provide custom WebGPU instance and device.
ORT_ENFORCE(instance != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device.");
}

// Lazy-allocate the contexts map on first use (heap-allocated to avoid static destruction crash).
if (contexts_ == nullptr) {
contexts_ = new std::unordered_map<int32_t, WebGpuContextInfo>();

Check warning on line 993 in onnxruntime/core/providers/webgpu/webgpu_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_context.cc:993: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
}

auto it = contexts_->find(context_id);
if (it == contexts_->end()) {
GSL_SUPPRESS(r.11)
Expand Down Expand Up @@ -1034,9 +1039,16 @@

void WebGpuContextFactory::Cleanup() {
std::lock_guard<std::mutex> lock(mutex_);
delete contexts_;
contexts_ = nullptr;
default_instance_ = nullptr;

if (contexts_ != nullptr) {
delete contexts_;
contexts_ = nullptr;
}

if (default_instance_ != nullptr) {
wgpuInstanceRelease(default_instance_);
default_instance_ = nullptr;
}
}

WebGpuContext& WebGpuContextFactory::DefaultContext() {
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class WebGpuContextFactory {
private:
WebGpuContextFactory() {}

static std::mutex mutex_;
static std::once_flag init_default_flag_;

// Use pointers to heap-allocated objects so that their destructors do NOT run
// during static destruction at process exit. This avoids crashes when dependent
// DLLs (e.g. dxcompiler.dll) have already been unloaded by the OS.
Expand All @@ -155,9 +158,7 @@ class WebGpuContextFactory {
// it is reached from OrtEnv::~OrtEnv via CleanupWebGpuContexts().
// On abnormal/process termination they simply leak, which is safe.
static std::unordered_map<int32_t, WebGpuContextInfo>* contexts_;
static std::mutex mutex_;
static std::once_flag init_default_flag_;
static wgpu::Instance default_instance_;
static WGPUInstance default_instance_;
};

// Class WebGpuContext includes all necessary resources for the context.
Expand Down
Loading