Your current environment
Summary
vLLM-Omni QwenImage pipelines set txt_seq_lens from prompt_embeds_mask.sum(dim=1) when preparing generation context / prepare_encode(). Diffusers (latest main) instead derives RoPE text length from the padded encoder hidden-states width (encoder_hidden_states.shape[1]) inside QwenImageTransformer2DModel.forward() via compute_text_seq_len_from_mask().
When prompt embeddings are wider than the number of valid (non-padding) tokens — common under continuous batching (CB) after in-place padding, RL training with fixed-width collation, or any caller that supplies pre-padded embeds — vLLM-Omni builds a too-short text RoPE table. That changes attention numerics relative to diffusers/PyTorch reference and can inflate rollout–trainer KL in RL setups.
Verified on vllm-omni main @ a693ae67 against diffusers main @ d1f8e55c3.
Affected code (vllm-omni)
All QwenImage pipeline variants set txt_seq_lens from mask.sum() in their generation-context helpers:
vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py (_prepare_generation_context, used by forward(), diffuse(), and prepare_encode())
vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
# Current (vllm-omni):
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
prepare_encode() copies these values onto DiffusionRequestState.txt_seq_lens / negative_txt_seq_lens. The stepwise CB path (InputBatch) reuses them but does not refresh them after _prepare_request_prompt_field() pads prompt_embeds / masks on the request state in place.
Reference behavior (diffusers latest main)
diffusers.models.transformers.transformer_qwenimage.compute_text_seq_len_from_mask() returns the encoder tensor width for RoPE:
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
# ...
return text_seq_len, per_sample_len, encoder_hidden_states_mask
QwenImageTransformer2DModel.forward() uses that width:
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
The diffusers pipeline does not pass txt_seq_lens into the transformer; RoPE length is inferred from the padded tensor geometry.
Why this harms precision
QwenEmbedRope in vllm-omni slices text frequencies with:
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
Joint attention applies those frequencies to all encoder_hidden_states.shape[1] text positions (qwen_image_transformer.py).
If max(txt_seq_lens) < encoder_hidden_states.shape[1]:
- The RoPE table is too short for the padded embedding width.
- On the native CPU RoPE path this can raise a sequence-length
RuntimeError during rotary application.
- When execution proceeds (or on other backends), positions beyond
max(txt_seq_lens) do not receive the same RoPE as diffusers — attention logits and denoiser outputs diverge from the reference.
Concrete example: 200 valid tokens in a tensor padded to width 1058 → vllm-omni uses RoPE length 200, diffusers uses 1058.
For positions 0..199 the frequency values match (same offset into pos_freqs), but positions 200..1057 are missing RoPE entirely in vllm-omni while diffusers still applies positional encoding across the full padded width (padding tokens are masked in attention but still participate in RoPE indexing consistent with diffusers).
CB vs non-CB
Non-CB (forward() / diffuse())
Stock encode_prompt() usually produces tensors where prompt_embeds.shape[1] == prompt_embeds_mask.sum() for a single naturally-encoded prompt (no extra right-padding). In that narrow case mask.sum() and shape[1] agree and the bug is latent.
The mismatch appears whenever embedding width exceeds valid token count, e.g.:
- Caller-supplied
prompt_embeds padded to a fixed max_model_len
- Batched / collated training tensors with fixed sequence width
- Any path that right-pads embeddings without making
mask.sum() == shape[1]
Continuous batching (CB)
prepare_encode() stores txt_seq_lens = [mask.sum()].
InputBatch._prepare_request_prompt_field() may pad state.prompt_embeds / masks to a shared target_seq_len without updating state.txt_seq_lens.
denoise_step() passes the stale input_batch.txt_seq_lens into the transformer.
Even when batching a short request with a longer one, max(stored txt_seq_lens) often equals the padded width; the failure mode is when max(txt_seq_lens) stays below the padded embed width (e.g. single padded request, fixed-width training pad, or stale per-request values).
CFG (classifier-free guidance)
Both branches are affected independently — positive and negative txt_seq_lens / negative_txt_seq_lens are each computed via mask.sum() in _prepare_generation_context() and passed through:
- non-CB:
cfg_parallel.py::diffuse() (positive_kwargs / negative_kwargs)
- CB:
denoise_step() → _build_denoise_kwargs() → predict_noise_maybe_with_cfg()
Short negative prompts with padded width exhibit the same RoPE length mismatch as the positive branch.
Suggested fix
Align with diffusers / the verl-omni stepwise workaround:
# In _prepare_generation_context() (all QwenImage pipeline variants):
txt_seq_lens = [int(prompt_embeds.shape[1])] * int(prompt_embeds.shape[0])
if negative_prompt_embeds is not None:
negative_txt_seq_lens = [int(negative_prompt_embeds.shape[1])] * int(negative_prompt_embeds.shape[0])
else:
negative_txt_seq_lens = None
Optionally also refresh state.txt_seq_lens inside _prepare_request_prompt_field() after CB padding so request state stays self-consistent.
Alternative longer-term: stop threading txt_seq_lens through the pipeline and infer RoPE length inside QwenImageTransformer2DModel.forward() the way diffusers does.
Impact
- Numerical divergence vs diffusers / PyTorch reference for padded prompts
- RL training: elevated
actor/ppo_kl at step 1 when comparing vLLM rollout to trainer log-prob recomputation (reported downstream in verl-omni stepwise adapters; fixed there by overriding txt_seq_lens, but stock vllm-omni prepare_encode() remains affected)
Environment
- vllm-omni:
main @ a693ae67 (2026-06-15)
- diffusers:
main @ d1f8e55c3
- CPU verification: conda env
torch_2.9.0
Your code version
The commit id or version of vllm
The commit id or version of vllm-omni
🐛 Describe the bug
see above
Before submitting a new issue...
Your current environment
Summary
vLLM-Omni QwenImage pipelines set
txt_seq_lensfromprompt_embeds_mask.sum(dim=1)when preparing generation context /prepare_encode(). Diffusers (latestmain) instead derives RoPE text length from the padded encoder hidden-states width (encoder_hidden_states.shape[1]) insideQwenImageTransformer2DModel.forward()viacompute_text_seq_len_from_mask().When prompt embeddings are wider than the number of valid (non-padding) tokens — common under continuous batching (CB) after in-place padding, RL training with fixed-width collation, or any caller that supplies pre-padded embeds — vLLM-Omni builds a too-short text RoPE table. That changes attention numerics relative to diffusers/PyTorch reference and can inflate rollout–trainer KL in RL setups.
Verified on vllm-omni
main@a693ae67against diffusersmain@d1f8e55c3.Affected code (vllm-omni)
All QwenImage pipeline variants set
txt_seq_lensfrommask.sum()in their generation-context helpers:vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py(_prepare_generation_context, used byforward(),diffuse(), andprepare_encode())vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.pyvllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.pyvllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.pyprepare_encode()copies these values ontoDiffusionRequestState.txt_seq_lens/negative_txt_seq_lens. The stepwise CB path (InputBatch) reuses them but does not refresh them after_prepare_request_prompt_field()padsprompt_embeds/ masks on the request state in place.Reference behavior (diffusers latest main)
diffusers.models.transformers.transformer_qwenimage.compute_text_seq_len_from_mask()returns the encoder tensor width for RoPE:QwenImageTransformer2DModel.forward()uses that width:The diffusers pipeline does not pass
txt_seq_lensinto the transformer; RoPE length is inferred from the padded tensor geometry.Why this harms precision
QwenEmbedRopein vllm-omni slices text frequencies with:Joint attention applies those frequencies to all
encoder_hidden_states.shape[1]text positions (qwen_image_transformer.py).If
max(txt_seq_lens) < encoder_hidden_states.shape[1]:RuntimeErrorduring rotary application.max(txt_seq_lens)do not receive the same RoPE as diffusers — attention logits and denoiser outputs diverge from the reference.Concrete example: 200 valid tokens in a tensor padded to width 1058 → vllm-omni uses RoPE length 200, diffusers uses 1058.
For positions
0..199the frequency values match (same offset intopos_freqs), but positions200..1057are missing RoPE entirely in vllm-omni while diffusers still applies positional encoding across the full padded width (padding tokens are masked in attention but still participate in RoPE indexing consistent with diffusers).CB vs non-CB
Non-CB (
forward()/diffuse())Stock
encode_prompt()usually produces tensors whereprompt_embeds.shape[1] == prompt_embeds_mask.sum()for a single naturally-encoded prompt (no extra right-padding). In that narrow casemask.sum()andshape[1]agree and the bug is latent.The mismatch appears whenever embedding width exceeds valid token count, e.g.:
prompt_embedspadded to a fixedmax_model_lenmask.sum() == shape[1]Continuous batching (CB)
prepare_encode()storestxt_seq_lens = [mask.sum()].InputBatch._prepare_request_prompt_field()may padstate.prompt_embeds/ masks to a sharedtarget_seq_lenwithout updatingstate.txt_seq_lens.denoise_step()passes the staleinput_batch.txt_seq_lensinto the transformer.Even when batching a short request with a longer one,
max(stored txt_seq_lens)often equals the padded width; the failure mode is whenmax(txt_seq_lens)stays below the padded embed width (e.g. single padded request, fixed-width training pad, or stale per-request values).CFG (classifier-free guidance)
Both branches are affected independently — positive and negative
txt_seq_lens/negative_txt_seq_lensare each computed viamask.sum()in_prepare_generation_context()and passed through:cfg_parallel.py::diffuse()(positive_kwargs/negative_kwargs)denoise_step()→_build_denoise_kwargs()→predict_noise_maybe_with_cfg()Short negative prompts with padded width exhibit the same RoPE length mismatch as the positive branch.
Suggested fix
Align with diffusers / the verl-omni stepwise workaround:
Optionally also refresh
state.txt_seq_lensinside_prepare_request_prompt_field()after CB padding so request state stays self-consistent.Alternative longer-term: stop threading
txt_seq_lensthrough the pipeline and infer RoPE length insideQwenImageTransformer2DModel.forward()the way diffusers does.Impact
actor/ppo_klat step 1 when comparing vLLM rollout to trainer log-prob recomputation (reported downstream in verl-omni stepwise adapters; fixed there by overridingtxt_seq_lens, but stock vllm-omniprepare_encode()remains affected)Environment
main@a693ae67(2026-06-15)main@d1f8e55c3torch_2.9.0Your code version
The commit id or version of vllm
The commit id or version of vllm-omni
🐛 Describe the bug
see above
Before submitting a new issue...