Skip to content

Commit 403b9fa

Browse files
committed
fix: bypass FP8 linear replacement for unaligned layers to prevent ROCm scaled MM kernel crashes
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
1 parent 7a63f39 commit 403b9fa

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

vllm_omni/diffusion/models/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,15 @@ def _recursive_replace(module: nn.Module, prefix: str):
8787
# Replace modules as needed
8888
if isinstance(child_module, nn.Linear):
8989
style = "replicate"
90-
new_module = replace_linear_class(child_module, style, quant_config, prefix=qual_name)
90+
# AMD ROCm FP8 kernels require K (in_features) and N (out_features) to be divisible by 16.
91+
# If they are not divisible, bypass FP8 replacement for this layer to avoid runtime errors.
92+
is_fp8 = (
93+
quant_config is not None and hasattr(quant_config, "get_name") and quant_config.get_name() == "fp8"
94+
)
95+
layer_quant_config = quant_config
96+
if is_fp8 and (child_module.in_features % 16 != 0 or child_module.out_features % 16 != 0):
97+
layer_quant_config = None
98+
new_module = replace_linear_class(child_module, style, layer_quant_config, prefix=qual_name)
9199
else:
92100
_recursive_replace(child_module, prefix=qual_name)
93101
if new_module is not child_module:

0 commit comments

Comments
 (0)