diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 15553727..10238ac4 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -35,7 +35,7 @@ AxisNames = tuple[str, ...] # Physical axis names for device meshes. DATA = "data" -FSDP = "fsdp" +FSDP = "fsdp_tpu" TENSOR = "tensor" # Logical axis names for model parameters and activations. BATCH = "activation_batch" diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 7bd8ae70..58c81a3e 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -106,7 +106,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -122,16 +122,16 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 24dffe40..2d54dab0 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -108,7 +108,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -124,16 +124,16 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 7b224058..3e826105 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -121,7 +121,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -137,16 +137,16 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 0036b363..49a48c61 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -148,17 +148,17 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index ac0a0f8c..433b34d6 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -148,17 +148,17 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], -# ['embed','fsdp'], - ['mlp',['fsdp','tensor']], +# ['embed','fsdp_tpu'], + ['mlp',['fsdp_tpu','tensor']], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index c60dd79e..12171041 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -140,7 +140,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -156,17 +156,17 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 1b647424..1b731ccf 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -148,7 +148,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'tensor', 'fsdp_tpu', 'fsdp_gpu'] # batch : batch dimension of data and activations # hidden : @@ -163,32 +163,34 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_self_attn_heads', ['fsdp', 'tensor']], - ['activation_cross_attn_q_length', ['fsdp', 'tensor']], - ['activation_length', 'fsdp'], + ['batch', ['data', 'fsdp_gpu']], + ['activation_batch', ['data', 'fsdp_gpu']], + ['activation_length', 'fsdp_tpu'], + ['activation_self_attn_heads', ['fsdp_tpu', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp_tpu', 'tensor']], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed', ['fsdp_tpu', 'fsdp_gpu']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data', 'fsdp_tpu', 'fsdp_gpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'tensor', 'fsdp_tpu', 'fsdp_gpu']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 +dcn_fsdp_tpu_parallelism: -1 +dcn_fsdp_gpu_parallelism: 1 # recommended DCN axis to be auto-sharded ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_fsdp_tpu_parallelism: -1 +ici_fsdp_gpu_parallelism: 1 # recommended ICI axis to be auto-sharded allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 1b93a32a..acbd6479 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -137,7 +137,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -154,18 +154,18 @@ mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', 'data'], - ['activation_length', 'fsdp'], + ['activation_length', 'fsdp_tpu'], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 49e53ae5..df1cf97c 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -106,7 +106,7 @@ base_output_directory: "" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -122,16 +122,16 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 6f6662b0..65210668 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -86,7 +86,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -102,16 +102,16 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_batch', ['data','fsdp']], + ['activation_batch', ['data','fsdp_tpu']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'fsdp_tpu'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 71316ea1..b87a6d10 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -62,22 +62,22 @@ second_pass: cfg_star_rescale: True #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp_tpu', 'tensor'] logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], + ['activation_heads', 'fsdp_tpu'], ['activation_batch', 'data'], ['activation_kv', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed','fsdp_tpu'], ['heads', 'tensor'], - ['norm', 'fsdp'], - ['conv_batch', ['data','fsdp']], + ['norm', 'fsdp_tpu'], + ['conv_batch', ['data','fsdp_tpu']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] + ['conv_out', 'fsdp_tpu'], + ['conv_in', 'fsdp_tpu'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp_tpu', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 48c6ca44..39fb7a64 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -268,17 +268,30 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") multi_slice_env = num_slices > 1 - - dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_tensor_parallelism, - ] - ici_parallelism = [ - config.ici_data_parallelism, - config.ici_fsdp_parallelism, - config.ici_tensor_parallelism, - ] + if "dcn_fsdp_tpu_parallelism" in config.get_keys(): + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_tensor_parallelism, + config.dcn_fsdp_tpu_parallelism, + config.dcn_fsdp_gpu_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_tensor_parallelism, + config.ici_fsdp_tpu_parallelism, + config.ici_fsdp_gpu_parallelism, + ] + else: + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") @@ -651,3 +664,15 @@ def maybe_initialize_jax_distributed_system(raw_keys): max_logging.log("Jax distributed system initialized on GPU!") else: jax.distributed.initialize() + +def get_axis_names(axis_key: str, config=None) -> str: + """Returns the mesh axis names given the logical axis key from config.logical_axis_rules.""" + axis_name = '' + if config: + axis_rules = config.logical_axis_rules + else: + axis_rules = nn.get_logical_axis_rules() + for rules in axis_rules: + if rules[0] == axis_key: + axis_name = rules[1] + return axis_name \ No newline at end of file diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 218b3b79..a08512f7 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -30,6 +30,7 @@ from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel from einops import rearrange from .. import common_types, max_logging +from .. import max_utils from . import quantizations @@ -78,8 +79,11 @@ def _reshape_data_from_cudnn_flash(tensor): def _reshape_data_for_cudnn_flash(tensor, heads): # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + if len(tensor.shape) == 3: + batch, seq, dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, dim_head // heads) + else: + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) return tensor @@ -89,7 +93,8 @@ def _reshape_batch_dim_to_heads(tensor, heads): tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _reshape_heads_to_batch_dim(tensor, heads): @@ -102,8 +107,8 @@ def _reshape_heads_to_batch_dim(tensor, heads): else: batch_size, head_size, seq_len, head_dim = tensor.shape reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _reshape_heads_to_head_dim(tensor): @@ -112,7 +117,8 @@ def _reshape_heads_to_head_dim(tensor): b, h, s, d = tensor.shape tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _unflatten_heads(tensor, heads): @@ -247,7 +253,8 @@ def _tpu_flash_attention( block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, ) - num_fsdp_shards = mesh.shape["fsdp"] + fsdp_key = max_utils.get_axis_names("activation_length") + num_fsdp_shards = mesh.shape[fsdp_key] query = _reshape_data_for_flash(query, heads) key = _reshape_data_for_flash(key, heads) value = _reshape_data_for_flash(value, heads) @@ -361,13 +368,13 @@ def wrap_flash_attention(query, key, value): perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] - k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) - v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) + k1 = jax.lax.ppermute(key, axis_name=fsdp_key, perm=perm) + v1 = jax.lax.ppermute(value, axis_name=fsdp_key, perm=perm) def ring_scan_body(carry, _): m, l, o, k_current, v_current = carry - k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) - v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) + k_next = jax.lax.ppermute(k_current, axis_name=fsdp_key, perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name=fsdp_key, perm=perm) out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) @@ -394,7 +401,7 @@ def ring_scan_body(carry, _): return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) - devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] + devices_in_data_fsdp = mesh.shape["data"] * mesh.shape[fsdp_key] # This warning might show up when doing model eval for example, when calculating model flops # and that is expected. if not (query.shape[0] / devices_in_data_fsdp).is_integer(): @@ -492,24 +499,12 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m key = _reshape_data_for_cudnn_flash(key, heads) value = _reshape_data_for_cudnn_flash(value, heads) - cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) - axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) - - query = nn.with_logical_constraint(query, axis_names) - key = nn.with_logical_constraint(key, axis_names) - value = nn.with_logical_constraint(value, axis_names) - - @functools.partial( - shard_map.shard_map, - mesh=mesh, - in_specs=(axis_names, axis_names, axis_names), - out_specs=axis_names, - check_rep=False, - ) - def wrap_flash_attention(query, key, value): - return jax.vmap(dpa_layer)(query, key, value, mask=None) - - out = wrap_flash_attention(query, key, value) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV)) + query = jax.lax.with_sharding_constraint(query, axis_names) + key = jax.lax.with_sharding_constraint(key, axis_names) + value = jax.lax.with_sharding_constraint(value, axis_names) + + out = dpa_layer(query, key, value, mask=None) return _reshape_data_from_cudnn_flash(out) @@ -706,7 +701,24 @@ def __init__( ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": - raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + jax.config.update("jax_use_shardy_partitioner", False) + + dpa_layer = DotProductAttention( + head_dim=dim_head, + num_attention_heads=heads, + num_gqa_groups=heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=dtype, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=scale, + transpose_batch_sequence=False, + ) + variables = {} + self.dpa_layer = functools.partial(dpa_layer.apply, variables) self.mesh = mesh self.scale = scale @@ -769,8 +781,9 @@ def setup(self): self.dpa_layer = None if self.attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + jax.config.update("jax_use_shardy_partitioner", False) - self.dpa_layer = DotProductAttention( + dpa_layer = DotProductAttention( head_dim=self.dim_head, num_attention_heads=self.heads, num_gqa_groups=self.heads, @@ -784,6 +797,9 @@ def setup(self): scale_factor=self.scale, transpose_batch_sequence=False, ) + variables = {} + self.dpa_layer = functools.partial(dpa_layer.apply, variables) + def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( @@ -839,9 +855,6 @@ def __init__( residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, ): - if attention_kernel == "cudnn_flash_te": - raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") - if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") self.dim_head = dim_head @@ -998,8 +1011,9 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 77f35073..944a10dd 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -21,6 +21,7 @@ import jax.numpy as jnp from flax import nnx from ...configuration_utils import ConfigMixin +from ... import max_utils from ..modeling_flax_utils import FlaxModelMixin, get_activation from ... import common_types from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) @@ -28,7 +29,10 @@ BlockSizes = common_types.BlockSizes CACHE_T = 2 -flax.config.update('flax_always_shard_variable', False) +try: + flax.config.update('flax_always_shard_variable', False) +except: + pass # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: @@ -73,7 +77,10 @@ def __init__( self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] # Set sharding dynamically based on out_channels. - num_fsdp_axis_devices = mesh.device_ids.shape[1] + fsdp_key = max_utils.get_axis_names("activation_length") + if not fsdp_key: + fsdp_key = "fsdp_tpu" + num_fsdp_axis_devices = mesh.shape[fsdp_key] kernel_sharding = (None, None, None, None, None) if out_channels % num_fsdp_axis_devices == 0: kernel_sharding = (None, None, None, None, "conv_out") diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index cb952afa..a432c4d9 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -362,9 +362,11 @@ def __call__( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads")) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) hidden_states = checkpoint_name(hidden_states, "hidden_states") - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) # 1. Self-attention with self.conditional_named_scope("self_attn"): @@ -515,7 +517,7 @@ def init_block(rngs): if scan_layers: self.blocks = init_block(rngs) else: - blocks = nnx.List([]) + blocks = [] for _ in range(num_layers): block = WanTransformerBlock( rngs=rngs, @@ -535,7 +537,7 @@ def init_block(rngs): enable_jax_named_scopes=enable_jax_named_scopes, ) blocks.append(block) - self.blocks = blocks + self.blocks = nnx.data(blocks) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 060cc1bf..b5a6f16a 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -195,13 +195,13 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. - if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]: + if "ring" in raw_keys["attention"] or (raw_keys["attention_sharding_uniform"] and "cudnn_flash_te" not in raw_keys["attention"]): max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.") logical_axis_rules = list(raw_keys["logical_axis_rules"]) max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") new_rules = [] - q_seq_sharding = (LENGTH, "fsdp") - kv_seq_sharding = (KV_LENGTH, "fsdp") + q_seq_sharding = (LENGTH, "fsdp_tpu") + kv_seq_sharding = (KV_LENGTH, "fsdp_tpu") if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index fea15720..cc246797 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -35,7 +35,10 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.config validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") - flax.config.update("flax_always_shard_variable", False) + try: + flax.config.update("flax_always_shard_variable", False) + except: + pass train(config) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index f23836a5..f27279d9 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -40,6 +40,7 @@ from flax.training import train_state from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from jax.experimental import multihost_utils +from transformer_engine.jax.sharding import global_shard_guard, MeshResource class TrainState(train_state.TrainState): @@ -210,8 +211,8 @@ def prepare_sample_eval(features): return data_iterator def start_training(self): - - pipeline, opt_state, step = self.checkpointer.load_checkpoint() + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + pipeline, opt_state, step = self.checkpointer.load_checkpoint() restore_args = {} if opt_state and step: restore_args = {"opt_state": opt_state, "step": step} @@ -309,7 +310,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) max_logging.log(pretty_string) max_logging.log("------------------------------------------------") - max_utils.delete_pytree(params) + if self.config.hardware != 'gpu': + max_utils.delete_pytree(params) data_shardings = self.get_data_shardings(mesh) eval_data_shardings = self.get_eval_data_shardings(mesh) @@ -359,15 +361,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data scheduler_state = pipeline.scheduler_state example_batch = load_next_batch(train_data_iterator, None, self.config) + # Designate the context parallel axis for sharding + cp_resource = max_utils.get_axis_names("activation_length", config=self.config) + mesh_resource = MeshResource(cp_resource=cp_resource) + with ThreadPoolExecutor(max_workers=1) as executor: for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) start_step_time = datetime.datetime.now() next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, \ + global_shard_guard(mesh_resource), nn_partitioning.axis_rules(self.config.logical_axis_rules): state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) train_metric["scalar"]["learning/loss"].block_until_ready() last_step_completion = datetime.datetime.now()