Skip to content

FlashAttention RuntimeError: batch size must be positive when running PPO training with Qwen2.5-3B + FSDP + remove-padding #5

@user50lab

Description

@user50lab

Hi, I encountered a critical runtime error when running PPO training using Router-R1 with:
Qwen2.5-3B-Instruct
FSDP (param/grad/optimizer offloading)
use_remove_padding=True
use_dynamic_bsz=True
FlashAttention (default in Qwen2.5)
rollout worker generating sequences (actor_rollout_generate_sequences)
During rollout generation, FlashAttention crashes with:
RuntimeError: batch size must be positive
This happens inside flash_attn_varlen_fwd() → _flash_attn_varlen_forward().
It seems the model receives an empty micro-batch after remove-padding + dynamic batching, leading to an invalid call where batch_size == 0. The crash happens before PPO training can proceed.

Error Log:

[36m(main_task pid=1274802)�[0m �[36mray::WorkerDict.actor_rollout_generate_sequences()�[39m (pid=1275248, ip=10.32.35.206, actor_id=20605ce952c5d5c75042ff6901000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x15130becea60>)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/ray/base.py", line 399, in func
�[36m(main_task pid=1274802)�[0m return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/base/decorator.py", line 404, in inner
�[36m(main_task pid=1274802)�[0m return func(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/fsdp_workers.py", line 465, in generate_sequences
�[36m(main_task pid=1274802)�[0m old_log_probs = self.actor.compute_log_prob(data=output)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/actor/dp_actor.py", line 191, in compute_log_prob
�[36m(main_task pid=1274802)�[0m _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/actor/dp_actor.py", line 94, in _forward_micro_batch
�[36m(main_task pid=1274802)�[0m output = self.actor_module(input_ids=input_ids_rmpad,
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
�[36m(main_task pid=1274802)�[0m output = self._fsdp_wrapped_module(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1165, in forward
�[36m(main_task pid=1274802)�[0m outputs = self.model(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
�[36m(main_task pid=1274802)�[0m layer_outputs = decoder_layer(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
�[36m(main_task pid=1274802)�[0m output = self._fsdp_wrapped_module(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
�[36m(main_task pid=1274802)�[0m hidden_states, self_attn_weights, present_key_value = self.self_attn(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 443, in forward
�[36m(main_task pid=1274802)�[0m attn_output = _flash_attention_forward(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/modeling_flash_attention_utils.py", line 346, in _flash_attention_forward
�[36m(main_task pid=1274802)�[0m attn_output = flash_attn_varlen_func(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
�[36m(main_task pid=1274802)�[0m return FlashAttnVarlenFunc.apply(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/autograd/function.py", line 574, in apply
�[36m(main_task pid=1274802)�[0m return super().apply(*args, **kwargs) # type: ignore[misc]
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
�[36m(main_task pid=1274802)�[0m out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/ops.py", line 1061, in call
�[36m(main_task pid=1274802)�[0m return self
._op(*args, **(kwargs or {}))
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/_library/autograd.py", line 98, in autograd_impl
�[36m(main_task pid=1274802)�[0m result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/autograd/function.py", line 574, in apply
�[36m(main_task pid=1274802)�[0m return super().apply(*args, **kwargs) # type: ignore[misc]
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/_library/autograd.py", line 40, in forward
�[36m(main_task pid=1274802)�[0m result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/ops.py", line 672, in redispatch
�[36m(main_task pid=1274802)�[0m return self
._handle.redispatch_boxed(keyset, *args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/_library/custom_ops.py", line 236, in backend_impl
�[36m(main_task pid=1274802)�[0m result = self._backend_fns[device_type](*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
�[36m(main_task pid=1274802)�[0m out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
�[36m(main_task pid=1274802)�[0m RuntimeError: batch size must be positive
Error executing job with overrides: ['data.train_files=data/nq_search/train_nh_qwen.parquet', 'data.val_files=data/nq_search/test_nh_qwen.parquet', 'data.train_data_num=null', 'data.val_data_num=null', 'data.train_batch_size=64', 'data.val_batch_size=64', 'data.max_prompt_length=4096', 'data.max_response_length=1024', 'data.max_start_length=2048', 'data.max_obs_length=600', 'data.shuffle_train_dataloader=True', 'algorithm.adv_estimator=gae', 'actor_rollout_ref.model.path=/vast/cj2851/proiects-local-models/ARLLM-series/Qwen2.5-3B-Instruct', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.model.enable_gradient_checkpointing=true', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.actor.use_dynamic_bsz=True', 'actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.0', 'actor_rollout_ref.actor.ppo_mini_batch_size=32', 'actor_rollout_ref.actor.ppo_micro_batch_size=8', 'actor_rollout_ref.actor.fsdp_config.param_offload=true', 'actor_rollout_ref.actor.fsdp_config.grad_offload=true', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=true', 'actor_rollout_ref.rollout.log_prob_micro_batch_size=16', 'actor_rollout_ref.rollout.tensor_model_parallel_size=1', 'actor_rollout_ref.rollout.name=vllm', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.6', 'actor_rollout_ref.ref.log_prob_micro_batch_size=16', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'actor_rollout_ref.rollout.n_agent=1', 'actor_rollout_ref.rollout.temperature=1', 'actor_rollout_ref.actor.state_masking=true', 'critic.optim.lr=1e-5', 'critic.model.use_remove_padding=True', 'critic.optim.lr_warmup_steps_ratio=0.0', 'critic.model.path=/vast/cj2851/proiects-local-models/ARLLM-series/Qwen2.5-3B-Instruct', 'critic.model.enable_gradient_checkpointing=true', 'critic.ppo_micro_batch_size=8', 'critic.model.fsdp_config.param_offload=true', 'critic.model.fsdp_config.grad_offload=true', 'critic.model.fsdp_config.optimizer_offload=true', 'algorithm.kl_ctrl.kl_coef=0.001', 'algorithm.no_think_rl=false', 'trainer.logger=[wandb]', '+trainer.val_only=false', '+trainer.val_before_train=false', 'trainer.default_hdfs_dir=null', 'trainer.n_gpus_per_node=1', 'trainer.nnodes=1', 'trainer.save_freq=15', 'trainer.test_freq=15', 'trainer.project_name=Router-R1-Official', 'trainer.experiment_name=nh-bs64-ppo-qwen2.5-3b-it-em', 'trainer.total_epochs=100', 'trainer.total_training_steps=225', 'trainer.default_hdfs_dir=null', 'trainer.default_local_dir=verl_checkpoints/nh-bs64-ppo-qwen2.5-3b-it-em', 'max_turns=4', '+reward_metric=em', '+cost_coe=0.0', '+api_base=https://integrate.api.nvidia.com/v1', '+api_key=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']
Traceback (most recent call last):
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/trainer/main_ppo.py", line 294, in main
ray.get(main_task.remote(config))
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/worker.py", line 2961, in get
values, debugger_breakpoint = worker.get_objects(
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/worker.py", line 1026, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): �[36mray::main_task()�[39m (pid=1274802, ip=10.32.35.206)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/trainer/main_ppo.py", line 382, in main_task
trainer.fit()
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/trainer/ppo/ray_trainer.py", line 780, in fit
final_gen_batch_output = generation_manager.run_llm_loop(
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/router_r1/llm_agent/generation.py", line 297, in run_llm_loop
gen_output = self._generate_with_gpu_padding(rollings_active)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/router_r1/llm_agent/generation.py", line 181, in _generate_with_gpu_padding
return self.actor_rollout_wg.generate_sequences(active_batch)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/ray/base.py", line 42, in func
output = ray.get(output)
ray.exceptions.RayTaskError(AssertionError): �[36mray::WorkerDict.actor_rollout_generate_sequences()�[39m (pid=1275248, ip=10.32.35.206, actor_id=20605ce952c5d5c75042ff6901000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x15130becea60>)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/ray/base.py", line 399, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/base/decorator.py", line 404, in inner
return func(*args, **kwargs)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/fsdp_workers.py", line 446, in generate_sequences
with self.rollout_sharding_manager:
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/sharding_manager/fsdp_vllm.py", line 71, in enter
params = self.module.state_dict()
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1935, in state_dict
hook(self, prefix, keep_vars)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 787, in _pre_state_dict_hook
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 308, in _full_pre_state_dict_hook
_common_unshard_pre_state_dict_hook(
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 175, in _common_unshard_pre_state_dict_hook
_enter_unshard_params_ctx(
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 139, in _enter_unshard_params_ctx
fsdp_state._unshard_params_ctx[module].enter()
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/contextlib.py", line 119, in enter
return next(self.gen)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 186, in _unshard_fsdp_state_params
assert (
AssertionError: Expects the handle training to be IDLE but got HandleTrainingState.FORWARD

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions