diff --git a/rlax/_src/pop_art_test.py b/rlax/_src/pop_art_test.py index dba1375..c85a92a 100644 --- a/rlax/_src/pop_art_test.py +++ b/rlax/_src/pop_art_test.py @@ -222,7 +222,6 @@ class PopArtTestWithPmapShmapMerge(PopArtTest): def setUp(self): super().setUp() self.pmap_shmap_merge = jax.config.jax_pmap_shmap_merge - jax.config.update('jax_pmap_shmap_merge', True) def tearDown(self): super().tearDown()