diff --git a/tests/test_simple_math.py b/tests/test_simple_math.py index e575dad..3da3b26 100644 --- a/tests/test_simple_math.py +++ b/tests/test_simple_math.py @@ -4,9 +4,9 @@ from efficient_kan import KAN - +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def test_mul(): - kan = KAN([2, 2, 1], base_activation=nn.Identity) + kan = KAN([2, 2, 1], base_activation=nn.Identity).to(device=device) optimizer = torch.optim.LBFGS(kan.parameters(), lr=1) with tqdm(range(100)) as pbar: for i in pbar: @@ -14,7 +14,7 @@ def test_mul(): def closure(): optimizer.zero_grad() - x = torch.rand(1024, 2) + x = torch.rand(1024, 2).to(device=device) y = kan(x, update_grid=(i % 20 == 0)) assert y.shape == (1024, 1)