Skip to content

Commit 9c638d1

Browse files
committed
[Diffusion] Resolve Anima precision drift via diffusers' apply_rotary_emb
Uses diffusers' apply_rotary_emb to upcast RoPE calculations to float32, resolving the bfloat16 numerical drift vs the reference pipeline. Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
1 parent 98379bc commit 9c638d1

2 files changed

Lines changed: 10 additions & 81 deletions

File tree

benchmarks/diffusion/README.md

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -149,74 +149,3 @@ batch may still pay compile or CUDA-graph capture cost.
149149

150150
For a Qwen-Image continuous-batching replay example, see
151151
[`performance_dashboard/qwen_image_serving_performance.md`](./performance_dashboard/qwen_image_serving_performance.md).
152-
153-
## 4. Anima Native Single-File Benchmarking
154-
155-
Native Anima is benchmarked as a text-to-image model through the same serving
156-
benchmark entrypoint. Unlike standard HuggingFace model IDs, Anima serves the
157-
raw single-file transformer checkpoint and loads non-denoiser components from a
158-
Diffusers-layout component directory.
159-
160-
Download the official Anima checkpoint and components first. The commands below
161-
use `/path/to/models` as a placeholder; replace it with any local directory that
162-
has enough space for the checkpoint and component files.
163-
164-
```bash
165-
mkdir -p /path/to/models/anima-official
166-
mkdir -p /path/to/models/anima-components
167-
168-
hf download circlestone-labs/Anima \
169-
split_files/diffusion_models/anima-base-v1.0.safetensors \
170-
--local-dir /path/to/models/anima-official
171-
172-
hf download circlestone-labs/Anima-Base-v1.0-Diffusers \
173-
--local-dir /path/to/models/anima-components
174-
175-
CHECKPOINT=/path/to/models/anima-official/split_files/diffusion_models/anima-base-v1.0.safetensors
176-
COMPONENTS=/path/to/models/anima-components
177-
```
178-
179-
Run these commands from the vLLM-Omni repository in the Python environment or
180-
container where vLLM-Omni is installed.
181-
182-
Start the server with the checkpoint as `--model` and pass the component
183-
directory through `--diffusers-load-kwargs`:
184-
185-
```bash
186-
vllm serve "$CHECKPOINT" \
187-
--omni \
188-
--port 8099 \
189-
--model-class-name AnimaPipeline \
190-
--diffusers-load-kwargs "{\"components_path\":\"$COMPONENTS\"}"
191-
```
192-
193-
Then run the standard diffusion serving benchmark:
194-
195-
```bash
196-
python3 benchmarks/diffusion/diffusion_benchmark_serving.py \
197-
--base-url http://localhost:8099 \
198-
--endpoint /v1/chat/completions \
199-
--model "$CHECKPOINT" \
200-
--task t2i \
201-
--dataset random \
202-
--num-prompts 10 \
203-
--max-concurrency 1 \
204-
--warmup-requests 1 \
205-
--warmup-concurrency 1 \
206-
--width 1024 \
207-
--height 1024 \
208-
--num-inference-steps 50
209-
```
210-
211-
This matches the Diffusers baseline defaults for Anima: 1024x1024, 50 denoising
212-
steps, `max_sequence_length=512`, one image per prompt, empty negative prompt,
213-
and CFG scale 4.0 from the default guider. Do not pass `guidance_scale` through
214-
the benchmark unless you are intentionally measuring a non-default CFG setting.
215-
216-
Native Anima currently supports baseline single-GPU execution. Cache-DiT,
217-
TeaCache, CPU offload, layer-wise offload, quantization, TP/SP, CFG parallel,
218-
HSDP, and step execution are not supported by `AnimaPipeline` yet.
219-
220-
Anima uses the default single diffusion stage for local single-file checkpoint
221-
discovery when `--model-class-name AnimaPipeline` is provided; no deploy config
222-
is required.

vllm_omni/diffusion/models/anima/anima_transformer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn as nn
99
import torch.nn.functional as F
10-
from diffusers.models.embeddings import Timesteps
10+
from diffusers.models.embeddings import Timesteps, apply_rotary_emb
1111
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1212
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1313

@@ -32,13 +32,11 @@
3232
}
3333

3434

35-
def _apply_rotary_emb(hidden_states, image_rotary_emb):
36-
cos, sin = image_rotary_emb
37-
cos = cos[None, :, None, :].to(device=hidden_states.device, dtype=hidden_states.dtype)
38-
sin = sin[None, :, None, :].to(device=hidden_states.device, dtype=hidden_states.dtype)
39-
x_real, x_imag = hidden_states.reshape(*hidden_states.shape[:-1], 2, -1).unbind(-2)
40-
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
41-
return hidden_states * cos + x_rotated * sin
35+
# NOTE: We import and use diffusers' `apply_rotary_emb` instead of a custom native implementation
36+
# to prevent numerical drift in bfloat16. Diffusers upcasts queries, keys, and rotary frequency
37+
# tensors to float32 before computing the rotation, and casts back to bfloat16 at the end.
38+
# Performing the entire computation in bfloat16 accumulates precision errors across the 28
39+
# transformer blocks, which is heavily amplified by Classifier-Free Guidance (CFG).
4240

4341

4442
class CosmosPatchEmbed(nn.Module):
@@ -235,8 +233,10 @@ def _attention(self, hidden_states, encoder_hidden_states=None, attention_mask=N
235233
key = self.norm_k(key)
236234

237235
if image_rotary_emb is not None:
238-
query = _apply_rotary_emb(query, image_rotary_emb)
239-
key = _apply_rotary_emb(key, image_rotary_emb)
236+
# We use diffusers' apply_rotary_emb to leverage its internal float32 rotation upcasting
237+
# logic, resolving the bfloat16 cumulative precision drift vs. the reference pipeline.
238+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
239+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
240240

241241
attn_metadata = AttentionMetadata(attn_mask=attention_mask) if attention_mask is not None else None
242242
hidden_states = self.attn(query, key, value, attn_metadata=attn_metadata)

0 commit comments

Comments
 (0)