Related to #4024 and #4041.
I ran a small operator-level benchmark to compare FA2 varlen and PyTorch SDPA varlen for packed varlen attention. The goal is to check whether SDPA varlen is a good backend candidate for dynamic batching.
Setup
- A100 80GB, BF16
- heads=24, head_dim=128
- FA2: flash_attn_varlen_func
- SDPA: torch.nn.attention.varlen.varlen_attn
- real cu_seqlens, no padding
- metric: attention output max_abs, not E2E image diff
main results
| case |
lens |
FA2 batch ms |
SDPA batch ms |
faster |
| edge_256 |
256 |
0.137 |
0.070 |
SDPA 1.97x |
| qwen_256_single_320 |
320 |
0.127 |
0.068 |
SDPA 1.88x |
| qwen_256_batch_320x2 |
320,320 |
0.119 |
0.067 |
SDPA 1.78x |
| qwen_mixed_txt_256+512 |
299,1115 |
0.179 |
0.180 |
tie |
| qwen_512+768 |
1088,2368 |
0.539 |
0.571 |
FA2 1.06x |
| qwen_512+1024 |
1088,4160 |
1.310 |
1.441 |
FA2 1.10x |
| bsz8_1024 |
4160 x8 |
9.040 |
10.015 |
FA2 1.11x |
Batch runtime comparison shows SDPA is faster for small-token cases, while FA2 becomes faster for larger and 1024-resolution cases.
| case |
lens |
SDPA batch vs SDPA serial |
FA2 batch vs FA2 serial |
| edge_256 |
256 |
0.003906 |
0.000000 |
| edge_319 |
319 |
0.003906 |
0.000000 |
| qwen_256_single_320 |
320 |
0.003906 |
0.000000 |
| edge_321 |
321 |
0.003906 |
0.000000 |
| edge_384 |
384 |
0.001953 |
0.000000 |
| qwen_256_batch_320x2 |
320,320 |
0.001953 |
0.000000 |
| qwen_mixed_txt_256+512 |
299,1115 |
0.003906 |
0.000000 |
| qwen_512+768 |
1088,2368 |
0.000000 |
0.000000 |
| qwen_512+1024 |
1088,4160 |
0.000000 |
0.000000 |
| bsz1_1024 |
4160 |
0.000000 |
0.000000 |
| bsz2_1024 |
4160,4160 |
0.000000 |
0.000000 |
| bsz4_1024 |
4160 x4 |
0.000000 |
0.000000 |
| bsz8_1024 |
4160 x8 |
0.000000 |
0.000000 |
Batch-vs-serial precision shows FA2 varlen batching is exact across all tested cases, while SDPA varlen has drift on small-token cases.
| case |
lens |
max_abs |
| edge_256 |
256 |
0.003906 |
| edge_319 |
319 |
0.003906 |
| qwen_256_single_320 |
320 |
0.003906 |
| edge_321 |
321 |
0.003906 |
| edge_384 |
384 |
0.001953 |
| qwen_256_batch_320x2 |
320,320 |
0.001953 |
| qwen_mixed_txt_256+512 |
299,1115 |
0.003906 |
| qwen_512+768 |
1088,2368 |
0.000488 |
| qwen_512+1024 |
1088,4160 |
0.000977 |
| bsz1_1024 |
4160 |
0.000488 |
| bsz2_1024 |
4160,4160 |
0.000977 |
| bsz4_1024 |
4160 x4 |
0.000488 |
| bsz8_1024 |
4160 x8 |
0.000488 |
Serial backend comparison shows FA2 and SDPA differ numerically even without batching, so backend-substitution drift must be separated from batching
drift.
Observations
-
FA2 and SDPA are not numerically identical: even in serial execution, their BF16 outputs differ, so backend-substitution drift must be tracked separately from batching drift.
-
SDPA varlen has same-backend batch-vs-serial drift in small-token cases (max_abs = 0.001953-0.003906), while FA2 varlen batch-vs-serial is exact in this test.
-
FA2 and SDPA have different performance profiles: on A100, SDPA is faster for small-token cases while FA2 is faster for larger image-token cases; on H200, SDPA appears faster for most tested cases.
cc @Gaohan123 @wuhang2014 @hsliuustc0106 @SamitHuang @fake0fan
Related to #4024 and #4041.
I ran a small operator-level benchmark to compare FA2 varlen and PyTorch SDPA varlen for packed varlen attention. The goal is to check whether SDPA varlen is a good backend candidate for dynamic batching.
Setup
main results
Batch runtime comparison shows SDPA is faster for small-token cases, while FA2 becomes faster for larger and 1024-resolution cases.
Batch-vs-serial precision shows FA2 varlen batching is exact across all tested cases, while SDPA varlen has drift on small-token cases.
Serial backend comparison shows FA2 and SDPA differ numerically even without batching, so backend-substitution drift must be separated from batching
drift.
Observations
FA2 and SDPA are not numerically identical: even in serial execution, their BF16 outputs differ, so backend-substitution drift must be tracked separately from batching drift.
SDPA varlen has same-backend batch-vs-serial drift in small-token cases (
max_abs = 0.001953-0.003906), while FA2 varlen batch-vs-serial is exact in this test.FA2 and SDPA have different performance profiles: on A100, SDPA is faster for small-token cases while FA2 is faster for larger image-token cases; on H200, SDPA appears faster for most tested cases.
cc @Gaohan123 @wuhang2014 @hsliuustc0106 @SamitHuang @fake0fan