perf(deepseek-v4): dense deep_gemm warmup M-sweep + fp8_einsum coverage#427
Conversation
The prior startup warmup still JIT-compiled cubins inline on the serving hot path: (1) small-N GemmType::Normal FP8 projection GEMMs whose block_m the vLLM-style _optimal_warmup_m_values under-sampled, and (2) the wo_a attention output projection (deep_gemm.fp8_einsum, GemmType::Batched) which warmup_fp8_gemm_nt_from_model skipped as is_bmm. block_m/block_n/swap_ab (and num_groups for batched) are all part of the cubin JIT key and are chosen by a C++ heuristic (get_best_config in _C.so) that deep_gemm does not expose to Python, so modelling the selection leaks configs (e.g. M=257 picked block_m=32 swap_ab=1 after warming the wave boundaries). Replace _optimal_warmup_m_values/_ceil_div with a dense _warmup_m_values (M in [1,2048] step 1, then step 16) that lets the real heuristic pick at each M and compiles whatever it selects -- provably covering every config the serving path (M in [1,max]) can reach. Add warmup_fp8_einsum to cover the wo_a batched einsum (reuses the loaded weight/scale, synthesises only the activation in the layout produced by deepseek_v4_fused_inv_rope_fp8_quant). Pass the chunked_prefill_size-derived ceiling (_deepseek_v4_mega_moe_max_num_tokens, the same ceiling mega_moe uses) to the prefill/post-quant warmups instead of a hardcoded 8192, so configs with chunked_prefill_size > 8192 are covered too. Validated on B200 (cold DG_JIT_CACHE_DIR, V4-Flash): exhaustive fp8_gemm_nt probe at every M in 1..8192 across all 6 Normal shapes and a dense prefill sweep both produce 0 inline-JIT cubins (was 17); startup ~450s; GSM8K 5-shot 200-sample flexible 0.960 / strict 0.955. Signed-off-by: jiyingd <jiyingd@nvidia.com> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f4379b8642
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| k_vals = torch.zeros( | ||
| max_kv_len, head_dim_bytes, dtype=torch.uint8, device=device | ||
| ).view(torch.int8) | ||
| k_scales = torch.zeros(max_kv_len, dtype=torch.int32, device=device) |
There was a problem hiding this comment.
Mirror runtime MXFP4 scale layout in warmup
For sparse-indexer prefill, this warmup no longer uses the same KV scale tensor that serving passes to fp8_fp4_mqa_logits: the runtime gathered_k scales come from _prefill_gather_scales_workspace, which is allocated as a 2-D uint8 tensor and then passed unchanged via kv=(k_values.contiguous(), k_scales.contiguous()). Warming with a 1-D int32 tensor here either compiles a different specialization or fails to exercise the real call shape, so the first actual prefill can still JIT the ragged MQA logits kernel on the hot path.
Useful? React with 👍 / 👎.
Problem
After #398, the V4-Flash startup warmup still JIT-compiled deep_gemm cubins inline on the serving hot path, which can stall the engine past the gRPC health probe (the failure class #398 set out to fix). Two residual sources:
GemmType::NormalFP8 projection GEMMs (e.g. N=1536, K=4096) whoseblock_mthe previous wave-boundary M-enumerator under-sampled.wo_aattention output projection — a per-groupGemmType::BatchedGEMM run viadeep_gemm.fp8_einsum— whichwarmup_fp8_gemm_nt_from_modelskipped asis_bmm.Root cause
A cubin's JIT key includes
block_m,block_n, theswap_abflag (andnum_groupsfor batched GEMMs), all chosen by a C++ heuristic (get_best_configin_C.so) that deep_gemm does not expose to Python. Modelling that selection in Python leaks configs — e.g. after warming the wave boundaries, a runtime call at M=257 still picksblock_m=32, swap_ab=1(a new cubin). The transition points are linear wave boundaries, not geometric, so any sparse heuristic mis-samples them.Fix
_optimal_warmup_m_values/_ceil_divwith a dense_warmup_m_values(max_tokens): M in[1, 2048]step 1, then step 16 up to max. Let the real heuristic pick at each M and compile whatever it selects — provably covering every config the serving path (M in[1, max]) can reach. Tiling changes most often at small M (swap_abflips), hence step 1 there; it is stable above 2048.warmup_fp8_einsumto cover thewo_abatched einsum. It reuses the loaded weight/scale and synthesises only the activation, in the exact layout produced bydeepseek_v4_fused_inv_rope_fp8_quant.warmup_fp8_gemm_nt_from_modelnow warmsis_bmmmodules instead of skipping them._deepseek_v4_mega_moe_max_num_tokens(), the same ceiling mega_moe uses) to the prefill/post-quant warmups instead of a hardcoded 8192, so--chunked-prefill-size > 8192is covered._token_count_sweep.Validation (B200, V4-Flash, cold
DG_JIT_CACHE_DIR)fp8_gemm_ntat every M in1..8192across all 6 Normal shapes → 0 inline-JIT cubins. Dense prefill sweep exercising thewo_aeinsum → 0 inline-JIT cubins. (Inline-JIT count: 17 → 0.)🤖 Generated with Claude Code