I was playing around with the new JAX shardings and noticed an error when trying to shard over a lineax problem. This seems like the simple approach, but it errors, am I doing something wrong?
import lineax as lx
import jax.scipy.linalg as jsla
import jax
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
from jax import numpy as jnp
num_devices = jax.local_device_count()
mesh = jax.make_mesh((num_devices,), ("batch",), axis_types=(AxisType.Explicit,))
spec = P("batch")
sharding = NamedSharding(mesh, spec)
n_systems = num_devices * 2
A_batch = jnp.stack([jnp.eye(3) * (i + 1) for i in range(n_systems)])
b_batch = jnp.ones((n_systems, 3))
A_sharded = jax.device_put(A_batch, sharding)
b_sharded = jax.device_put(b_batch, sharding)
def solve_lineax(A, b):
op = lx.MatrixLinearOperator(A)
return lx.linear_solve(op, b, solver=lx.LU()).value
def solve_jax(A, b):
return jsla.solve(A, b)
print("lineax:")
try:
sharded_lineax = jax.jit(jax.shard_map(
lambda A, b: jax.vmap(solve_lineax)(A, b),
mesh=mesh, in_specs=(spec, spec), out_specs=spec, check_vma=False
))
result = sharded_lineax(A_sharded, b_sharded)
print(f"Result: {result.mean()}")
except Exception as e:
print(f"FAILED: {type(e).__name__}: {e}")
print("\njax.scipy.linalg.solve:")
try:
sharded_jax = jax.jit(jax.shard_map(
lambda A, b: jax.vmap(solve_jax)(A, b),
mesh=mesh, in_specs=(spec, spec), out_specs=spec, check_vma=False
))
result = sharded_jax(A_sharded, b_sharded)
print(f"Result:\n{result.mean()}")
except Exception as e:
print(f"FAILED: {type(e).__name__}: {e}")
14
/var/folders/fk/wsltfy6d1hv2bbrp75ph4kl00000gn/T/ipykernel_44869/2699943594.py:3: DeprecationWarning: The default axis_types will change in JAX v0.9.0 to jax.sharding.AxisType.Explicit. To maintain the old behavior, pass `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To opt-into the new behavior, pass `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names)
mesh = jax.make_mesh((num_devices,), ("batch",))
Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?565a00ce-2f66-4532-ad43-f8a5971be416)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[5], line 27
16 return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)
18 sharded_fun = jax.jit(
19 jax.shard_map(
20 solve_single,
(...) 25 )
26 )
---> 27 _ = sharded_fun(params_sharded, init_rep).block_until_ready()
29 time = timeit.repeat(
30 "sharded_fun(params_sharded, init_rep).block_until_ready()",
31 globals=globals(),
32 number=2,
33 repeat=2,
34 )
36 print(f"took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run")
[... skipping hidden 26 frame]
Cell In[5], line 16, in solve_single(params_local, init)
15 def solve_single(params_local, init):
---> 16 return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)
[... skipping hidden 7 frame]
...
---> 30 assert not hlo_sharding.is_manual()
31 if hlo_sharding.is_replicated():
32 return [], 1
AssertionError:
Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?565a00ce-2f66-4532-ad43-f8a5971be416)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[5], line 27
16 return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)
18 sharded_fun = jax.jit(
19 jax.shard_map(
20 solve_single,
(...) 25 )
26 )
---> 27 _ = sharded_fun(params_sharded, init_rep).block_until_ready()
29 time = timeit.repeat(
30 "sharded_fun(params_sharded, init_rep).block_until_ready()",
31 globals=globals(),
32 number=2,
33 repeat=2,
34 )
36 print(f"took {time[0]:.2e} seconds to compile, {time[1]:.2e} to run")
[... skipping hidden 26 frame]
Cell In[5], line 16, in solve_single(params_local, init)
15 def solve_single(params_local, init):
---> 16 return jax.vmap(integ, in_axes=(None, None, 0))(init, 100., params_local)
[... skipping hidden 7 frame]
...
---> 30 assert not hlo_sharding.is_manual()
31 if hlo_sharding.is_replicated():
32 return [], 1
AssertionError:
params type: float64[140@x]
init type: float64[10]
Explicit sharding: took 5.66e-01 seconds to compile, 2.54e-01 to run
lineax:
FAILED: ValueError: Vector and operator structures do not match. Got a vector with structure ShapeDtypeStruct(shape=(3,), dtype=float64, sharding=NamedSharding(mesh=AbstractMesh('batch': 14, axis_types=(Manual,), device_kind=cpu, num_cores=None), spec=PartitionSpec(None,))) and an operator with out-structure ShapeDtypeStruct(shape=(3,), dtype=float64)
jax.scipy.linalg.solve:
Result:
0.14025610853451315
I was playing around with the new JAX shardings and noticed an error when trying to shard over a lineax problem. This seems like the simple approach, but it errors, am I doing something wrong?
latest versions of jax/eqx/linx etc