diff --git a/lifejacket/post_deployment_analysis.py b/lifejacket/post_deployment_analysis.py index 9552fea..a982e2e 100644 --- a/lifejacket/post_deployment_analysis.py +++ b/lifejacket/post_deployment_analysis.py @@ -53,8 +53,6 @@ level=logging.INFO, ) -jax.config.update("jax_enable_x64", True) - @click.group() def cli(): diff --git a/pyproject.toml b/pyproject.toml index f006de9..77f1c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lifejacket" -version = "1.1.0" +version = "1.2.0" description = "Consistent standard errors for longitudinal data collected under pooling online decision policies." readme = "README.md" requires-python = ">=3.10" diff --git a/tests/unit_tests/test_post_deployment_analysis.py b/tests/unit_tests/test_post_deployment_analysis.py index 285412a..edb3f4a 100644 --- a/tests/unit_tests/test_post_deployment_analysis.py +++ b/tests/unit_tests/test_post_deployment_analysis.py @@ -2935,7 +2935,7 @@ def test_construct_single_user_weighted_estimating_function_stacker_use_action_p jnp.mean( jnp.array([expected_weighted_stack_1, expected_weighted_stack_2]), axis=0 ), - rtol=1e-5, + rtol=1e-6, ) np.testing.assert_array_equal( result[1][0], diff --git a/tests/utils.py b/tests/utils.py index 1534b00..9af4fd3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -108,7 +108,7 @@ def assert_real_run_output_as_expected(test_file_path, relative_path_to_output_d np.testing.assert_allclose( observed_debug_pieces_dict["joint_meat_matrix"], expected_debug_pieces_dict["joint_meat_matrix"], - rtol=1e-3, + rtol=6e-4, ) np.testing.assert_allclose( observed_debug_pieces_dict["raw_joint_bread_matrix"],