From 699b1386877396a16fe8d7b0433bfbab083e04ad Mon Sep 17 00:00:00 2001 From: Rongkang Xiong Date: Tue, 7 May 2024 10:32:09 +0800 Subject: [PATCH] fix bug: false INTERNAL ASSERT FAILED --- tests/test_simple_math.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_simple_math.py b/tests/test_simple_math.py index e575dad..3da3b26 100644 --- a/tests/test_simple_math.py +++ b/tests/test_simple_math.py @@ -4,9 +4,9 @@ from efficient_kan import KAN - +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def test_mul(): - kan = KAN([2, 2, 1], base_activation=nn.Identity) + kan = KAN([2, 2, 1], base_activation=nn.Identity).to(device=device) optimizer = torch.optim.LBFGS(kan.parameters(), lr=1) with tqdm(range(100)) as pbar: for i in pbar: @@ -14,7 +14,7 @@ def test_mul(): def closure(): optimizer.zero_grad() - x = torch.rand(1024, 2) + x = torch.rand(1024, 2).to(device=device) y = kan(x, update_grid=(i % 20 == 0)) assert y.shape == (1024, 1)