Skip to content

perf(deepseek-v4): dense deep_gemm warmup M-sweep + fp8_einsum coverage#427

Merged
zhyncs merged 1 commit into
mainfrom
perf/deep-gemm-warmup-dense
Jun 12, 2026
Merged

perf(deepseek-v4): dense deep_gemm warmup M-sweep + fp8_einsum coverage#427
zhyncs merged 1 commit into
mainfrom
perf/deep-gemm-warmup-dense

Conversation

@dongjiyingdjy

Copy link
Copy Markdown
Contributor

Problem

After #398, the V4-Flash startup warmup still JIT-compiled deep_gemm cubins inline on the serving hot path, which can stall the engine past the gRPC health probe (the failure class #398 set out to fix). Two residual sources:

  1. Small-N GemmType::Normal FP8 projection GEMMs (e.g. N=1536, K=4096) whose block_m the previous wave-boundary M-enumerator under-sampled.
  2. The wo_a attention output projection — a per-group GemmType::Batched GEMM run via deep_gemm.fp8_einsum — which warmup_fp8_gemm_nt_from_model skipped as is_bmm.

Root cause

A cubin's JIT key includes block_m, block_n, the swap_ab flag (and num_groups for batched GEMMs), all chosen by a C++ heuristic (get_best_config in _C.so) that deep_gemm does not expose to Python. Modelling that selection in Python leaks configs — e.g. after warming the wave boundaries, a runtime call at M=257 still picks block_m=32, swap_ab=1 (a new cubin). The transition points are linear wave boundaries, not geometric, so any sparse heuristic mis-samples them.

Fix

  • Replace _optimal_warmup_m_values/_ceil_div with a dense _warmup_m_values(max_tokens): M in [1, 2048] step 1, then step 16 up to max. Let the real heuristic pick at each M and compile whatever it selects — provably covering every config the serving path (M in [1, max]) can reach. Tiling changes most often at small M (swap_ab flips), hence step 1 there; it is stable above 2048.
  • Add warmup_fp8_einsum to cover the wo_a batched einsum. It reuses the loaded weight/scale and synthesises only the activation, in the exact layout produced by deepseek_v4_fused_inv_rope_fp8_quant. warmup_fp8_gemm_nt_from_model now warms is_bmm modules instead of skipping them.
  • Pass the chunked-prefill-derived ceiling (_deepseek_v4_mega_moe_max_num_tokens(), the same ceiling mega_moe uses) to the prefill/post-quant warmups instead of a hardcoded 8192, so --chunked-prefill-size > 8192 is covered.
  • Delete the now-unused _token_count_sweep.

Validation (B200, V4-Flash, cold DG_JIT_CACHE_DIR)

  • Exhaustive probe: fp8_gemm_nt at every M in 1..8192 across all 6 Normal shapes → 0 inline-JIT cubins. Dense prefill sweep exercising the wo_a einsum → 0 inline-JIT cubins. (Inline-JIT count: 17 → 0.)
  • GSM8K 5-shot, 200 samples: flexible 0.960 / strict 0.955.
  • Startup warmup ~294s; the dense sweep adds only ~3s over the prior sparse enumerator (compilation dominates, not launch count), well within the 1800s CI readiness timeout.

🤖 Generated with Claude Code

The prior startup warmup still JIT-compiled cubins inline on the serving
hot path: (1) small-N GemmType::Normal FP8 projection GEMMs whose block_m
the vLLM-style _optimal_warmup_m_values under-sampled, and (2) the wo_a
attention output projection (deep_gemm.fp8_einsum, GemmType::Batched) which
warmup_fp8_gemm_nt_from_model skipped as is_bmm.

block_m/block_n/swap_ab (and num_groups for batched) are all part of the
cubin JIT key and are chosen by a C++ heuristic (get_best_config in _C.so)
that deep_gemm does not expose to Python, so modelling the selection leaks
configs (e.g. M=257 picked block_m=32 swap_ab=1 after warming the wave
boundaries). Replace _optimal_warmup_m_values/_ceil_div with a dense
_warmup_m_values (M in [1,2048] step 1, then step 16) that lets the real
heuristic pick at each M and compiles whatever it selects -- provably
covering every config the serving path (M in [1,max]) can reach.

Add warmup_fp8_einsum to cover the wo_a batched einsum (reuses the loaded
weight/scale, synthesises only the activation in the layout produced by
deepseek_v4_fused_inv_rope_fp8_quant). Pass the chunked_prefill_size-derived
ceiling (_deepseek_v4_mega_moe_max_num_tokens, the same ceiling mega_moe
uses) to the prefill/post-quant warmups instead of a hardcoded 8192, so
configs with chunked_prefill_size > 8192 are covered too.

Validated on B200 (cold DG_JIT_CACHE_DIR, V4-Flash): exhaustive fp8_gemm_nt
probe at every M in 1..8192 across all 6 Normal shapes and a dense prefill
sweep both produce 0 inline-JIT cubins (was 17); startup ~450s; GSM8K
5-shot 200-sample flexible 0.960 / strict 0.955.

Signed-off-by: jiyingd <jiyingd@nvidia.com>

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dongjiyingdjy dongjiyingdjy requested a review from a team as a code owner June 12, 2026 01:30

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f4379b8642

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

k_vals = torch.zeros(
max_kv_len, head_dim_bytes, dtype=torch.uint8, device=device
).view(torch.int8)
k_scales = torch.zeros(max_kv_len, dtype=torch.int32, device=device)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mirror runtime MXFP4 scale layout in warmup

For sparse-indexer prefill, this warmup no longer uses the same KV scale tensor that serving passes to fp8_fp4_mqa_logits: the runtime gathered_k scales come from _prefill_gather_scales_workspace, which is allocated as a 2-D uint8 tensor and then passed unchanged via kv=(k_values.contiguous(), k_scales.contiguous()). Warming with a 1-D int32 tensor here either compiles a different specialization or fails to exercise the real call shape, so the first actual prefill can still JIT the ragged MQA logits kernel on the hot path.

Useful? React with 👍 / 👎.

@zhyncs zhyncs merged commit 4d0d32c into main Jun 12, 2026
30 of 36 checks passed
@zhyncs zhyncs deleted the perf/deep-gemm-warmup-dense branch June 12, 2026 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants