File tree Expand file tree Collapse file tree
vllm_omni/diffusion/models Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -87,7 +87,15 @@ def _recursive_replace(module: nn.Module, prefix: str):
8787 # Replace modules as needed
8888 if isinstance (child_module , nn .Linear ):
8989 style = "replicate"
90- new_module = replace_linear_class (child_module , style , quant_config , prefix = qual_name )
90+ # AMD ROCm FP8 kernels require K (in_features) and N (out_features) to be divisible by 16.
91+ # If they are not divisible, bypass FP8 replacement for this layer to avoid runtime errors.
92+ is_fp8 = (
93+ quant_config is not None and hasattr (quant_config , "get_name" ) and quant_config .get_name () == "fp8"
94+ )
95+ layer_quant_config = quant_config
96+ if is_fp8 and (child_module .in_features % 16 != 0 or child_module .out_features % 16 != 0 ):
97+ layer_quant_config = None
98+ new_module = replace_linear_class (child_module , style , layer_quant_config , prefix = qual_name )
9199 else :
92100 _recursive_replace (child_module , prefix = qual_name )
93101 if new_module is not child_module :
You can’t perform that action at this time.
0 commit comments