Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,24 +282,38 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
# On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
return_lse = query.device.type != "cpu"

# PyTorch >= 2.9 renamed return_lse to return_aux
torch_version = get_torch_version()
use_return_aux = version.parse(torch_version).base_version >= "2.9"

if not return_lse and s_aux is not None:
raise ValueError(
"Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA"
)

# Build the kwargs for flex attention
flex_attn_kwargs = {
"score_mod": score_mod,
"block_mask": block_mask,
"enable_gqa": enable_gqa,
"scale": scaling,
"kernel_options": kernel_options,
"training": module.training,
}

# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
# In PyTorch >= 2.9, return_lse was renamed to return_aux
if use_return_aux:
flex_attn_kwargs["return_aux"] = return_lse
else:
flex_attn_kwargs["return_lse"] = return_lse

flex_attention_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=return_lse,
training=module.training,
**flex_attn_kwargs,
)
# lse is returned in float32
if return_lse:
Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/doge/modeling_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...integrations.flex_attention import compile_friendly_flex_attention
from ...utils.import_utils import get_torch_version
from packaging import version
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
Expand Down Expand Up @@ -233,17 +235,31 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
return score

# PyTorch >= 2.9 renamed return_lse to return_aux
torch_version = get_torch_version()
use_return_aux = version.parse(torch_version).base_version >= "2.9"

# Build kwargs for flex attention
flex_attn_kwargs = {
"score_mod": score_mod,
"block_mask": block_mask,
"enable_gqa": True,
"scale": scaling,
}

# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
# In PyTorch >= 2.9, return_lse was renamed to return_aux
if use_return_aux:
flex_attn_kwargs["return_aux"] = True
else:
flex_attn_kwargs["return_lse"] = True

attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
**flex_attn_kwargs,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/doge/modular_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
from ...integrations.flex_attention import compile_friendly_flex_attention
from ...utils.import_utils import get_torch_version
from packaging import version
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import RopeParameters
Expand Down Expand Up @@ -202,17 +204,31 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
return score

# PyTorch >= 2.9 renamed return_lse to return_aux
torch_version = get_torch_version()
use_return_aux = version.parse(torch_version).base_version >= "2.9"

# Build kwargs for flex attention
flex_attn_kwargs = {
"score_mod": score_mod,
"block_mask": block_mask,
"enable_gqa": True,
"scale": scaling,
}

# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
# In PyTorch >= 2.9, return_lse was renamed to return_aux
if use_return_aux:
flex_attn_kwargs["return_aux"] = True
else:
flex_attn_kwargs["return_lse"] = True

attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
**flex_attn_kwargs,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/utils/generic.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I'm not sure about the other changes in this PR, but the fix for int() when it should be float() is correct. Surprised it went unnoticed for so long.

Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def is_flash_attention_requested(
else:
checked_attention_implementation = requested_attention_implementation

# theoretically can happen, equivalent to default implementation (sdpa/eager)
if checked_attention_implementation is None:
return False

# If a specific version is requested, look for a pattern of type "flash...{version}"
if version is not None:
return re.match(r".*flash.*" + str(version), checked_attention_implementation) is not None
Expand Down Expand Up @@ -656,9 +660,9 @@ def torch_float(x):
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
"""
if not _is_torch_available:
return int(x)
return float(x)

return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else float(x)


def filter_out_non_signature_kwargs(extra: list | None = None):
Expand Down
Loading