-
Notifications
You must be signed in to change notification settings - Fork 75
Open
Labels
type:bugSomething isn't workingSomething isn't working
Description
On main and jax 0.7/0.8, if a global mesh is set with jax.sharding.set_mesh it is impossible to save a checkpoint because of the error shown below.
Am I doing something wrong, or is this an issue in orbax?
import numpy as np
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('model',)),
jax.sharding.PartitionSpec(),
)
# Commenting this line below will break cptr.save
jax.sharding.set_mesh(sharding.mesh)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
'a': np.arange(16),
'b': np.ones(16),
}
state = jax.tree.map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.StandardSave(state))error:
File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:262, in _serialize_arrays(arrays, infos, args, dispatcher, replica_id, use_replica_parallel, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, primary_host, metadata_key, array_metadata_store, enable_replica_parallel_separate_folder, ext_metadata)
259 """D2H transfer and serialize arrays using dispatcher if provided."""
260 if dispatcher is None:
261 # Complete D2H transfer in parallel for each array.
--> 262 values_on_host = replica_slices.transfer_arrays_to_host(
263 arrays,
264 replica_id,
265 use_replica_parallel,
266 enable_pinned_host_transfer=infos[0].enable_pinned_host_transfer,
267 min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel,
268 max_replicas_for_replica_parallel=max_replicas_for_replica_parallel,
269 )
270 return future.CommitFutureAwaitingContractedSignals(
271 _async_serialize_replica_slices(
272 values_on_host,
(...) 282 name='array_type_handler',
283 )
284 else:
....
File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/jax/_src/pjit.py:159, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
156 fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
157 msg = stages._device_assignment_mismatch_error(
158 fun_name, fails, args_flat, 'jit', p.arg_names)
--> 159 raise ValueError(msg) from None
160 except dtypes.InvalidInputException as e:
161 arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
ValueError: Received incompatible devices for jitted computation. Got argument args[0] of slice with shape int32[16] and device ids [0] on platform CPU and jit's context mesh with device ids [0, 1, 2, 3] on platform CPUMetadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working