From 62c2f595d2c7e0fbdf61296a1e3fa851eddc53fa Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 3 Dec 2025 10:54:23 -0800 Subject: [PATCH] [pmap] Remove `jax_pmap_shmap_merge` configuration. PiperOrigin-RevId: 839825245 --- rlax/_src/pop_art_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rlax/_src/pop_art_test.py b/rlax/_src/pop_art_test.py index dba1375..c85a92a 100644 --- a/rlax/_src/pop_art_test.py +++ b/rlax/_src/pop_art_test.py @@ -222,7 +222,6 @@ class PopArtTestWithPmapShmapMerge(PopArtTest): def setUp(self): super().setUp() self.pmap_shmap_merge = jax.config.jax_pmap_shmap_merge - jax.config.update('jax_pmap_shmap_merge', True) def tearDown(self): super().tearDown()