Skip to content

saving with default/global mesh is broken #2545

@PhilipVinc

Description

@PhilipVinc

On main and jax 0.7/0.8, if a global mesh is set with jax.sharding.set_mesh it is impossible to save a checkpoint because of the error shown below.

Am I doing something wrong, or is this an issue in orbax?

import numpy as np
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(),
)
# Commenting this line below will break cptr.save
jax.sharding.set_mesh(sharding.mesh)

create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
    'a': np.arange(16),
    'b': np.ones(16),
}
state = jax.tree.map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())

ckptr.save(path / '1', args=ocp.args.StandardSave(state))

error:

File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:262, in _serialize_arrays(arrays, infos, args, dispatcher, replica_id, use_replica_parallel, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, primary_host, metadata_key, array_metadata_store, enable_replica_parallel_separate_folder, ext_metadata)
    259 """D2H transfer and serialize arrays using dispatcher if provided."""
    260 if dispatcher is None:
    261   # Complete D2H transfer in parallel for each array.
--> 262   values_on_host = replica_slices.transfer_arrays_to_host(
    263       arrays,
    264       replica_id,
    265       use_replica_parallel,
    266       enable_pinned_host_transfer=infos[0].enable_pinned_host_transfer,
    267       min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel,
    268       max_replicas_for_replica_parallel=max_replicas_for_replica_parallel,
    269   )
    270   return future.CommitFutureAwaitingContractedSignals(
    271       _async_serialize_replica_slices(
    272           values_on_host,
   (...)    282       name='array_type_handler',
    283   )
    284 else:
....
File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/jax/_src/pjit.py:159, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    156   fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
    157   msg = stages._device_assignment_mismatch_error(
    158       fun_name, fails, args_flat, 'jit', p.arg_names)
--> 159   raise ValueError(msg) from None
    160 except dtypes.InvalidInputException as e:
    161   arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names

ValueError: Received incompatible devices for jitted computation. Got argument args[0] of slice with shape int32[16] and device ids [0] on platform CPU and jit's context mesh with device ids [0, 1, 2, 3] on platform CPU

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions