-
Notifications
You must be signed in to change notification settings - Fork 166
perf(deepseek-v4): pre-compile deep_gemm JIT kernels at startup #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
bceb897
f6f4124
53ccbcb
46b736d
2338d0f
3672551
eda4040
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -400,6 +400,10 @@ def load_model( | |
| with device_loading_context(module, target_device): | ||
| module.process_weights_after_loading(module) | ||
|
|
||
| post_quant_warmup = getattr(model, "post_quant_warmup", None) | ||
| if callable(post_quant_warmup): | ||
| post_quant_warmup() | ||
|
|
||
| return model.eval() | ||
|
|
||
|
|
||
|
|
@@ -460,6 +464,10 @@ def load_model( | |
| if process_method is not None: | ||
| module.process_weights_after_loading(module) | ||
|
|
||
| post_quant_warmup = getattr(model, "post_quant_warmup", None) | ||
| if callable(post_quant_warmup): | ||
| post_quant_warmup() | ||
|
|
||
| # For accurate performance evaluation, we assign | ||
| # random values to the weights. | ||
| initialize_dummy_weights(model) | ||
|
|
@@ -603,6 +611,11 @@ def load_model( | |
| state_dict.pop(key) | ||
| if state_dict: | ||
| raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") | ||
|
|
||
| post_quant_warmup = getattr(model, "post_quant_warmup", None) | ||
| if callable(post_quant_warmup): | ||
| post_quant_warmup() | ||
|
Comment on lines
+615
to
+617
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When loading DeepSeek V4 from sharded-state checkpoints, this loader never calls Useful? React with 👍 / 👎. |
||
|
|
||
| return model.eval() | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When CPU offloading is enabled, the preceding
device_loading_contextblocks intentionally move CPU parameters totarget_deviceonly for processing and then restore them to CPU before this new hook runs.post_quant_warmup()eventually usesnext(model.parameters()).deviceto allocate the DeepGEMM warmup tensors, so an offloaded DeepSeek V4 FP8 load can try to invoke CUDA DeepGEMM with CPU tensors and fail during startup instead of just warming the kernels.Useful? React with 👍 / 👎.