-
Notifications
You must be signed in to change notification settings - Fork 13
Description
Hi, I encountered a critical runtime error when running PPO training using Router-R1 with:
Qwen2.5-3B-Instruct
FSDP (param/grad/optimizer offloading)
use_remove_padding=True
use_dynamic_bsz=True
FlashAttention (default in Qwen2.5)
rollout worker generating sequences (actor_rollout_generate_sequences)
During rollout generation, FlashAttention crashes with:
RuntimeError: batch size must be positive
This happens inside flash_attn_varlen_fwd() → _flash_attn_varlen_forward().
It seems the model receives an empty micro-batch after remove-padding + dynamic batching, leading to an invalid call where batch_size == 0. The crash happens before PPO training can proceed.
Error Log:
[36m(main_task pid=1274802)�[0m �[36mray::WorkerDict.actor_rollout_generate_sequences()�[39m (pid=1275248, ip=10.32.35.206, actor_id=20605ce952c5d5c75042ff6901000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x15130becea60>)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/ray/base.py", line 399, in func
�[36m(main_task pid=1274802)�[0m return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/base/decorator.py", line 404, in inner
�[36m(main_task pid=1274802)�[0m return func(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/fsdp_workers.py", line 465, in generate_sequences
�[36m(main_task pid=1274802)�[0m old_log_probs = self.actor.compute_log_prob(data=output)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/actor/dp_actor.py", line 191, in compute_log_prob
�[36m(main_task pid=1274802)�[0m _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
�[36m(main_task pid=1274802)�[0m File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/actor/dp_actor.py", line 94, in _forward_micro_batch
�[36m(main_task pid=1274802)�[0m output = self.actor_module(input_ids=input_ids_rmpad,
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
�[36m(main_task pid=1274802)�[0m output = self._fsdp_wrapped_module(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1165, in forward
�[36m(main_task pid=1274802)�[0m outputs = self.model(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
�[36m(main_task pid=1274802)�[0m layer_outputs = decoder_layer(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
�[36m(main_task pid=1274802)�[0m output = self._fsdp_wrapped_module(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
�[36m(main_task pid=1274802)�[0m hidden_states, self_attn_weights, present_key_value = self.self_attn(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
�[36m(main_task pid=1274802)�[0m return self._call_impl(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
�[36m(main_task pid=1274802)�[0m return forward_call(*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 443, in forward
�[36m(main_task pid=1274802)�[0m attn_output = _flash_attention_forward(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/transformers/modeling_flash_attention_utils.py", line 346, in _flash_attention_forward
�[36m(main_task pid=1274802)�[0m attn_output = flash_attn_varlen_func(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
�[36m(main_task pid=1274802)�[0m return FlashAttnVarlenFunc.apply(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/autograd/function.py", line 574, in apply
�[36m(main_task pid=1274802)�[0m return super().apply(*args, **kwargs) # type: ignore[misc]
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
�[36m(main_task pid=1274802)�[0m out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/ops.py", line 1061, in call
�[36m(main_task pid=1274802)�[0m return self._op(*args, **(kwargs or {}))
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/_library/autograd.py", line 98, in autograd_impl
�[36m(main_task pid=1274802)�[0m result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/autograd/function.py", line 574, in apply
�[36m(main_task pid=1274802)�[0m return super().apply(*args, **kwargs) # type: ignore[misc]
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/_library/autograd.py", line 40, in forward
�[36m(main_task pid=1274802)�[0m result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/ops.py", line 672, in redispatch
�[36m(main_task pid=1274802)�[0m return self._handle.redispatch_boxed(keyset, *args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/_library/custom_ops.py", line 236, in backend_impl
�[36m(main_task pid=1274802)�[0m result = self._backend_fns[device_type](*args, **kwargs)
�[36m(main_task pid=1274802)�[0m File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
�[36m(main_task pid=1274802)�[0m out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
�[36m(main_task pid=1274802)�[0m RuntimeError: batch size must be positive
Error executing job with overrides: ['data.train_files=data/nq_search/train_nh_qwen.parquet', 'data.val_files=data/nq_search/test_nh_qwen.parquet', 'data.train_data_num=null', 'data.val_data_num=null', 'data.train_batch_size=64', 'data.val_batch_size=64', 'data.max_prompt_length=4096', 'data.max_response_length=1024', 'data.max_start_length=2048', 'data.max_obs_length=600', 'data.shuffle_train_dataloader=True', 'algorithm.adv_estimator=gae', 'actor_rollout_ref.model.path=/vast/cj2851/proiects-local-models/ARLLM-series/Qwen2.5-3B-Instruct', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.model.enable_gradient_checkpointing=true', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.actor.use_dynamic_bsz=True', 'actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.0', 'actor_rollout_ref.actor.ppo_mini_batch_size=32', 'actor_rollout_ref.actor.ppo_micro_batch_size=8', 'actor_rollout_ref.actor.fsdp_config.param_offload=true', 'actor_rollout_ref.actor.fsdp_config.grad_offload=true', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=true', 'actor_rollout_ref.rollout.log_prob_micro_batch_size=16', 'actor_rollout_ref.rollout.tensor_model_parallel_size=1', 'actor_rollout_ref.rollout.name=vllm', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.6', 'actor_rollout_ref.ref.log_prob_micro_batch_size=16', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'actor_rollout_ref.rollout.n_agent=1', 'actor_rollout_ref.rollout.temperature=1', 'actor_rollout_ref.actor.state_masking=true', 'critic.optim.lr=1e-5', 'critic.model.use_remove_padding=True', 'critic.optim.lr_warmup_steps_ratio=0.0', 'critic.model.path=/vast/cj2851/proiects-local-models/ARLLM-series/Qwen2.5-3B-Instruct', 'critic.model.enable_gradient_checkpointing=true', 'critic.ppo_micro_batch_size=8', 'critic.model.fsdp_config.param_offload=true', 'critic.model.fsdp_config.grad_offload=true', 'critic.model.fsdp_config.optimizer_offload=true', 'algorithm.kl_ctrl.kl_coef=0.001', 'algorithm.no_think_rl=false', 'trainer.logger=[wandb]', '+trainer.val_only=false', '+trainer.val_before_train=false', 'trainer.default_hdfs_dir=null', 'trainer.n_gpus_per_node=1', 'trainer.nnodes=1', 'trainer.save_freq=15', 'trainer.test_freq=15', 'trainer.project_name=Router-R1-Official', 'trainer.experiment_name=nh-bs64-ppo-qwen2.5-3b-it-em', 'trainer.total_epochs=100', 'trainer.total_training_steps=225', 'trainer.default_hdfs_dir=null', 'trainer.default_local_dir=verl_checkpoints/nh-bs64-ppo-qwen2.5-3b-it-em', 'max_turns=4', '+reward_metric=em', '+cost_coe=0.0', '+api_base=https://integrate.api.nvidia.com/v1', '+api_key=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']
Traceback (most recent call last):
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/trainer/main_ppo.py", line 294, in main
ray.get(main_task.remote(config))
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/worker.py", line 2961, in get
values, debugger_breakpoint = worker.get_objects(
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/ray/_private/worker.py", line 1026, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): �[36mray::main_task()�[39m (pid=1274802, ip=10.32.35.206)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/trainer/main_ppo.py", line 382, in main_task
trainer.fit()
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/trainer/ppo/ray_trainer.py", line 780, in fit
final_gen_batch_output = generation_manager.run_llm_loop(
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/router_r1/llm_agent/generation.py", line 297, in run_llm_loop
gen_output = self._generate_with_gpu_padding(rollings_active)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/router_r1/llm_agent/generation.py", line 181, in _generate_with_gpu_padding
return self.actor_rollout_wg.generate_sequences(active_batch)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/ray/base.py", line 42, in func
output = ray.get(output)
ray.exceptions.RayTaskError(AssertionError): �[36mray::WorkerDict.actor_rollout_generate_sequences()�[39m (pid=1275248, ip=10.32.35.206, actor_id=20605ce952c5d5c75042ff6901000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x15130becea60>)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/ray/base.py", line 399, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/single_controller/base/decorator.py", line 404, in inner
return func(*args, **kwargs)
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/fsdp_workers.py", line 446, in generate_sequences
with self.rollout_sharding_manager:
File "/scratch/cj2851/PROJECTS/SearchRouter_R1/Router-R1/verl/workers/sharding_manager/fsdp_vllm.py", line 71, in enter
params = self.module.state_dict()
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1935, in state_dict
hook(self, prefix, keep_vars)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 787, in _pre_state_dict_hook
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 308, in _full_pre_state_dict_hook
_common_unshard_pre_state_dict_hook(
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 175, in _common_unshard_pre_state_dict_hook
_enter_unshard_params_ctx(
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 139, in _enter_unshard_params_ctx
fsdp_state._unshard_params_ctx[module].enter()
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/contextlib.py", line 119, in enter
return next(self.gen)
File "/vast/cj2851/conda_env/router-r1/lib/python3.9/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 186, in _unshard_fsdp_state_params
assert (
AssertionError: Expects the handle training to be IDLE but got HandleTrainingState.FORWARDSet the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.