Skip to content

[Performance][Discussion]: FA2 varlen vs PyTorch SDPA varlen: performance tradeoff and numerical drift #4418

@asukaqaq-s

Description

@asukaqaq-s

Related to #4024 and #4041.
I ran a small operator-level benchmark to compare FA2 varlen and PyTorch SDPA varlen for packed varlen attention. The goal is to check whether SDPA varlen is a good backend candidate for dynamic batching.

Setup

  • A100 80GB, BF16
  • heads=24, head_dim=128
  • FA2: flash_attn_varlen_func
  • SDPA: torch.nn.attention.varlen.varlen_attn
  • real cu_seqlens, no padding
  • metric: attention output max_abs, not E2E image diff

main results

case lens FA2 batch ms SDPA batch ms faster
edge_256 256 0.137 0.070 SDPA 1.97x
qwen_256_single_320 320 0.127 0.068 SDPA 1.88x
qwen_256_batch_320x2 320,320 0.119 0.067 SDPA 1.78x
qwen_mixed_txt_256+512 299,1115 0.179 0.180 tie
qwen_512+768 1088,2368 0.539 0.571 FA2 1.06x
qwen_512+1024 1088,4160 1.310 1.441 FA2 1.10x
bsz8_1024 4160 x8 9.040 10.015 FA2 1.11x

Batch runtime comparison shows SDPA is faster for small-token cases, while FA2 becomes faster for larger and 1024-resolution cases.

case lens SDPA batch vs SDPA serial FA2 batch vs FA2 serial
edge_256 256 0.003906 0.000000
edge_319 319 0.003906 0.000000
qwen_256_single_320 320 0.003906 0.000000
edge_321 321 0.003906 0.000000
edge_384 384 0.001953 0.000000
qwen_256_batch_320x2 320,320 0.001953 0.000000
qwen_mixed_txt_256+512 299,1115 0.003906 0.000000
qwen_512+768 1088,2368 0.000000 0.000000
qwen_512+1024 1088,4160 0.000000 0.000000
bsz1_1024 4160 0.000000 0.000000
bsz2_1024 4160,4160 0.000000 0.000000
bsz4_1024 4160 x4 0.000000 0.000000
bsz8_1024 4160 x8 0.000000 0.000000

Batch-vs-serial precision shows FA2 varlen batching is exact across all tested cases, while SDPA varlen has drift on small-token cases.

case lens max_abs
edge_256 256 0.003906
edge_319 319 0.003906
qwen_256_single_320 320 0.003906
edge_321 321 0.003906
edge_384 384 0.001953
qwen_256_batch_320x2 320,320 0.001953
qwen_mixed_txt_256+512 299,1115 0.003906
qwen_512+768 1088,2368 0.000488
qwen_512+1024 1088,4160 0.000977
bsz1_1024 4160 0.000488
bsz2_1024 4160,4160 0.000977
bsz4_1024 4160 x4 0.000488
bsz8_1024 4160 x8 0.000488

Serial backend comparison shows FA2 and SDPA differ numerically even without batching, so backend-substitution drift must be separated from batching
drift.

Observations

  1. FA2 and SDPA are not numerically identical: even in serial execution, their BF16 outputs differ, so backend-substitution drift must be tracked separately from batching drift.

  2. SDPA varlen has same-backend batch-vs-serial drift in small-token cases (max_abs = 0.001953-0.003906), while FA2 varlen batch-vs-serial is exact in this test.

  3. FA2 and SDPA have different performance profiles: on A100, SDPA is faster for small-token cases while FA2 is faster for larger image-token cases; on H200, SDPA appears faster for most tested cases.

cc @Gaohan123 @wuhang2014 @hsliuustc0106 @SamitHuang @fake0fan

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions