Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 112 additions & 52 deletions megatron/training/theoretical_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# MoE.
num_experts = 1 if args.num_experts is None else args.num_experts
gated_linear_multiplier = 3 / 2 if args.swiglu else 1

shared_expert_ffn_hidden_size = (
0
if args.moe_shared_expert_intermediate_size is None
Expand Down Expand Up @@ -92,49 +92,55 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
)
)

num_parameters_in_transformer_layer_dense = (
2
* args.hidden_size
* (
# Dense MoE MLP.
(args.ffn_hidden_size * gated_linear_multiplier)
# Transformer layernorms.
+ norm_size
)
+ self_attn_term
)
num_parameters_in_transformer_layer_moe = (
2
* args.hidden_size
* (
# MoE MLP.
+ (moe_ffn_hidden_size * num_experts * gated_linear_multiplier)
# Shared MoE MLP.
+ (shared_expert_ffn_hidden_size * gated_linear_multiplier)
# Transformer layernorms.
+ norm_size
)
+ self_attn_term
)
num_active_parameters_in_transformer_layer_moe = (
2
* args.hidden_size
* (
# MoE MLP.
+ (moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier)
# Shared MoE MLP.
+ (shared_expert_ffn_hidden_size * gated_linear_multiplier)
# Transformer layernorms.
+ (2)
)
+ self_attn_term
)
embedding_size = args.hidden_size * args.padded_vocab_size
final_layernorm = norm_size * args.hidden_size
if args.untie_embeddings_and_output_weights:
num_parameters_in_embedding_layers = 2 * embedding_size
else:
num_parameters_in_embedding_layers = embedding_size

attention_params = self_attn_term
dense_mlp_params = 2 * args.hidden_size * args.ffn_hidden_size * gated_linear_multiplier
shared_expert_params = 2 * args.hidden_size * shared_expert_ffn_hidden_size * gated_linear_multiplier
routed_expert_params = (
2 * args.hidden_size * moe_ffn_hidden_size * num_experts * gated_linear_multiplier
)
active_routed_expert_params = (
2 * args.hidden_size * moe_ffn_hidden_size * args.moe_router_topk * gated_linear_multiplier
if args.num_experts is not None
else 0
)
layernorm_params = 2 * args.hidden_size * norm_size
router_params = (
(args.hidden_size * num_experts) + (num_experts if args.add_bias_linear else 0)
if args.num_experts is not None
else 0
)
shared_expert_gate_params = (
args.hidden_size
if shared_expert_ffn_hidden_size > 0 and getattr(args, "moe_shared_expert_gate", False)
else 0
)

num_parameters_in_transformer_layer_dense = (
attention_params + dense_mlp_params + layernorm_params
)
num_parameters_in_transformer_layer_moe = (
attention_params
+ shared_expert_params
+ routed_expert_params
+ layernorm_params
+ router_params
+ shared_expert_gate_params
)
num_active_parameters_in_transformer_layer_moe = (
attention_params
+ shared_expert_params
+ active_routed_expert_params
+ layernorm_params
+ router_params
+ shared_expert_gate_params
)
num_parameters_in_transformer_block = (
num_parameters_in_transformer_layer_dense * num_dense_layers
+ num_parameters_in_transformer_layer_moe * num_moe_layers
Expand Down Expand Up @@ -180,38 +186,92 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}")
print(f"Total number of active parameters in billions: {num_active_parameters / 10**9:.2f}")

# Most loaded model shard has (1/pp_size transformer layers + 1 mtp block + 1 embedding layer) / tp_size.
num_parameters_on_most_loaded_model_shard = (
(num_parameters_in_transformer_block / args.pipeline_model_parallel_size)
+ num_parameters_in_mtp_block
expert_tensor_parallel_size = args.expert_tensor_parallel_size
expert_model_parallel_size = args.expert_model_parallel_size
expert_tensor_model_pipeline_parallel_size = (
expert_tensor_parallel_size
* expert_model_parallel_size
* args.pipeline_model_parallel_size
)
expert_data_parallel_size = args.world_size // expert_tensor_model_pipeline_parallel_size

# Split params by how they are held on each rank: regular TP, replicated, or EP/ETP.
tp_sharded_params_in_transformer_block = (
(attention_params + dense_mlp_params) * num_dense_layers
+ (attention_params + shared_expert_params) * num_moe_layers
)
replicated_params_in_transformer_block = (
layernorm_params * num_dense_layers
+ (layernorm_params + router_params + shared_expert_gate_params) * num_moe_layers
+ final_layernorm
)
expert_sharded_params_in_transformer_block = routed_expert_params * num_moe_layers
tp_sharded_params_in_mtp_block = (
(attention_params + dense_mlp_params) * mtp_num_dense_layers
+ (attention_params + shared_expert_params) * mtp_num_moe_layers
)
replicated_params_in_mtp_block = (
layernorm_params * mtp_num_dense_layers
+ (layernorm_params + router_params + shared_expert_gate_params) * mtp_num_moe_layers
)
expert_sharded_params_in_mtp_block = routed_expert_params * mtp_num_moe_layers

# Most loaded model shard has 1/pp_size transformer layers, 1 mtp block, and 1 embedding layer.
tp_sharded_params_on_most_loaded_shard = (
(tp_sharded_params_in_transformer_block / args.pipeline_model_parallel_size)
+ tp_sharded_params_in_mtp_block
+ embedding_size
) / args.tensor_model_parallel_size
replicated_params_on_most_loaded_shard = (
replicated_params_in_transformer_block / args.pipeline_model_parallel_size
) + replicated_params_in_mtp_block
expert_sharded_params_on_most_loaded_shard = (
(expert_sharded_params_in_transformer_block / args.pipeline_model_parallel_size)
+ expert_sharded_params_in_mtp_block
) / (expert_tensor_parallel_size * expert_model_parallel_size)
num_parameters_on_most_loaded_model_shard = (
tp_sharded_params_on_most_loaded_shard
+ replicated_params_on_most_loaded_shard
+ expert_sharded_params_on_most_loaded_shard
)
if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
num_parameters_on_most_loaded_model_shard += (
embedding_size / args.tensor_model_parallel_size
)
tp_sharded_params_on_most_loaded_shard += embedding_size / args.tensor_model_parallel_size
num_parameters_on_most_loaded_model_shard += embedding_size / args.tensor_model_parallel_size
if verbose:
print(
f"Number of parameters in most loaded shard in billions: "
f"{num_parameters_on_most_loaded_model_shard / 10**9:.4f}"
)

if args.pipeline_model_parallel_size > 1:
# Other shards just have (1/pp_size transformer layers) / tp_size.
num_parameters_on_other_model_shards = num_parameters_in_transformer_block / (
args.pipeline_model_parallel_size * args.tensor_model_parallel_size
# Other shards just have 1/pp_size transformer layers.
num_parameters_on_other_model_shards = (
tp_sharded_params_in_transformer_block
/ (args.pipeline_model_parallel_size * args.tensor_model_parallel_size)
+ replicated_params_in_transformer_block / args.pipeline_model_parallel_size
+ expert_sharded_params_in_transformer_block
/ (
args.pipeline_model_parallel_size
* expert_tensor_parallel_size
* expert_model_parallel_size
)
)
if verbose:
print(
f"Number of parameters in other shards in billions: "
f"{num_parameters_on_other_model_shards / 10**9:.4f}"
)

num_bytes_per_parameter = (
18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
)
def num_bytes_per_parameter(data_parallel_size):
# This estimator assumes bf16 training: bf16 model params, fp32 main gradients,
# fp32 main params, and fp32 Adam states. See docs/user-guide/features/dist_optimizer.md.
return 18 if not args.use_distributed_optimizer else 6 + (12 / data_parallel_size)

weight_and_optimizer_memory = (
num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
(tp_sharded_params_on_most_loaded_shard + replicated_params_on_most_loaded_shard)
* num_bytes_per_parameter(args.data_parallel_size)
+ expert_sharded_params_on_most_loaded_shard
* num_bytes_per_parameter(expert_data_parallel_size)
)

return weight_and_optimizer_memory
Expand Down
Loading