Skip to content

Commit 0f5099f

Browse files
committed
feat: seperate module call and forward
1 parent db90bcb commit 0f5099f

3 files changed

Lines changed: 114 additions & 3 deletions

File tree

examples/pyodide/bridge.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,11 @@ def __init__(self, js_module):
479479

480480
def __call__(self, *args):
481481
js_args = [a._js if isinstance(a, Tensor) else a for a in args]
482-
return Tensor(self._module.forward(*js_args))
482+
return Tensor(self._module.call(*js_args))
483483

484484
def forward(self, *args):
485-
return self(*args)
485+
js_args = [a._js if isinstance(a, Tensor) else a for a in args]
486+
return Tensor(self._module.forward(*js_args))
486487

487488
def parameters(self):
488489
return [Tensor(p) for p in self._module.parameters().to_py()]

examples/pyodide/py/nn_module.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# nn_module.py — Tests for custom nn.Module subclasses, parameter registration,
2+
# call vs forward separation, and nested models.
3+
4+
class MyLinearLayer(torch.nn.Module):
5+
def __init__(self, in_features, out_features):
6+
super().__init__()
7+
self.weight = torch.nn.Parameter(torch.rand(in_features, out_features))
8+
self.bias = torch.nn.Parameter(torch.rand(out_features))
9+
10+
def forward(self, input):
11+
return input @ self.weight + self.bias
12+
13+
class MySmallModel(torch.nn.Module):
14+
def __init__(self, in_features, intermediate_features, out_features):
15+
super().__init__()
16+
# Using our own defined layer
17+
self.lin1 = MyLinearLayer(in_features, intermediate_features)
18+
# Using pre-defined Linear Layer
19+
self.lin2 = torch.nn.Linear(intermediate_features, out_features)
20+
21+
def forward(self, x):
22+
x = self.lin1(x)
23+
x = self.lin2(x)
24+
return x
25+
26+
27+
print("=== MyLinearLayer: output shape ===")
28+
layer = MyLinearLayer(4, 3)
29+
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
30+
out = layer(x)
31+
print("output shape:", list(out.shape))
32+
print("> output shape correct:", list(out.shape) == [1, 3])
33+
34+
print("\n=== MyLinearLayer: parameter registration ===")
35+
params = list(layer.parameters())
36+
print("num parameters:", len(params))
37+
print("> num parameters:", len(params) == 2) # weight + bias
38+
for name, p in layer.named_parameters():
39+
print(f" {name}: shape={list(p.shape)}")
40+
print("> weight shape:", list(layer.weight.shape) == [4, 3])
41+
print("> bias shape:", list(layer.bias.shape) == [3])
42+
43+
print("\n=== MyLinearLayer: __call__ vs forward() ===")
44+
out_call = layer(x)
45+
out_forward = layer.forward(x)
46+
print("> outputs match:", torch.allclose(out_call, out_forward))
47+
48+
print("\n=== MyLinearLayer: backward ===")
49+
layer.zero_grad()
50+
layer(x).sum().backward()
51+
print("> weight.grad exists:", layer.weight.grad is not None)
52+
print("> bias.grad exists:", layer.bias.grad is not None)
53+
54+
print("\n=== MySmallModel: output shape ===")
55+
model = MySmallModel(4, 8, 2)
56+
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
57+
out = model(x)
58+
print("output shape:", list(out.shape))
59+
print("> output shape correct:", list(out.shape) == [1, 2])
60+
61+
print("\n=== MySmallModel: parameters collected from both sub-modules ===")
62+
params = list(model.parameters())
63+
# lin1: weight (4x8) + bias (8) = 2 params
64+
# lin2: weight (8x2) + bias (2) = 2 params
65+
print("num parameters:", len(params))
66+
print("> num parameters:", len(params) == 4)
67+
for name, p in model.named_parameters():
68+
print(f" {name}: shape={list(p.shape)}")
69+
70+
print("\n=== MySmallModel: __call__ vs forward() ===")
71+
out_call = model(x)
72+
out_forward = model.forward(x)
73+
print("> outputs match:", torch.allclose(out_call, out_forward))
74+
75+
print("\n=== MySmallModel: backward through nested modules ===")
76+
model.zero_grad()
77+
model(x).sum().backward()
78+
print("> all grads computed:", all(p.grad is not None for p in model.parameters()))
79+
80+
print("\n=== __call__ vs forward() on a built-in module ===")
81+
fc = torch.nn.Linear(3, 2)
82+
x = torch.tensor([[1.0, 2.0, 3.0]])
83+
out_call = fc(x)
84+
out_forward = fc.forward(x)
85+
print("> outputs match:", torch.allclose(out_call, out_forward))
86+
print("> shapes match:", out_call.shape == out_forward.shape)
87+
88+
print("\n=== Sequential: submodules run via __call__ path ===")
89+
seq = torch.nn.Sequential(
90+
torch.nn.Linear(2, 4),
91+
torch.nn.ReLU(),
92+
torch.nn.Linear(4, 1),
93+
)
94+
x = torch.tensor([[1.0, 2.0]])
95+
out_seq_call = seq(x)
96+
out_seq_forward = seq.forward(x)
97+
print("output shape:", list(out_seq_call.shape))
98+
print("> output shape correct:", list(out_seq_call.shape) == [1, 1])
99+
print("> call and forward match:", torch.allclose(out_seq_call, out_seq_forward))
100+
101+
print("\nAll nn_module checks passed.")

src/nn/base.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ export abstract class Module {
3131

3232
public abstract forward(...args: Tensor[]): Tensor;
3333

34+
/**
35+
* Entry point for running the module. Equivalent to `model(x)` in Python.
36+
* In the future, this is where forward hooks will be triggered.
37+
* Call `forward()` directly to bypass hooks.
38+
*/
39+
public call(...args: Tensor[]): Tensor {
40+
return this.forward(...args);
41+
}
42+
3443
public train(mode: boolean = true): this {
3544
this.training = mode;
3645
for (const module of Object.values(this._modules)) {
@@ -100,7 +109,7 @@ export class Sequential extends Module {
100109
forward(input: Tensor) {
101110
let x = input;
102111
for (const module of this._modulesArr) {
103-
x = module.forward(x);
112+
x = module.call(x);
104113
}
105114
return x;
106115
}

0 commit comments

Comments
 (0)