Skip to content

Commit 4d0d32c

Browse files
perf(deepseek-v4): dense deep_gemm warmup M-sweep + fp8_einsum coverage (#427)
1 parent 38b3a35 commit 4d0d32c

2 files changed

Lines changed: 189 additions & 164 deletions

File tree

  • python/tokenspeed/runtime/models
  • tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/deep_gemm

python/tokenspeed/runtime/models/deepseek_v4.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4189,14 +4189,19 @@ def _warmup_prefill_jit(self) -> None:
41894189
),
41904190
mxfp4_block_size=DEEPSEEK_V4_MXFP4_BLOCK_SIZE,
41914191
tp_size=tp_size,
4192-
max_tokens=min(getattr(config, "max_position_embeddings", 8192), 8192),
4192+
# Prefill GEMM/prenorm M is capped per forward by chunked_prefill_size
4193+
# (continuous batching), the same ceiling mega_moe warms to. Hardcoding
4194+
# 8192 would leave M in (8192, chunked_prefill_size] to JIT inline.
4195+
max_tokens=_deepseek_v4_mega_moe_max_num_tokens(),
41934196
device=torch.device("cuda", torch.cuda.current_device()),
41944197
)
41954198

41964199
def post_quant_warmup(self) -> None:
41974200
"""Called by the weight loader after all quant process_weights_after_loading."""
41984201
if deep_gemm is not None:
4199-
deep_gemm.warmup_fp8_gemm_nt_from_model(self)
4202+
deep_gemm.warmup_fp8_gemm_nt_from_model(
4203+
self, max_tokens=_deepseek_v4_mega_moe_max_num_tokens()
4204+
)
42004205

42014206
@classmethod
42024207
def get_model_config_for_expert_location(cls, config):

0 commit comments

Comments
 (0)