|
7 | 7 | import torch |
8 | 8 | import torch.nn as nn |
9 | 9 | import torch.nn.functional as F |
10 | | -from diffusers.models.embeddings import Timesteps |
| 10 | +from diffusers.models.embeddings import Timesteps, apply_rotary_emb |
11 | 11 | from diffusers.models.modeling_outputs import Transformer2DModelOutput |
12 | 12 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
13 | 13 |
|
|
32 | 32 | } |
33 | 33 |
|
34 | 34 |
|
35 | | -def _apply_rotary_emb(hidden_states, image_rotary_emb): |
36 | | - cos, sin = image_rotary_emb |
37 | | - cos = cos[None, :, None, :].to(device=hidden_states.device, dtype=hidden_states.dtype) |
38 | | - sin = sin[None, :, None, :].to(device=hidden_states.device, dtype=hidden_states.dtype) |
39 | | - x_real, x_imag = hidden_states.reshape(*hidden_states.shape[:-1], 2, -1).unbind(-2) |
40 | | - x_rotated = torch.cat([-x_imag, x_real], dim=-1) |
41 | | - return hidden_states * cos + x_rotated * sin |
| 35 | +# NOTE: We import and use diffusers' `apply_rotary_emb` instead of a custom native implementation |
| 36 | +# to prevent numerical drift in bfloat16. Diffusers upcasts queries, keys, and rotary frequency |
| 37 | +# tensors to float32 before computing the rotation, and casts back to bfloat16 at the end. |
| 38 | +# Performing the entire computation in bfloat16 accumulates precision errors across the 28 |
| 39 | +# transformer blocks, which is heavily amplified by Classifier-Free Guidance (CFG). |
42 | 40 |
|
43 | 41 |
|
44 | 42 | class CosmosPatchEmbed(nn.Module): |
@@ -235,8 +233,10 @@ def _attention(self, hidden_states, encoder_hidden_states=None, attention_mask=N |
235 | 233 | key = self.norm_k(key) |
236 | 234 |
|
237 | 235 | if image_rotary_emb is not None: |
238 | | - query = _apply_rotary_emb(query, image_rotary_emb) |
239 | | - key = _apply_rotary_emb(key, image_rotary_emb) |
| 236 | + # We use diffusers' apply_rotary_emb to leverage its internal float32 rotation upcasting |
| 237 | + # logic, resolving the bfloat16 cumulative precision drift vs. the reference pipeline. |
| 238 | + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) |
| 239 | + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) |
240 | 240 |
|
241 | 241 | attn_metadata = AttentionMetadata(attn_mask=attention_mask) if attention_mask is not None else None |
242 | 242 | hidden_states = self.attn(query, key, value, attn_metadata=attn_metadata) |
|
0 commit comments