Skip to content

Memory leak when log_max_attention_logits is set #4699

@supershadoe

Description

@supershadoe

Describe the bug

Currently, there's a flag in mcore to log the maximum value of QK attention logits in tensorboard --log-max-attention-logit. This flag and --qk-clip are the only flags that trigger obtaining and storing the max QK logit from TE in TEDotProductAttention

If QKClip is enabled, clip_qk() uses the QKClip implementation in SelfAttention or MLASelfAttention to clip the weights after which core_attention.current_max_attn_logits is reset.

When --qk-clip is not used, the attention logits are no longer reset because self_attention.clip_qk() is never called. Additionally, as torch.max() is called without torch.no_grad() or .detach() in TEDotProductAttention, every batch's max logits gets added to the computational graph, thus causing a memory leak (on RAM, not GPU) that grows on every DPA call to TE.

Steps/Code to reproduce bug

In NeMo framework container 26.04, the attached train.py and repro.txt (standard hydra config file; GitHub didn't let me upload it as YAML) can be used. Attached config is a slightly modified version of llama 3's pretraining recipe available on mbridge. Basically if model.log_max_attention_logit is True but model.qk_clip is False, this bug occurs.

export CUDA_DEVICE_MAX_CONNECTIONS=1
export HYDRA_FULL_ERROR=1
export NCCL_NVLS_ENABLE=0
export HF_TOKEN=<redacted>
export MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING=true
export MLFLOW_DISABLE_TELEMETRY=true

python -u "/workspace/scripts/train.py" \
    --config-path="/workspace/scripts/config" \
    --config-name="repro" \
  1>> "/workspace/logs/stdout.log" 2>> "/workspace/logs/stderr.log"

Expected behavior

CPU memory should not keep constantly increasing during the pretraining process.

Additional context

Possible solution: Either detaching the tensor in TE/mcore from the autograd graph or setting the logits to None even for the case where qk-clip is not enabled.

System memory usage graph illustrating the issue:

Memory usage graph

Here,

  • control = both log_max_attention_logit and qk_clip set to False
  • qk_clip = both log_max_attention_logit and qk_clip set to True
  • repro = log_max_attention_logit alone set to True (qk_clip = False)

Attached: train.py, repro.txt

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions