diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index bdc5b9f5774..fc208640c3a 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -19,7 +19,7 @@ def compute_weight_and_optimizer_memory(args, verbose=False): # MoE. num_experts = 1 if args.num_experts is None else args.num_experts gated_linear_multiplier = 3 / 2 if args.swiglu else 1 - + shared_expert_ffn_hidden_size = ( 0 if args.moe_shared_expert_intermediate_size is None @@ -92,49 +92,55 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ) ) - num_parameters_in_transformer_layer_dense = ( - 2 - * args.hidden_size - * ( - # Dense MoE MLP. - (args.ffn_hidden_size * gated_linear_multiplier) - # Transformer layernorms. - + norm_size - ) - + self_attn_term - ) - num_parameters_in_transformer_layer_moe = ( - 2 - * args.hidden_size - * ( - # MoE MLP. - + (moe_ffn_hidden_size * num_experts * gated_linear_multiplier) - # Shared MoE MLP. - + (shared_expert_ffn_hidden_size * gated_linear_multiplier) - # Transformer layernorms. - + norm_size - ) - + self_attn_term - ) - num_active_parameters_in_transformer_layer_moe = ( - 2 - * args.hidden_size - * ( - # MoE MLP. - + (moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier) - # Shared MoE MLP. - + (shared_expert_ffn_hidden_size * gated_linear_multiplier) - # Transformer layernorms. - + (2) - ) - + self_attn_term - ) embedding_size = args.hidden_size * args.padded_vocab_size final_layernorm = norm_size * args.hidden_size if args.untie_embeddings_and_output_weights: num_parameters_in_embedding_layers = 2 * embedding_size else: num_parameters_in_embedding_layers = embedding_size + + attention_params = self_attn_term + dense_mlp_params = 2 * args.hidden_size * args.ffn_hidden_size * gated_linear_multiplier + shared_expert_params = 2 * args.hidden_size * shared_expert_ffn_hidden_size * gated_linear_multiplier + routed_expert_params = ( + 2 * args.hidden_size * moe_ffn_hidden_size * num_experts * gated_linear_multiplier + ) + active_routed_expert_params = ( + 2 * args.hidden_size * moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier + if args.num_experts is not None + else 0 + ) + layernorm_params = 2 * args.hidden_size * norm_size + router_params = ( + (args.hidden_size * num_experts) + (num_experts if args.add_bias_linear else 0) + if args.num_experts is not None + else 0 + ) + shared_expert_gate_params = ( + args.hidden_size + if shared_expert_ffn_hidden_size > 0 and getattr(args, "moe_shared_expert_gate", False) + else 0 + ) + + num_parameters_in_transformer_layer_dense = ( + attention_params + dense_mlp_params + layernorm_params + ) + num_parameters_in_transformer_layer_moe = ( + attention_params + + shared_expert_params + + routed_expert_params + + layernorm_params + + router_params + + shared_expert_gate_params + ) + num_active_parameters_in_transformer_layer_moe = ( + attention_params + + shared_expert_params + + active_routed_expert_params + + layernorm_params + + router_params + + shared_expert_gate_params + ) num_parameters_in_transformer_block = ( num_parameters_in_transformer_layer_dense * num_dense_layers + num_parameters_in_transformer_layer_moe * num_moe_layers @@ -180,16 +186,57 @@ def compute_weight_and_optimizer_memory(args, verbose=False): print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}") print(f"Total number of active parameters in billions: {num_active_parameters / 10**9:.2f}") - # Most loaded model shard has (1/pp_size transformer layers + 1 mtp block + 1 embedding layer) / tp_size. - num_parameters_on_most_loaded_model_shard = ( - (num_parameters_in_transformer_block / args.pipeline_model_parallel_size) - + num_parameters_in_mtp_block + expert_tensor_parallel_size = args.expert_tensor_parallel_size + expert_model_parallel_size = args.expert_model_parallel_size + expert_tensor_model_pipeline_parallel_size = ( + expert_tensor_parallel_size + * expert_model_parallel_size + * args.pipeline_model_parallel_size + ) + expert_data_parallel_size = args.world_size // expert_tensor_model_pipeline_parallel_size + + # Split params by how they are held on each rank: regular TP, replicated, or EP/ETP. + tp_sharded_params_in_transformer_block = ( + (attention_params + dense_mlp_params) * num_dense_layers + + (attention_params + shared_expert_params) * num_moe_layers + ) + replicated_params_in_transformer_block = ( + layernorm_params * num_dense_layers + + (layernorm_params + router_params + shared_expert_gate_params) * num_moe_layers + + final_layernorm + ) + expert_sharded_params_in_transformer_block = routed_expert_params * num_moe_layers + tp_sharded_params_in_mtp_block = ( + (attention_params + dense_mlp_params) * mtp_num_dense_layers + + (attention_params + shared_expert_params) * mtp_num_moe_layers + ) + replicated_params_in_mtp_block = ( + layernorm_params * mtp_num_dense_layers + + (layernorm_params + router_params + shared_expert_gate_params) * mtp_num_moe_layers + ) + expert_sharded_params_in_mtp_block = routed_expert_params * mtp_num_moe_layers + + # Most loaded model shard has 1/pp_size transformer layers, 1 mtp block, and 1 embedding layer. + tp_sharded_params_on_most_loaded_shard = ( + (tp_sharded_params_in_transformer_block / args.pipeline_model_parallel_size) + + tp_sharded_params_in_mtp_block + embedding_size ) / args.tensor_model_parallel_size + replicated_params_on_most_loaded_shard = ( + replicated_params_in_transformer_block / args.pipeline_model_parallel_size + ) + replicated_params_in_mtp_block + expert_sharded_params_on_most_loaded_shard = ( + (expert_sharded_params_in_transformer_block / args.pipeline_model_parallel_size) + + expert_sharded_params_in_mtp_block + ) / (expert_tensor_parallel_size * expert_model_parallel_size) + num_parameters_on_most_loaded_model_shard = ( + tp_sharded_params_on_most_loaded_shard + + replicated_params_on_most_loaded_shard + + expert_sharded_params_on_most_loaded_shard + ) if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1: - num_parameters_on_most_loaded_model_shard += ( - embedding_size / args.tensor_model_parallel_size - ) + tp_sharded_params_on_most_loaded_shard += embedding_size / args.tensor_model_parallel_size + num_parameters_on_most_loaded_model_shard += embedding_size / args.tensor_model_parallel_size if verbose: print( f"Number of parameters in most loaded shard in billions: " @@ -197,9 +244,17 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ) if args.pipeline_model_parallel_size > 1: - # Other shards just have (1/pp_size transformer layers) / tp_size. - num_parameters_on_other_model_shards = num_parameters_in_transformer_block / ( - args.pipeline_model_parallel_size * args.tensor_model_parallel_size + # Other shards just have 1/pp_size transformer layers. + num_parameters_on_other_model_shards = ( + tp_sharded_params_in_transformer_block + / (args.pipeline_model_parallel_size * args.tensor_model_parallel_size) + + replicated_params_in_transformer_block / args.pipeline_model_parallel_size + + expert_sharded_params_in_transformer_block + / ( + args.pipeline_model_parallel_size + * expert_tensor_parallel_size + * expert_model_parallel_size + ) ) if verbose: print( @@ -207,11 +262,16 @@ def compute_weight_and_optimizer_memory(args, verbose=False): f"{num_parameters_on_other_model_shards / 10**9:.4f}" ) - num_bytes_per_parameter = ( - 18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size) - ) + def num_bytes_per_parameter(data_parallel_size): + # This estimator assumes bf16 training: bf16 model params, fp32 main gradients, + # fp32 main params, and fp32 Adam states. See docs/user-guide/features/dist_optimizer.md. + return 18 if not args.use_distributed_optimizer else 6 + (12 / data_parallel_size) + weight_and_optimizer_memory = ( - num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter + (tp_sharded_params_on_most_loaded_shard + replicated_params_on_most_loaded_shard) + * num_bytes_per_parameter(args.data_parallel_size) + + expert_sharded_params_on_most_loaded_shard + * num_bytes_per_parameter(expert_data_parallel_size) ) return weight_and_optimizer_memory