diff --git a/atom/plugin/vllm/attention_backend/attention_gdn.py b/atom/plugin/vllm/attention_backend/attention_gdn.py index b6158a086..a47f16431 100644 --- a/atom/plugin/vllm/attention_backend/attention_gdn.py +++ b/atom/plugin/vllm/attention_backend/attention_gdn.py @@ -395,8 +395,8 @@ def forward( query=query_non_spec, key=key_non_spec, value=value_non_spec, - a=a, - b=b, + a=a.unsqueeze(1), + b=b.unsqueeze(1), dt_bias=self.dt_bias, A_log=self.A_log, indices=non_spec_state_indices_tensor, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index fb487bc10..8403947a1 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -89,7 +89,7 @@ "ATOM_USE_CUSTOM_ALL_GATHER": lambda: ( os.getenv("ATOM_USE_CUSTOM_ALL_GATHER", "1").lower() == "1" ), - "ATOM_USE_FLYDSL_GDR": lambda: os.getenv("ATOM_USE_FLYDSL_GDR", "0").lower() == "1", + "ATOM_USE_FLYDSL_GDR": lambda: os.getenv("ATOM_USE_FLYDSL_GDR", "1").lower() == "1", # --- MoE (DeepSeek-style shared experts) --- # Dual-stream MoE only when num_tokens <= threshold; 0 disables dual-stream registration. "ATOM_DUAL_STREAM_MOE_TOKEN_THRESHOLD": lambda: int(