Skip to content

Commit f32de2e

Browse files
committed
[Diffusion] Resolve Anima precision drift via diffusers' apply_rotary_emb
Uses diffusers' apply_rotary_emb to upcast RoPE calculations to float32, resolving the bfloat16 numerical drift vs the reference pipeline. Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
1 parent 98379bc commit f32de2e

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

vllm_omni/diffusion/models/anima/anima_transformer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn as nn
99
import torch.nn.functional as F
10-
from diffusers.models.embeddings import Timesteps
10+
from diffusers.models.embeddings import Timesteps, apply_rotary_emb
1111
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1212
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1313

@@ -32,13 +32,11 @@
3232
}
3333

3434

35-
def _apply_rotary_emb(hidden_states, image_rotary_emb):
36-
cos, sin = image_rotary_emb
37-
cos = cos[None, :, None, :].to(device=hidden_states.device, dtype=hidden_states.dtype)
38-
sin = sin[None, :, None, :].to(device=hidden_states.device, dtype=hidden_states.dtype)
39-
x_real, x_imag = hidden_states.reshape(*hidden_states.shape[:-1], 2, -1).unbind(-2)
40-
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
41-
return hidden_states * cos + x_rotated * sin
35+
# NOTE: We import and use diffusers' `apply_rotary_emb` instead of a custom native implementation
36+
# to prevent numerical drift in bfloat16. Diffusers upcasts queries, keys, and rotary frequency
37+
# tensors to float32 before computing the rotation, and casts back to bfloat16 at the end.
38+
# Performing the entire computation in bfloat16 accumulates precision errors across the 28
39+
# transformer blocks, which is heavily amplified by Classifier-Free Guidance (CFG).
4240

4341

4442
class CosmosPatchEmbed(nn.Module):
@@ -235,8 +233,10 @@ def _attention(self, hidden_states, encoder_hidden_states=None, attention_mask=N
235233
key = self.norm_k(key)
236234

237235
if image_rotary_emb is not None:
238-
query = _apply_rotary_emb(query, image_rotary_emb)
239-
key = _apply_rotary_emb(key, image_rotary_emb)
236+
# We use diffusers' apply_rotary_emb to leverage its internal float32 rotation upcasting
237+
# logic, resolving the bfloat16 cumulative precision drift vs. the reference pipeline.
238+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
239+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
240240

241241
attn_metadata = AttentionMetadata(attn_mask=attention_mask) if attention_mask is not None else None
242242
hidden_states = self.attn(query, key, value, attn_metadata=attn_metadata)

0 commit comments

Comments
 (0)