From 60d595e3b6572e6a90714c4e67a6ca73763c33e9 Mon Sep 17 00:00:00 2001 From: YangFei1990 Date: Thu, 7 May 2026 16:41:58 -0700 Subject: [PATCH 1/2] fix weights/opt memory estimation for moe --- megatron/training/theoretical_memory_usage.py | 208 ++++++++++++++---- 1 file changed, 164 insertions(+), 44 deletions(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index bdc5b9f5774..4be110225d9 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -19,7 +19,7 @@ def compute_weight_and_optimizer_memory(args, verbose=False): # MoE. num_experts = 1 if args.num_experts is None else args.num_experts gated_linear_multiplier = 3 / 2 if args.swiglu else 1 - + shared_expert_ffn_hidden_size = ( 0 if args.moe_shared_expert_intermediate_size is None @@ -79,6 +79,7 @@ def compute_weight_and_optimizer_memory(args, verbose=False): + (args.num_attention_heads * args.v_head_dim) * args.hidden_size ) else: + # Self-attention linear weights: fused QKV plus output projection. self_attn_term = ( 2 * args.hidden_size @@ -92,59 +93,127 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ) ) + # Per-layer attention linear parameters, sharded by the regular tensor-parallel group. + num_parameters_in_attention = self_attn_term + # Per-layer dense MLP parameters + num_parameters_in_dense_mlp = ( + 2 * args.hidden_size * args.ffn_hidden_size * gated_linear_multiplier + ) + # Per-layer routed expert MLP parameters across all experts + num_parameters_in_routed_experts = ( + 2 * args.hidden_size * moe_ffn_hidden_size * num_experts * gated_linear_multiplier + ) + # Per-layer routed expert parameters active for one token; top-k experts are used per token. + num_active_parameters_in_routed_experts = ( + 2 * args.hidden_size * moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier + if args.num_experts is not None + else 0 + ) + # Per-layer shared expert MLP parameters; shared experts use regular TP, not ETP. + num_parameters_in_shared_experts = ( + 2 * args.hidden_size * shared_expert_ffn_hidden_size * gated_linear_multiplier + ) + # Per-layer normalization parameters; the factor 2 counts input norm and pre-MLP norm. + num_parameters_in_layernorms = 2 * args.hidden_size * norm_size + # Per-layer optional shared expert gate weight, replicated across tensor-parallel ranks. + num_parameters_in_shared_expert_gate = ( + args.hidden_size + if shared_expert_ffn_hidden_size > 0 and getattr(args, "moe_shared_expert_gate", False) + else 0 + ) + # Per-layer router gate parameters, replicated across tensor-parallel ranks. + num_parameters_in_router = ( + (args.hidden_size * num_experts) + (num_experts if args.add_bias_linear else 0) + if args.num_experts is not None + else 0 + ) + + # Per dense transformer layer, parameters sharded by regular tensor parallelism. + num_tp_sharded_parameters_in_transformer_layer_dense = ( + num_parameters_in_attention + num_parameters_in_dense_mlp + ) + # Per dense transformer layer, parameters replicated across tensor-parallel ranks. + num_replicated_parameters_in_transformer_layer_dense = num_parameters_in_layernorms + # Per dense transformer layer, total logical parameters before any parallel sharding. num_parameters_in_transformer_layer_dense = ( - 2 - * args.hidden_size - * ( - # Dense MoE MLP. - (args.ffn_hidden_size * gated_linear_multiplier) - # Transformer layernorms. - + norm_size - ) - + self_attn_term + num_tp_sharded_parameters_in_transformer_layer_dense + + num_replicated_parameters_in_transformer_layer_dense + ) + + # Per MoE transformer layer, non-routed parameters sharded by regular tensor parallelism. + num_tp_sharded_parameters_in_transformer_layer_moe = ( + num_parameters_in_attention + num_parameters_in_shared_experts + ) + # Per MoE transformer layer, non-routed parameters replicated across tensor-parallel ranks. + num_replicated_parameters_in_transformer_layer_moe = ( + num_parameters_in_layernorms + + num_parameters_in_router + + num_parameters_in_shared_expert_gate ) + # Per MoE transformer layer, total logical parameters before any parallel sharding. num_parameters_in_transformer_layer_moe = ( - 2 - * args.hidden_size - * ( - # MoE MLP. - + (moe_ffn_hidden_size * num_experts * gated_linear_multiplier) - # Shared MoE MLP. - + (shared_expert_ffn_hidden_size * gated_linear_multiplier) - # Transformer layernorms. - + norm_size - ) - + self_attn_term + num_tp_sharded_parameters_in_transformer_layer_moe + + num_replicated_parameters_in_transformer_layer_moe + + num_parameters_in_routed_experts ) + # Per MoE transformer layer, logical parameters used by one routed token. num_active_parameters_in_transformer_layer_moe = ( - 2 - * args.hidden_size - * ( - # MoE MLP. - + (moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier) - # Shared MoE MLP. - + (shared_expert_ffn_hidden_size * gated_linear_multiplier) - # Transformer layernorms. - + (2) - ) - + self_attn_term + num_tp_sharded_parameters_in_transformer_layer_moe + + num_replicated_parameters_in_transformer_layer_moe + + num_active_parameters_in_routed_experts ) + # Input embedding table parameters. embedding_size = args.hidden_size * args.padded_vocab_size + # Final normalization parameters, replicated across tensor-parallel ranks. final_layernorm = norm_size * args.hidden_size if args.untie_embeddings_and_output_weights: + # Untied embeddings have separate input embedding and output LM-head tables. num_parameters_in_embedding_layers = 2 * embedding_size else: + # Tied embeddings share the input embedding and output LM-head table. num_parameters_in_embedding_layers = embedding_size + # Transformer block parameters that will be divided by regular tensor parallelism. + num_tp_sharded_parameters_in_transformer_block = ( + num_tp_sharded_parameters_in_transformer_layer_dense * num_dense_layers + + num_tp_sharded_parameters_in_transformer_layer_moe * num_moe_layers + ) + # Transformer block parameters replicated across regular tensor-parallel ranks. + num_replicated_parameters_in_transformer_block = ( + num_replicated_parameters_in_transformer_layer_dense * num_dense_layers + + num_replicated_parameters_in_transformer_layer_moe * num_moe_layers + + final_layernorm + ) + # Transformer block routed expert parameters that will be divided by ETP and EP. + num_routed_expert_parameters_in_transformer_block = ( + num_parameters_in_routed_experts * num_moe_layers + ) + # Total logical transformer block parameters before model-parallel sharding. num_parameters_in_transformer_block = ( num_parameters_in_transformer_layer_dense * num_dense_layers + num_parameters_in_transformer_layer_moe * num_moe_layers + final_layernorm ) + # Total logical active transformer block parameters before model-parallel sharding. num_active_parameters_in_transformer_block = ( num_parameters_in_transformer_layer_dense * num_dense_layers + num_active_parameters_in_transformer_layer_moe * num_moe_layers + final_layernorm ) + # MTP block parameters that will be divided by regular tensor parallelism. + num_tp_sharded_parameters_in_mtp_block = ( + num_tp_sharded_parameters_in_transformer_layer_dense * mtp_num_dense_layers + + num_tp_sharded_parameters_in_transformer_layer_moe * mtp_num_moe_layers + ) + # MTP block parameters replicated across regular tensor-parallel ranks. + num_replicated_parameters_in_mtp_block = ( + num_replicated_parameters_in_transformer_layer_dense * mtp_num_dense_layers + + num_replicated_parameters_in_transformer_layer_moe * mtp_num_moe_layers + ) + # MTP block routed expert parameters that will be divided by ETP and EP. + num_routed_expert_parameters_in_mtp_block = ( + num_parameters_in_routed_experts * mtp_num_moe_layers + ) + # Total logical MTP block parameters before model-parallel sharding. num_parameters_in_mtp_block = ( num_parameters_in_transformer_layer_dense * mtp_num_dense_layers + num_parameters_in_transformer_layer_moe * mtp_num_moe_layers @@ -180,16 +249,48 @@ def compute_weight_and_optimizer_memory(args, verbose=False): print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}") print(f"Total number of active parameters in billions: {num_active_parameters / 10**9:.2f}") - # Most loaded model shard has (1/pp_size transformer layers + 1 mtp block + 1 embedding layer) / tp_size. - num_parameters_on_most_loaded_model_shard = ( - (num_parameters_in_transformer_block / args.pipeline_model_parallel_size) - + num_parameters_in_mtp_block + # Number of ranks that shard each routed expert's tensor dimensions. + expert_tensor_parallel_size = args.expert_tensor_parallel_size + # Number of ranks that split the global set of routed experts. + expert_model_parallel_size = args.expert_model_parallel_size + # Number of ranks in one expert tensor/expert/pipeline model-parallel group. + expert_tensor_model_pipeline_parallel_size = ( + expert_tensor_parallel_size + * expert_model_parallel_size + * args.pipeline_model_parallel_size + ) + # Data-parallel size used by expert parameters and distributed optimizer state. + expert_data_parallel_size = args.world_size // expert_tensor_model_pipeline_parallel_size + + # Most loaded model shard has 1/pp_size transformer layers, 1 mtp block, and 1 embedding layer. + # TP-sharded dense/shared parameters use regular TP. Routed experts use ETP and EP. Router and + # normalization parameters are replicated across TP/ETP ranks. + # Per-rank regular tensor-parallel parameters on the most loaded pipeline stage. + num_tp_sharded_parameters_on_most_loaded_model_shard = ( + (num_tp_sharded_parameters_in_transformer_block / args.pipeline_model_parallel_size) + + num_tp_sharded_parameters_in_mtp_block + embedding_size ) / args.tensor_model_parallel_size + # Per-rank replicated parameters on the most loaded pipeline stage. + num_replicated_parameters_on_most_loaded_model_shard = ( + num_replicated_parameters_in_transformer_block / args.pipeline_model_parallel_size + ) + num_replicated_parameters_in_mtp_block + # Per-rank routed expert parameters on the most loaded pipeline stage. + num_routed_expert_parameters_on_most_loaded_model_shard = ( + (num_routed_expert_parameters_in_transformer_block / args.pipeline_model_parallel_size) + + num_routed_expert_parameters_in_mtp_block + ) / (expert_tensor_parallel_size * expert_model_parallel_size) + # Total per-rank parameters on the most loaded pipeline stage. + num_parameters_on_most_loaded_model_shard = ( + num_tp_sharded_parameters_on_most_loaded_model_shard + + num_replicated_parameters_on_most_loaded_model_shard + + num_routed_expert_parameters_on_most_loaded_model_shard + ) if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1: - num_parameters_on_most_loaded_model_shard += ( + num_tp_sharded_parameters_on_most_loaded_model_shard += ( embedding_size / args.tensor_model_parallel_size ) + num_parameters_on_most_loaded_model_shard += embedding_size / args.tensor_model_parallel_size if verbose: print( f"Number of parameters in most loaded shard in billions: " @@ -197,9 +298,18 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ) if args.pipeline_model_parallel_size > 1: - # Other shards just have (1/pp_size transformer layers) / tp_size. - num_parameters_on_other_model_shards = num_parameters_in_transformer_block / ( - args.pipeline_model_parallel_size * args.tensor_model_parallel_size + # Other shards just have 1/pp_size transformer layers. + # Total per-rank parameters on non-embedding pipeline stages. + num_parameters_on_other_model_shards = ( + num_tp_sharded_parameters_in_transformer_block + / (args.pipeline_model_parallel_size * args.tensor_model_parallel_size) + + num_replicated_parameters_in_transformer_block / args.pipeline_model_parallel_size + + num_routed_expert_parameters_in_transformer_block + / ( + args.pipeline_model_parallel_size + * expert_tensor_parallel_size + * expert_model_parallel_size + ) ) if verbose: print( @@ -207,11 +317,21 @@ def compute_weight_and_optimizer_memory(args, verbose=False): f"{num_parameters_on_other_model_shards / 10**9:.4f}" ) - num_bytes_per_parameter = ( - 18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size) - ) + # Bf16 training bytes per logical parameter for the given data-parallel domain. + def num_bytes_per_parameter(data_parallel_size): + # This estimator assumes bf16 training: bf16 model params, fp32 main gradients, + # fp32 main params, and fp32 Adam states. See docs/user-guide/features/dist_optimizer.md. + return 18 if not args.use_distributed_optimizer else 6 + (12 / data_parallel_size) + + # Per-rank memory for weights, gradients, main params, and optimizer state. weight_and_optimizer_memory = ( - num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter + ( + num_tp_sharded_parameters_on_most_loaded_model_shard + + num_replicated_parameters_on_most_loaded_model_shard + ) + * num_bytes_per_parameter(args.data_parallel_size) + + num_routed_expert_parameters_on_most_loaded_model_shard + * num_bytes_per_parameter(expert_data_parallel_size) ) return weight_and_optimizer_memory From af44c3cd35114316ad40cdfcdd891f038f19f331 Mon Sep 17 00:00:00 2001 From: YangFei1990 Date: Fri, 8 May 2026 15:14:18 -0700 Subject: [PATCH 2/2] refactor to reduce verbose --- megatron/training/theoretical_memory_usage.py | 200 ++++++------------ 1 file changed, 70 insertions(+), 130 deletions(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index 4be110225d9..fc208640c3a 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -79,7 +79,6 @@ def compute_weight_and_optimizer_memory(args, verbose=False): + (args.num_attention_heads * args.v_head_dim) * args.hidden_size ) else: - # Self-attention linear weights: fused QKV plus output projection. self_attn_term = ( 2 * args.hidden_size @@ -93,127 +92,65 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ) ) - # Per-layer attention linear parameters, sharded by the regular tensor-parallel group. - num_parameters_in_attention = self_attn_term - # Per-layer dense MLP parameters - num_parameters_in_dense_mlp = ( - 2 * args.hidden_size * args.ffn_hidden_size * gated_linear_multiplier - ) - # Per-layer routed expert MLP parameters across all experts - num_parameters_in_routed_experts = ( + embedding_size = args.hidden_size * args.padded_vocab_size + final_layernorm = norm_size * args.hidden_size + if args.untie_embeddings_and_output_weights: + num_parameters_in_embedding_layers = 2 * embedding_size + else: + num_parameters_in_embedding_layers = embedding_size + + attention_params = self_attn_term + dense_mlp_params = 2 * args.hidden_size * args.ffn_hidden_size * gated_linear_multiplier + shared_expert_params = 2 * args.hidden_size * shared_expert_ffn_hidden_size * gated_linear_multiplier + routed_expert_params = ( 2 * args.hidden_size * moe_ffn_hidden_size * num_experts * gated_linear_multiplier ) - # Per-layer routed expert parameters active for one token; top-k experts are used per token. - num_active_parameters_in_routed_experts = ( + active_routed_expert_params = ( 2 * args.hidden_size * moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier if args.num_experts is not None else 0 ) - # Per-layer shared expert MLP parameters; shared experts use regular TP, not ETP. - num_parameters_in_shared_experts = ( - 2 * args.hidden_size * shared_expert_ffn_hidden_size * gated_linear_multiplier + layernorm_params = 2 * args.hidden_size * norm_size + router_params = ( + (args.hidden_size * num_experts) + (num_experts if args.add_bias_linear else 0) + if args.num_experts is not None + else 0 ) - # Per-layer normalization parameters; the factor 2 counts input norm and pre-MLP norm. - num_parameters_in_layernorms = 2 * args.hidden_size * norm_size - # Per-layer optional shared expert gate weight, replicated across tensor-parallel ranks. - num_parameters_in_shared_expert_gate = ( + shared_expert_gate_params = ( args.hidden_size if shared_expert_ffn_hidden_size > 0 and getattr(args, "moe_shared_expert_gate", False) else 0 ) - # Per-layer router gate parameters, replicated across tensor-parallel ranks. - num_parameters_in_router = ( - (args.hidden_size * num_experts) + (num_experts if args.add_bias_linear else 0) - if args.num_experts is not None - else 0 - ) - # Per dense transformer layer, parameters sharded by regular tensor parallelism. - num_tp_sharded_parameters_in_transformer_layer_dense = ( - num_parameters_in_attention + num_parameters_in_dense_mlp - ) - # Per dense transformer layer, parameters replicated across tensor-parallel ranks. - num_replicated_parameters_in_transformer_layer_dense = num_parameters_in_layernorms - # Per dense transformer layer, total logical parameters before any parallel sharding. num_parameters_in_transformer_layer_dense = ( - num_tp_sharded_parameters_in_transformer_layer_dense - + num_replicated_parameters_in_transformer_layer_dense - ) - - # Per MoE transformer layer, non-routed parameters sharded by regular tensor parallelism. - num_tp_sharded_parameters_in_transformer_layer_moe = ( - num_parameters_in_attention + num_parameters_in_shared_experts + attention_params + dense_mlp_params + layernorm_params ) - # Per MoE transformer layer, non-routed parameters replicated across tensor-parallel ranks. - num_replicated_parameters_in_transformer_layer_moe = ( - num_parameters_in_layernorms - + num_parameters_in_router - + num_parameters_in_shared_expert_gate - ) - # Per MoE transformer layer, total logical parameters before any parallel sharding. num_parameters_in_transformer_layer_moe = ( - num_tp_sharded_parameters_in_transformer_layer_moe - + num_replicated_parameters_in_transformer_layer_moe - + num_parameters_in_routed_experts + attention_params + + shared_expert_params + + routed_expert_params + + layernorm_params + + router_params + + shared_expert_gate_params ) - # Per MoE transformer layer, logical parameters used by one routed token. num_active_parameters_in_transformer_layer_moe = ( - num_tp_sharded_parameters_in_transformer_layer_moe - + num_replicated_parameters_in_transformer_layer_moe - + num_active_parameters_in_routed_experts - ) - # Input embedding table parameters. - embedding_size = args.hidden_size * args.padded_vocab_size - # Final normalization parameters, replicated across tensor-parallel ranks. - final_layernorm = norm_size * args.hidden_size - if args.untie_embeddings_and_output_weights: - # Untied embeddings have separate input embedding and output LM-head tables. - num_parameters_in_embedding_layers = 2 * embedding_size - else: - # Tied embeddings share the input embedding and output LM-head table. - num_parameters_in_embedding_layers = embedding_size - # Transformer block parameters that will be divided by regular tensor parallelism. - num_tp_sharded_parameters_in_transformer_block = ( - num_tp_sharded_parameters_in_transformer_layer_dense * num_dense_layers - + num_tp_sharded_parameters_in_transformer_layer_moe * num_moe_layers - ) - # Transformer block parameters replicated across regular tensor-parallel ranks. - num_replicated_parameters_in_transformer_block = ( - num_replicated_parameters_in_transformer_layer_dense * num_dense_layers - + num_replicated_parameters_in_transformer_layer_moe * num_moe_layers - + final_layernorm + attention_params + + shared_expert_params + + active_routed_expert_params + + layernorm_params + + router_params + + shared_expert_gate_params ) - # Transformer block routed expert parameters that will be divided by ETP and EP. - num_routed_expert_parameters_in_transformer_block = ( - num_parameters_in_routed_experts * num_moe_layers - ) - # Total logical transformer block parameters before model-parallel sharding. num_parameters_in_transformer_block = ( num_parameters_in_transformer_layer_dense * num_dense_layers + num_parameters_in_transformer_layer_moe * num_moe_layers + final_layernorm ) - # Total logical active transformer block parameters before model-parallel sharding. num_active_parameters_in_transformer_block = ( num_parameters_in_transformer_layer_dense * num_dense_layers + num_active_parameters_in_transformer_layer_moe * num_moe_layers + final_layernorm ) - # MTP block parameters that will be divided by regular tensor parallelism. - num_tp_sharded_parameters_in_mtp_block = ( - num_tp_sharded_parameters_in_transformer_layer_dense * mtp_num_dense_layers - + num_tp_sharded_parameters_in_transformer_layer_moe * mtp_num_moe_layers - ) - # MTP block parameters replicated across regular tensor-parallel ranks. - num_replicated_parameters_in_mtp_block = ( - num_replicated_parameters_in_transformer_layer_dense * mtp_num_dense_layers - + num_replicated_parameters_in_transformer_layer_moe * mtp_num_moe_layers - ) - # MTP block routed expert parameters that will be divided by ETP and EP. - num_routed_expert_parameters_in_mtp_block = ( - num_parameters_in_routed_experts * mtp_num_moe_layers - ) - # Total logical MTP block parameters before model-parallel sharding. num_parameters_in_mtp_block = ( num_parameters_in_transformer_layer_dense * mtp_num_dense_layers + num_parameters_in_transformer_layer_moe * mtp_num_moe_layers @@ -249,47 +186,56 @@ def compute_weight_and_optimizer_memory(args, verbose=False): print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}") print(f"Total number of active parameters in billions: {num_active_parameters / 10**9:.2f}") - # Number of ranks that shard each routed expert's tensor dimensions. expert_tensor_parallel_size = args.expert_tensor_parallel_size - # Number of ranks that split the global set of routed experts. expert_model_parallel_size = args.expert_model_parallel_size - # Number of ranks in one expert tensor/expert/pipeline model-parallel group. expert_tensor_model_pipeline_parallel_size = ( expert_tensor_parallel_size * expert_model_parallel_size * args.pipeline_model_parallel_size ) - # Data-parallel size used by expert parameters and distributed optimizer state. expert_data_parallel_size = args.world_size // expert_tensor_model_pipeline_parallel_size + # Split params by how they are held on each rank: regular TP, replicated, or EP/ETP. + tp_sharded_params_in_transformer_block = ( + (attention_params + dense_mlp_params) * num_dense_layers + + (attention_params + shared_expert_params) * num_moe_layers + ) + replicated_params_in_transformer_block = ( + layernorm_params * num_dense_layers + + (layernorm_params + router_params + shared_expert_gate_params) * num_moe_layers + + final_layernorm + ) + expert_sharded_params_in_transformer_block = routed_expert_params * num_moe_layers + tp_sharded_params_in_mtp_block = ( + (attention_params + dense_mlp_params) * mtp_num_dense_layers + + (attention_params + shared_expert_params) * mtp_num_moe_layers + ) + replicated_params_in_mtp_block = ( + layernorm_params * mtp_num_dense_layers + + (layernorm_params + router_params + shared_expert_gate_params) * mtp_num_moe_layers + ) + expert_sharded_params_in_mtp_block = routed_expert_params * mtp_num_moe_layers + # Most loaded model shard has 1/pp_size transformer layers, 1 mtp block, and 1 embedding layer. - # TP-sharded dense/shared parameters use regular TP. Routed experts use ETP and EP. Router and - # normalization parameters are replicated across TP/ETP ranks. - # Per-rank regular tensor-parallel parameters on the most loaded pipeline stage. - num_tp_sharded_parameters_on_most_loaded_model_shard = ( - (num_tp_sharded_parameters_in_transformer_block / args.pipeline_model_parallel_size) - + num_tp_sharded_parameters_in_mtp_block + tp_sharded_params_on_most_loaded_shard = ( + (tp_sharded_params_in_transformer_block / args.pipeline_model_parallel_size) + + tp_sharded_params_in_mtp_block + embedding_size ) / args.tensor_model_parallel_size - # Per-rank replicated parameters on the most loaded pipeline stage. - num_replicated_parameters_on_most_loaded_model_shard = ( - num_replicated_parameters_in_transformer_block / args.pipeline_model_parallel_size - ) + num_replicated_parameters_in_mtp_block - # Per-rank routed expert parameters on the most loaded pipeline stage. - num_routed_expert_parameters_on_most_loaded_model_shard = ( - (num_routed_expert_parameters_in_transformer_block / args.pipeline_model_parallel_size) - + num_routed_expert_parameters_in_mtp_block + replicated_params_on_most_loaded_shard = ( + replicated_params_in_transformer_block / args.pipeline_model_parallel_size + ) + replicated_params_in_mtp_block + expert_sharded_params_on_most_loaded_shard = ( + (expert_sharded_params_in_transformer_block / args.pipeline_model_parallel_size) + + expert_sharded_params_in_mtp_block ) / (expert_tensor_parallel_size * expert_model_parallel_size) - # Total per-rank parameters on the most loaded pipeline stage. num_parameters_on_most_loaded_model_shard = ( - num_tp_sharded_parameters_on_most_loaded_model_shard - + num_replicated_parameters_on_most_loaded_model_shard - + num_routed_expert_parameters_on_most_loaded_model_shard + tp_sharded_params_on_most_loaded_shard + + replicated_params_on_most_loaded_shard + + expert_sharded_params_on_most_loaded_shard ) if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1: - num_tp_sharded_parameters_on_most_loaded_model_shard += ( - embedding_size / args.tensor_model_parallel_size - ) + tp_sharded_params_on_most_loaded_shard += embedding_size / args.tensor_model_parallel_size num_parameters_on_most_loaded_model_shard += embedding_size / args.tensor_model_parallel_size if verbose: print( @@ -299,12 +245,11 @@ def compute_weight_and_optimizer_memory(args, verbose=False): if args.pipeline_model_parallel_size > 1: # Other shards just have 1/pp_size transformer layers. - # Total per-rank parameters on non-embedding pipeline stages. num_parameters_on_other_model_shards = ( - num_tp_sharded_parameters_in_transformer_block + tp_sharded_params_in_transformer_block / (args.pipeline_model_parallel_size * args.tensor_model_parallel_size) - + num_replicated_parameters_in_transformer_block / args.pipeline_model_parallel_size - + num_routed_expert_parameters_in_transformer_block + + replicated_params_in_transformer_block / args.pipeline_model_parallel_size + + expert_sharded_params_in_transformer_block / ( args.pipeline_model_parallel_size * expert_tensor_parallel_size @@ -317,20 +262,15 @@ def compute_weight_and_optimizer_memory(args, verbose=False): f"{num_parameters_on_other_model_shards / 10**9:.4f}" ) - # Bf16 training bytes per logical parameter for the given data-parallel domain. def num_bytes_per_parameter(data_parallel_size): # This estimator assumes bf16 training: bf16 model params, fp32 main gradients, # fp32 main params, and fp32 Adam states. See docs/user-guide/features/dist_optimizer.md. return 18 if not args.use_distributed_optimizer else 6 + (12 / data_parallel_size) - # Per-rank memory for weights, gradients, main params, and optimizer state. weight_and_optimizer_memory = ( - ( - num_tp_sharded_parameters_on_most_loaded_model_shard - + num_replicated_parameters_on_most_loaded_model_shard - ) + (tp_sharded_params_on_most_loaded_shard + replicated_params_on_most_loaded_shard) * num_bytes_per_parameter(args.data_parallel_size) - + num_routed_expert_parameters_on_most_loaded_model_shard + + expert_sharded_params_on_most_loaded_shard * num_bytes_per_parameter(expert_data_parallel_size) )