diff --git a/test/test_interop.py b/test/test_interop.py index 571a1d5..2c85ea8 100644 --- a/test/test_interop.py +++ b/test/test_interop.py @@ -127,7 +127,7 @@ def forward(self, x): actual = m_jitted(x) # assert - torch.testing.assert_allclose(actual, expected) + torch.testing.assert_close(actual, expected) # arrange # make sure buffer donation works @@ -139,7 +139,7 @@ def forward(self, x): # act actual = functional_forward(m_jitted.params, m_jitted.buffers, x) # assert - torch.testing.assert_allclose(actual, expected) + torch.testing.assert_close(actual, expected) def test_to_jax_device(self): a = torch.ones(3, 3) diff --git a/test/test_train.py b/test/test_train.py index 1c820fe..86c81b4 100644 --- a/test/test_train.py +++ b/test/test_train.py @@ -46,7 +46,9 @@ def test_scan_module(self): x = x.to("jax") model.to("jax") result2 = model(x) - torch.testing.assert_allclose(result, result2.to("cpu")) + # Explicit rtol/atol match previous assert_allclose's defaults for float32 (1e-4 / 1e-5) + # to accommodate small numerical drift accumulating from eager loop vs ScannedModule executions. + torch.testing.assert_close(result, result2.to("cpu"), rtol=1e-4, atol=1e-5) def test_train_step_can_run(self): import optax