@@ -233,24 +233,51 @@ def handle_tanh(codegen, n, out, input_names, input_args):
233233 codegen .add_forward_call ("MiCo_tanh{dim}d_{dtype}" , out , n .name , input_names )
234234
235235
236+ def _extract_scalar_param (param , param_name , default = None ):
237+ """Extract a scalar C API parameter from PyTorch int/tuple pooling args."""
238+ if param is None :
239+ if default is None :
240+ raise ValueError (f"{ param_name } cannot be None" )
241+ param = default
242+
243+ if isinstance (param , torch .fx .node .Node ):
244+ raise ValueError (f"Unresolved FX node for { param_name } : { param } " )
245+
246+ if isinstance (param , torch .Size ):
247+ param = tuple (param )
248+
249+ if isinstance (param , (tuple , list )):
250+ if len (param ) == 0 :
251+ raise ValueError (f"{ param_name } cannot be empty" )
252+ first = param [0 ]
253+ if any (value != first for value in param ):
254+ raise NotImplementedError (
255+ f"MiCo C pooling kernels only support scalar/symmetric { param_name } , got { param } "
256+ )
257+ param = first
258+
259+ if isinstance (param , bool ) or not isinstance (param , int ):
260+ raise ValueError (f"Unexpected { param_name } type: { type (param )} " )
261+ return param
262+
263+
236264def _extract_kernel_size (param ):
237- """Helper to extract kernel size from tuple or int parameter."""
238- if isinstance (param , Tuple ):
239- return param [0 ]
240- elif isinstance (param , int ):
241- return param
242- else :
243- raise ValueError (f"Unexpected kernel_size type: { type (param )} " )
265+ """Helper to extract scalar kernel size for the C pooling API."""
266+ return _extract_scalar_param (param , "kernel_size" )
244267
245268
246269def _extract_output_size (param ):
247- """Helper to extract output size from tuple or int parameter."""
248- if isinstance (param , Tuple ):
249- return param [0 ]
250- elif isinstance (param , int ):
251- return param
270+ """Helper to extract scalar output size for the C adaptive pooling API."""
271+ return _extract_scalar_param (param , "output_size" )
272+
273+
274+ def _pool_arg (n , input_args , index , name , default = None ):
275+ """Read pooling arg from positional or keyword FX args and normalize it."""
276+ if len (input_args ) > index :
277+ value = input_args [index ]
252278 else :
253- raise ValueError (f"Unexpected output_size type: { type (param )} " )
279+ value = n .kwargs .get (name , default )
280+ return _extract_scalar_param (value , name , default )
254281
255282
256283@MiCoOpRegistry .register_function (torch .nn .functional .linear )
@@ -275,20 +302,22 @@ def handle_linear(codegen, n, out, input_names, input_args):
275302def handle_avg_pool2d (codegen , n , out , input_names , input_args ):
276303 """Handler for 2D average pooling function."""
277304 codegen .add_uninitialized_tensor (n .name , out )
278- kernel_size = _extract_kernel_size (input_args [1 ])
279- stride = input_args [2 ] if len (input_args ) > 2 else 1
305+ kernel_size = _pool_arg (n , input_args , 1 , "kernel_size" )
306+ stride = _pool_arg (n , input_args , 2 , "stride" , kernel_size )
307+ padding = _pool_arg (n , input_args , 3 , "padding" , 0 )
280308 codegen .add_forward_call ("MiCo_avgpool{dim}d_{dtype}" , out , n .name , input_names ,
281- [kernel_size , stride ])
309+ [kernel_size , stride , padding ])
282310
283311
284312@MiCoOpRegistry .register_function (torch .nn .functional .max_pool2d )
285313def handle_max_pool2d (codegen , n , out , input_names , input_args ):
286314 """Handler for 2D max pooling function."""
287315 codegen .add_uninitialized_tensor (n .name , out )
288- kernel_size = _extract_kernel_size (input_args [1 ])
289- stride = input_args [2 ] if len (input_args ) > 2 else 1
316+ kernel_size = _pool_arg (n , input_args , 1 , "kernel_size" )
317+ stride = _pool_arg (n , input_args , 2 , "stride" , kernel_size )
318+ padding = _pool_arg (n , input_args , 3 , "padding" , 0 )
290319 codegen .add_forward_call ("MiCo_maxpool{dim}d_{dtype}" , out , n .name , input_names ,
291- [kernel_size , stride ])
320+ [kernel_size , stride , padding ])
292321
293322
294323@MiCoOpRegistry .register_function (torch .nn .functional .adaptive_avg_pool2d )
@@ -303,20 +332,22 @@ def handle_adaptive_avg_pool2d(codegen, n, out, input_names, input_args):
303332def handle_avg_pool1d (codegen , n , out , input_names , input_args ):
304333 """Handler for 1D average pooling function."""
305334 codegen .add_uninitialized_tensor (n .name , out )
306- kernel_size = _extract_kernel_size (input_args [1 ])
307- stride = input_args [2 ] if len (input_args ) > 2 else 1
335+ kernel_size = _pool_arg (n , input_args , 1 , "kernel_size" )
336+ stride = _pool_arg (n , input_args , 2 , "stride" , kernel_size )
337+ padding = _pool_arg (n , input_args , 3 , "padding" , 0 )
308338 codegen .add_forward_call ("MiCo_avgpool{dim}d_{dtype}" , out , n .name , input_names ,
309- [kernel_size , stride ])
339+ [kernel_size , stride , padding ])
310340
311341
312342@MiCoOpRegistry .register_function (torch .nn .functional .max_pool1d )
313343def handle_max_pool1d (codegen , n , out , input_names , input_args ):
314344 """Handler for 1D max pooling function."""
315345 codegen .add_uninitialized_tensor (n .name , out )
316- kernel_size = _extract_kernel_size (input_args [1 ])
317- stride = input_args [2 ] if len (input_args ) > 2 else 1
346+ kernel_size = _pool_arg (n , input_args , 1 , "kernel_size" )
347+ stride = _pool_arg (n , input_args , 2 , "stride" , kernel_size )
348+ padding = _pool_arg (n , input_args , 3 , "padding" , 0 )
318349 codegen .add_forward_call ("MiCo_maxpool{dim}d_{dtype}" , out , n .name , input_names ,
319- [kernel_size , stride ])
350+ [kernel_size , stride , padding ])
320351
321352
322353@MiCoOpRegistry .register_function (torch .nn .functional .adaptive_avg_pool1d )
@@ -550,8 +581,10 @@ def handle_avgpool2d_module(codegen, n, out, module, input_names):
550581 layer_name = n .name
551582 codegen .add_uninitialized_tensor (layer_name , out )
552583 kernel_size = _extract_kernel_size (module .kernel_size )
584+ stride = _extract_scalar_param (module .stride , "stride" , kernel_size )
585+ padding = _extract_scalar_param (module .padding , "padding" , 0 )
553586 codegen .add_forward_call ("MiCo_avgpool{dim}d_{dtype}" , out , layer_name , input_names ,
554- [kernel_size , module . stride , module . padding ])
587+ [kernel_size , stride , padding ])
555588
556589
557590@MiCoOpRegistry .register_module (torch .nn .MaxPool2d )
@@ -560,8 +593,10 @@ def handle_maxpool2d_module(codegen, n, out, module, input_names):
560593 layer_name = n .name
561594 codegen .add_uninitialized_tensor (layer_name , out )
562595 kernel_size = _extract_kernel_size (module .kernel_size )
596+ stride = _extract_scalar_param (module .stride , "stride" , kernel_size )
597+ padding = _extract_scalar_param (module .padding , "padding" , 0 )
563598 codegen .add_forward_call ("MiCo_maxpool{dim}d_{dtype}" , out , layer_name , input_names ,
564- [kernel_size , module . stride , module . padding ])
599+ [kernel_size , stride , padding ])
565600
566601
567602@MiCoOpRegistry .register_module (torch .nn .AdaptiveAvgPool2d )
@@ -579,8 +614,10 @@ def handle_avgpool1d_module(codegen, n, out, module, input_names):
579614 layer_name = n .name
580615 codegen .add_uninitialized_tensor (layer_name , out )
581616 kernel_size = _extract_kernel_size (module .kernel_size )
617+ stride = _extract_scalar_param (module .stride , "stride" , kernel_size )
618+ padding = _extract_scalar_param (module .padding , "padding" , 0 )
582619 codegen .add_forward_call ("MiCo_avgpool{dim}d_{dtype}" , out , layer_name , input_names ,
583- [kernel_size , module . stride , module . padding ])
620+ [kernel_size , stride , padding ])
584621
585622
586623@MiCoOpRegistry .register_module (torch .nn .MaxPool1d )
@@ -589,8 +626,10 @@ def handle_maxpool1d_module(codegen, n, out, module, input_names):
589626 layer_name = n .name
590627 codegen .add_uninitialized_tensor (layer_name , out )
591628 kernel_size = _extract_kernel_size (module .kernel_size )
629+ stride = _extract_scalar_param (module .stride , "stride" , kernel_size )
630+ padding = _extract_scalar_param (module .padding , "padding" , 0 )
592631 codegen .add_forward_call ("MiCo_maxpool{dim}d_{dtype}" , out , layer_name , input_names ,
593- [kernel_size , module . stride , module . padding ])
632+ [kernel_size , stride , padding ])
594633
595634
596635@MiCoOpRegistry .register_module (torch .nn .AdaptiveAvgPool1d )
0 commit comments