Skip to content

Commit bb36f44

Browse files
committed
pyodide fixes
1 parent fbc840f commit bb36f44

6 files changed

Lines changed: 89 additions & 47 deletions

File tree

webgpu/canvas.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77

88
from . import platform
9-
from .utils import get_device, read_texture
9+
from .utils import get_device, read_texture, Lock
1010
from .webgpu_api import *
1111

1212
_TARGET_FPS = 60
@@ -20,9 +20,6 @@ class _DebounceData:
2020

2121
def debounce(arg=None):
2222
def decorator(func):
23-
if platform.is_pyodide:
24-
return arg
25-
2623
# Render only once every 1/_TARGET_FPS seconds
2724
@functools.wraps(func)
2825
def debounced(obj, *args, **kwargs):
@@ -35,16 +32,25 @@ def debounced(obj, *args, **kwargs):
3532

3633
data = obj._debounce_data[fname]
3734

38-
if data.timer is not None:
39-
# we already have a render scheduled, so do nothing
40-
return
35+
# check if we already have a render scheduled
36+
if platform.is_pyodide:
37+
if data.timer is not None and not data.timer.done():
38+
return
39+
else:
40+
if data.timer is not None:
41+
return
4142

4243
def f():
4344
# clear the timer, so we can schedule a new one with the next function call
4445
t = time.time()
4546
data.timer = None
46-
func(obj, *args, **kwargs)
47-
data.t_last = t
47+
if platform.is_pyodide:
48+
# due to async nature, we need to update t_last before calling func
49+
data.t_last = t
50+
func(obj, *args, **kwargs)
51+
else:
52+
func(obj, *args, **kwargs)
53+
data.t_last = t
4854

4955
if data.t_last is None:
5056
# first call -> just call the function immediately
@@ -53,8 +59,16 @@ def f():
5359
return
5460

5561
t_wait = max(1 / target_fps - (time.time() - data.t_last), 0)
56-
data.timer = threading.Timer(t_wait, f)
57-
data.timer.start()
62+
if platform.is_pyodide:
63+
import asyncio
64+
async def _runner():
65+
if t_wait > 0:
66+
await asyncio.sleep(t_wait)
67+
f()
68+
data.timer = asyncio.create_task(_runner())
69+
else:
70+
data.timer = threading.Timer(t_wait, f)
71+
data.timer.start()
5872

5973
return debounced
6074

@@ -91,7 +105,7 @@ class Canvas:
91105
_on_update_html_canvas: list[Callable]
92106

93107
def __init__(self, device, canvas, multisample_count=4):
94-
self._update_mutex = threading.RLock()
108+
self._update_mutex = Lock()
95109
self.target_texture = None
96110

97111
self._on_resize_callbacks = []

webgpu/input_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import threading
21
from typing import Callable
2+
from .utils import Lock
33

44

55
class InputHandler:
@@ -24,7 +24,7 @@ def get_set(self):
2424
return s
2525

2626
def __init__(self):
27-
self._mutex = threading.Lock()
27+
self._mutex = Lock(True)
2828
self._callbacks = {}
2929
self._js_handlers = {}
3030
self._is_mousedown = False

webgpu/link/base.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import base64
33
import itertools
44
import json
5+
import time
56
import threading
67
from collections.abc import Mapping
78
from typing import Callable
@@ -96,6 +97,8 @@ def __init__(self):
9697
self._objects = {}
9798
self._cache = {}
9899

100+
next(self._request_id) # make sure first id is 1, in case 0 is interpreted as None
101+
99102
def _call_data(self, id, prop, args, ignore_result=False):
100103
buffer = []
101104
args = self._dump_data(args, buffer)
@@ -313,7 +316,11 @@ async def _on_message_async(self, message: str):
313316
self._requests[request_id] = self._load_data(data.get("value", None))
314317
if key and data.get("cache", False):
315318
self._cache[key] = self._requests[request_id]
316-
event.set()
319+
320+
if isinstance(event, asyncio.Future):
321+
event.set_result(self._requests[request_id])
322+
else:
323+
event.set()
317324
return
318325

319326
case "call":
@@ -407,11 +414,10 @@ def _on_message(self, message: str):
407414

408415

409416
class PyodideLink(LinkBase):
410-
def __init__(self, send_message, size_buffer, result_buffer):
417+
def __init__(self, send_message):
411418
super().__init__()
412419
self._send_message = send_message
413-
self._size_buffer = size_buffer
414-
self._result_buffer = result_buffer
420+
self._requests = {}
415421

416422
def create_proxy(self, func, ignore_return_value=False):
417423
id_ = id(func)
@@ -424,26 +430,23 @@ def create_proxy(self, func, ignore_return_value=False):
424430
}
425431

426432
def _send_data(self, metadata, data, key=None):
427-
if type(data) is bytes:
428-
self._send_message(data)
429-
else:
430-
if (
431-
metadata.get("request_id", None) is not None
432-
and metadata["type"] != "response"
433-
and not metadata.get("ignore_return_value", False)
434-
):
435-
import js
436-
437-
js.Atomics.store(self._size_buffer, 0, 0)
438-
self._send_message(data)
439-
js.Atomics.wait(self._size_buffer, 0, 0, 10000)
440-
n = self._size_buffer[0]
441-
res = bytes(self._result_buffer.slice(0, n))
442-
s = res.decode("utf-8")
443-
data = json.loads(s)
444-
return self._load_data(data.get("value", None))
445-
else:
446-
self._send_message(data)
433+
"""Send data to the remote environment,
434+
if request_id is set, (blocking-)wait for the response and return it"""
435+
request_id = metadata.get("request_id", None)
436+
type = metadata.get("type", None)
437+
event = None
438+
self._send_message(data)
439+
if type != "response" and request_id is not None:
440+
# from pyodide.ffi import run_sync
441+
import asyncio
442+
event = asyncio.Future()
443+
self._requests[request_id] = event, key
444+
# todo: this shouldn't be necessary
445+
# but run_sync(event) gives an error
446+
while not event.done():
447+
time.sleep(0.001)
448+
449+
return self._requests.pop(request_id)
447450

448451

449452
class LinkBaseAsync(LinkBase):

webgpu/platform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def create_event_handler(
3535

3636

3737
try:
38-
import js
38+
import js as pyodide_js
3939
import pyodide.ffi
4040
from pyodide.ffi import JsPromise, JsProxy
4141
from pyodide.ffi import create_proxy as _create_proxy
@@ -48,7 +48,7 @@ def destroy_proxy(proxy):
4848

4949
is_pyodide = True
5050
try:
51-
is_pyodide_main_thread = bool(js.window.document)
51+
is_pyodide_main_thread = bool(pyodide_js.window.document)
5252
except:
5353
is_pyodide_main_thread = False
5454

@@ -88,7 +88,7 @@ def toJS(value):
8888
value = _convert(value)
8989
ret = pyodide.ffi.to_js(
9090
value,
91-
dict_converter=js.Object.fromEntries,
91+
dict_converter=pyodide_js.Object.fromEntries,
9292
default_converter=_default_converter,
9393
create_pyproxies=False,
9494
)

webgpu/scene.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .canvas import Canvas, debounce
66
from .input_handler import InputHandler
77
from .renderer import BaseRenderer, RenderOptions, SelectEvent
8-
from .utils import max_bounding_box, read_buffer
8+
from .utils import max_bounding_box, read_buffer, Lock
99
from .platform import is_pyodide, is_pyodide_main_thread
1010
from .webgpu_api import *
1111
from .camera import Camera
@@ -88,7 +88,7 @@ def init(self, canvas):
8888
self.input_handler.set_canvas(canvas.canvas)
8989
self.options.set_canvas(canvas)
9090

91-
self._render_mutex = canvas._update_mutex
91+
self._render_mutex = Lock(True) if is_pyodide else canvas._update_mutex
9292

9393
with self._render_mutex:
9494
self.options.timestamp = time.time()
@@ -110,7 +110,6 @@ def init(self, canvas):
110110
)
111111
self._select_buffer_valid = False
112112

113-
canvas._update_mutex = self._render_mutex
114113
canvas.on_resize(self.render)
115114

116115
canvas.on_update_html_canvas(self.__on_update_html_canvas)

webgpu/utils.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,38 @@
55
from . import platform
66
from .webgpu_api import *
77
from .webgpu_api import toJS as to_js
8-
import threading
98

10-
_device: Device = None
9+
try:
10+
import pyodide.ffi
11+
import asyncio
12+
class Lock():
13+
def __init__(self, do: bool = False):
14+
self._lock = asyncio.Lock()
15+
self.do = do
16+
17+
def __enter__(self):
18+
if self.do:
19+
pyodide.ffi.run_sync(self._lock.acquire())
20+
21+
def __exit__(self, exc_type, exc, tb):
22+
if self.do:
23+
self._lock.release()
24+
25+
except ImportError:
26+
from threading import RLock
1127

12-
_lock_init_device = threading.Lock()
28+
class Lock():
29+
def __init__(self, do: bool = True):
30+
self._lock = RLock()
1331

32+
def __enter__(self):
33+
self._lock.acquire()
34+
35+
def __exit__(self, exc_type, exc, tb):
36+
self._lock.release()
37+
38+
_lock_init_device = Lock()
39+
_device: Device = None
1440

1541
def init_device_sync():
1642
global _device

0 commit comments

Comments
 (0)