Skip to content
13 changes: 13 additions & 0 deletions python/tokenspeed/runtime/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def load_model(
with device_loading_context(module, target_device):
module.process_weights_after_loading(module)

post_quant_warmup = getattr(model, "post_quant_warmup", None)
if callable(post_quant_warmup):
post_quant_warmup()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Run post-quant warmup on the target device

When CPU offloading is enabled, the preceding device_loading_context blocks intentionally move CPU parameters to target_device only for processing and then restore them to CPU before this new hook runs. post_quant_warmup() eventually uses next(model.parameters()).device to allocate the DeepGEMM warmup tensors, so an offloaded DeepSeek V4 FP8 load can try to invoke CUDA DeepGEMM with CPU tensors and fail during startup instead of just warming the kernels.

Useful? React with 👍 / 👎.


return model.eval()


Expand Down Expand Up @@ -460,6 +464,10 @@ def load_model(
if process_method is not None:
module.process_weights_after_loading(module)

post_quant_warmup = getattr(model, "post_quant_warmup", None)
if callable(post_quant_warmup):
post_quant_warmup()

# For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
Expand Down Expand Up @@ -603,6 +611,11 @@ def load_model(
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")

post_quant_warmup = getattr(model, "post_quant_warmup", None)
if callable(post_quant_warmup):
post_quant_warmup()
Comment on lines +615 to +617

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Invoke DeepGEMM model warmup for sharded loads

When loading DeepSeek V4 from sharded-state checkpoints, this loader never calls model.load_weights(), and the only call sites for warmup_deep_gemm() / _warmup_prefill_jit() are inside DeepseekV4ForCausalLM.load_weights (checked with rg warmup_deep_gemm). This new hook only runs post_quant_warmup(), which warms FP8 linear GEMMs, so sharded loads still skip the new prefill and MegaMoE DeepGEMM startup warmups and can hit first-request JITs despite this commit’s warmup path.

Useful? React with 👍 / 👎.


return model.eval()


Expand Down
141 changes: 68 additions & 73 deletions python/tokenspeed/runtime/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@

from tokenspeed.runtime.configs.deepseek_v4_cache_spec import (
DEEPSEEK_V4_MXFP4_BLOCK_SIZE,
V4_KERNEL_BLOCK_ROWS,
deepseek_v4_indexer_mxfp4_layout_from_row_bytes,
deepseek_v4_indexer_mxfp4_scale_dim,
deepseek_v4_indexer_mxfp4_value_bytes,
Expand Down Expand Up @@ -2387,71 +2388,6 @@ def forward(
)
return y

def warmup_jit_variants(self) -> None:
"""Pre-compile every DeepGEMM mega-MoE tile this layer can hit at runtime.

DeepGEMM JITs ``fp8_fp4_mega_moe`` on first use and selects the tile
(``block_m``) from the per-rank token count -- roughly
``num_tokens * num_ranks * num_topk / num_experts`` mapped to a
``block_m`` bucket (DeepGEMM ``heuristics/mega_moe.hpp``). CUDA-graph
capture only exercises decode-sized token counts, so the larger prefill
tiles would otherwise JIT mid-serving -- and because the kernel arms its
NVLink EP barrier on the same launch, a cold compile stalls one rank
past DeepGEMM's 30 s barrier timeout and aborts the job.

Sweeping a few token counts that cross every ``block_m`` boundary
compiles each tile at startup with all EP ranks in lock-step, so serving
never JITs on the hot path. The kernels are cached by shape + tile (not
by layer or activation values), so dummy activations on one module
cover the whole stack.
"""
if deep_gemm is None or self._transformed_l1_weights is None:
return

device = self._transformed_l1_weights[0].device
cap = self.max_num_tokens
# Token counts spanning the block_m buckets up to the symmetric-buffer
# capacity. Decode-sized tiles are also covered by CUDA-graph capture;
# the small counts here keep the warmup self-contained.
token_counts = sorted(
{n for n in (16, 32, 64, 128, 256, 512, 1024, 2048, 4096) if n < cap}
| {cap}
)

# The launch runs an NVLink barrier across the EP group, so every rank
# must enter the sweep together; align them once after weight load so a
# straggler cannot trip the barrier's 30 s timeout.
if torch.distributed.is_initialized():
group = pg_manager.get_process_group("nccl", self.mapping.moe.tp_ep_group)
torch.distributed.barrier(group=group)

for num_tokens in token_counts:
hidden_states = torch.randn(
(num_tokens, self.hidden_size),
dtype=torch.bfloat16,
device=device,
)
topk_weights = torch.full(
(num_tokens, self.top_k),
1.0 / self.top_k,
dtype=torch.float32,
device=device,
)
topk_ids = torch.randint(
0,
self.num_experts,
(num_tokens, self.top_k),
dtype=torch.int64,
device=device,
)
self.forward(
hidden_states,
topk_weights,
topk_ids,
activation_clamp=self.swiglu_limit,
)
torch.cuda.synchronize()


class DeepseekV4MoE(nn.Module):
def __init__(
Expand Down Expand Up @@ -4164,7 +4100,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)
del params_dict, moe_loader
self.post_load_weights()
self.warmup_mega_moe()
self.warmup_deep_gemm()

def post_load_weights(self):
mega_moe_experts: list[DeepseekV4MegaMoEExperts] = []
Expand All @@ -4177,24 +4113,83 @@ def post_load_weights(self):
elif isinstance(module, MoELayer):
module.process_weights_after_loading(module)

def warmup_mega_moe(self) -> None:
"""Pre-compile DeepGEMM mega-MoE tiles.
def warmup_deep_gemm(self) -> None:
"""Pre-compile all DeepGEMM JIT kernels used by this model.

Called after post_load_weights (not inside it) so that
finalize_weights temporaries have been freed by GC and there is
enough GPU memory for the symmetric buffer allocation.
"""
if os.environ.get("TOKENSPEED_DISABLE_MEGA_MOE_WARMUP") == "1":
if deep_gemm is None:
return
import gc

gc.collect()
torch.cuda.empty_cache()

self._warmup_mega_moe_jit()
self._warmup_prefill_jit()

def _warmup_mega_moe_jit(self) -> None:
if os.environ.get("TOKENSPEED_DISABLE_MEGA_MOE_WARMUP") == "1":
return
for module in self.modules():
if isinstance(module, DeepseekV4MegaMoEExperts):
logger.info("Pre-compiling DeepGEMM mega-MoE kernel variants...")
module.warmup_jit_variants()
return
if not isinstance(module, DeepseekV4MegaMoEExperts):
continue
if module._transformed_l1_weights is None:
continue
logger.info("Pre-compiling DeepGEMM mega-MoE kernel variants...")
group = pg_manager.get_process_group(
"nccl",
module.mapping.moe.tp_ep_group,
)
if torch.distributed.is_initialized():
torch.distributed.barrier(group=group)
deep_gemm.warmup_mega_moe_jit(
num_experts=module.num_experts,
max_num_tokens=module.max_num_tokens,
top_k=module.top_k,
hidden_size=module.hidden_size,
device=torch.device("cuda", torch.cuda.current_device()),
transformed_l1_weights=module._transformed_l1_weights,
transformed_l2_weights=module._transformed_l2_weights,
symm_buffer=module.get_symm_buffer(),
activation_clamp=module.swiglu_limit,
)
return

def _warmup_prefill_jit(self) -> None:
if deep_gemm is None:
return
if torch.cuda.get_device_capability()[0] < 10:
return
config = self.config
tp_size = self.mapping.attn.tp_size if self.mapping else 1
logger.info("Pre-compiling DeepGEMM prefill kernel variants...")
deep_gemm.warmup_prefill_jit(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
head_dim=getattr(config, "head_dim", 128),
hc_mult=getattr(config, "hc_mult", 0),
kv_lora_rank=getattr(config, "kv_lora_rank", 0),
index_n_heads=getattr(config, "index_n_heads", 0),
index_head_dim=getattr(config, "index_head_dim", 0),
indexer_cache_block_size=V4_KERNEL_BLOCK_ROWS,
max_decode_tokens=max(
int(global_server_args_dict.get("max_cudagraph_capture_size", 0) or 0),
int(global_server_args_dict.get("max_num_seqs", 0) or 0),
1,
),
mxfp4_block_size=DEEPSEEK_V4_MXFP4_BLOCK_SIZE,
tp_size=tp_size,
max_tokens=min(getattr(config, "max_position_embeddings", 8192), 8192),
device=torch.device("cuda", torch.cuda.current_device()),
)

def post_quant_warmup(self) -> None:
"""Called by the weight loader after all quant process_weights_after_loading."""
if deep_gemm is not None:
deep_gemm.warmup_fp8_gemm_nt_from_model(self)

@classmethod
def get_model_config_for_expert_location(cls, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def _prepare_deep_gemm_cuda_home() -> None:
transform_sf_into_required_layout,
transform_weights_for_mega_moe,
)
from tokenspeed_kernel.thirdparty.deep_gemm.warmup import (
warmup_fp8_gemm_nt,
warmup_fp8_gemm_nt_from_model,
warmup_mega_moe_jit,
warmup_prefill_jit,
)

__all__ = [
"ceil_div",
Expand All @@ -109,4 +115,8 @@ def _prepare_deep_gemm_cuda_home() -> None:
"get_paged_mqa_logits_metadata",
"fp8_paged_mqa_logits",
"fp8_mqa_logits",
"warmup_fp8_gemm_nt",
"warmup_fp8_gemm_nt_from_model",
"warmup_mega_moe_jit",
"warmup_prefill_jit",
]
Loading
Loading