Skip to content

Issue with shard_map #185

@lockwo

Description

@lockwo

I was playing around with the new JAX shardings and noticed an error when trying to shard over a lineax problem. This seems like the simple approach, but it errors, am I doing something wrong?

import lineax as lx
import jax.scipy.linalg as jsla
import jax
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
from jax import numpy as jnp

num_devices = jax.local_device_count()
mesh = jax.make_mesh((num_devices,), ("batch",), axis_types=(AxisType.Explicit,))
spec = P("batch")
sharding = NamedSharding(mesh, spec)

n_systems = num_devices * 2
A_batch = jnp.stack([jnp.eye(3) * (i + 1) for i in range(n_systems)])
b_batch = jnp.ones((n_systems, 3))

A_sharded = jax.device_put(A_batch, sharding)
b_sharded = jax.device_put(b_batch, sharding)

def solve_lineax(A, b):
    op = lx.MatrixLinearOperator(A)
    return lx.linear_solve(op, b, solver=lx.LU()).value

def solve_jax(A, b):
    return jsla.solve(A, b)

print("lineax:")
try:
    sharded_lineax = jax.jit(jax.shard_map(
        lambda A, b: jax.vmap(solve_lineax)(A, b),
        mesh=mesh, in_specs=(spec, spec), out_specs=spec, check_vma=False
    ))
    result = sharded_lineax(A_sharded, b_sharded)
    print(f"Result: {result.mean()}")
except Exception as e:
    print(f"FAILED: {type(e).__name__}: {e}")

print("\njax.scipy.linalg.solve:")
try:
    sharded_jax = jax.jit(jax.shard_map(
        lambda A, b: jax.vmap(solve_jax)(A, b),
        mesh=mesh, in_specs=(spec, spec), out_specs=spec, check_vma=False
    ))
    result = sharded_jax(A_sharded, b_sharded)
    print(f"Result:\n{result.mean()}")
except Exception as e:
    print(f"FAILED: {type(e).__name__}: {e}")
14
/var/folders/fk/wsltfy6d1hv2bbrp75ph4kl00000gn/T/ipykernel_44869/2699943594.py:3: DeprecationWarning: The default axis_types will change in JAX v0.9.0 to jax.sharding.AxisType.Explicit. To maintain the old behavior, pass `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To opt-into the new behavior, pass `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names)
  mesh = jax.make_mesh((num_devices,), ("batch",))
Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?565a00ce-2f66-4532-ad43-f8a5971be416)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[5], line 27
     16     return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)
     18 sharded_fun = jax.jit(
     19     jax.shard_map(
     20         solve_single, 
   (...)     25     )
     26 )
---> 27 _ = sharded_fun(params_sharded, init_rep).block_until_ready()
     29 time = timeit.repeat(
     30     "sharded_fun(params_sharded, init_rep).block_until_ready()",
     31     globals=globals(),
     32     number=2,
     33     repeat=2,
     34 )
     36 print(f"took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run")

    [... skipping hidden 26 frame]

Cell In[5], line 16, in solve_single(params_local, init)
     15 def solve_single(params_local, init):
---> 16     return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)

    [... skipping hidden 7 frame]
...
---> 30   assert not hlo_sharding.is_manual()
     31   if hlo_sharding.is_replicated():
     32     return [], 1

AssertionError: 
Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?565a00ce-2f66-4532-ad43-f8a5971be416)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[5], line 27
     16     return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)
     18 sharded_fun = jax.jit(
     19     jax.shard_map(
     20         solve_single, 
   (...)     25     )
     26 )
---> 27 _ = sharded_fun(params_sharded, init_rep).block_until_ready()
     29 time = timeit.repeat(
     30     "sharded_fun(params_sharded, init_rep).block_until_ready()",
     31     globals=globals(),
     32     number=2,
     33     repeat=2,
     34 )
     36 print(f"took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run")

    [... skipping hidden 26 frame]

Cell In[5], line 16, in solve_single(params_local, init)
     15 def solve_single(params_local, init):
---> 16     return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)

    [... skipping hidden 7 frame]
...
---> 30   assert not hlo_sharding.is_manual()
     31   if hlo_sharding.is_replicated():
     32     return [], 1

AssertionError: 
params type: float64[140@x]
init type: float64[10]
Explicit sharding: took 5.66e-01 seconds to compile, 2.54e-01 to run
lineax:
FAILED: ValueError: Vector and operator structures do not match. Got a vector with structure ShapeDtypeStruct(shape=(3,), dtype=float64, sharding=NamedSharding(mesh=AbstractMesh('batch': 14, axis_types=(Manual,), device_kind=cpu, num_cores=None), spec=PartitionSpec(None,))) and an operator with out-structure ShapeDtypeStruct(shape=(3,), dtype=float64)

jax.scipy.linalg.solve:
Result:
0.14025610853451315

latest versions of jax/eqx/linx etc

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions