@@ -204,6 +204,7 @@ def _fit(self, image, weights=None, plot=False):
204204 astropy_model = issubclass (self .model_class , Model )
205205 except TypeError :
206206 astropy_model = False
207+ self ._astropy_model = astropy_model
207208
208209 # Convert user regions to a Boolean mask for slicing:
209210 user_reg = ~ at .create_mask_from_regions (points , self ._slices )
@@ -262,6 +263,7 @@ def _fit(self, image, weights=None, plot=False):
262263 # Treat single model specially because FittingWithOutlierRemoval
263264 # fails for "1-model sets" and it should be more efficient anyway:
264265 image_to_fit = image
266+
265267 if image .ndim == 1 :
266268 n_models = 1
267269 elif image .mask is np .ma .nomask :
@@ -276,43 +278,47 @@ def _fit(self, image, weights=None, plot=False):
276278 if weights is not None :
277279 weights = weights [:, self ._good_cols ]
278280
279- model_set = self .model_class (
280- degree = self .order , n_models = n_models ,
281- domain = self .domain ,
282- model_set_axis = (None if n_models == 1 else 1 ),
283- ** self .model_args
284- )
285-
286- # Configure iterative linear fitter with rejection:
287- fitter = fitting .FittingWithOutlierRemoval (
288- fitting .LinearLSQFitter (),
289- sigma_clip ,
290- niter = self .niter ,
291- # additional args are passed to outlier_func, i.e. sigma_clip
292- sigma_lower = self .sigma_lower ,
293- sigma_upper = self .sigma_upper ,
294- maxiters = 1 ,
295- cenfunc = 'mean' ,
296- stdfunc = 'std' ,
297- grow = self .grow # requires AstroPy 4.2 (#10613)
298- )
299-
300- # Fit the pixel data with rejection of outlying points:
301- fitted_models , fitted_mask = fitter (
302- model_set ,
303- points [user_reg ], image_to_fit [user_reg ],
304- weights = None if weights is None else weights [user_reg ]
305- )
306- self .fit_info = fitter .fit_info
307-
308- # Incorporate mask for fitted columns into the full-sized mask:
309- if image .ndim > 1 and n_models < image .shape [1 ]:
310- # this is quite ugly, but seems the best way to assign to an
311- # array with a mask on both dimensions. This is equivalent to:
312- # mask[user_reg, masked_cols] = fitted_mask
313- mask [user_reg [:, None ] & self ._good_cols ] = fitted_mask .flat
281+ if n_models > 0 :
282+ model_set = self .model_class (
283+ degree = self .order , n_models = n_models ,
284+ domain = self .domain ,
285+ model_set_axis = (None if n_models == 1 else 1 ),
286+ ** self .model_args
287+ )
288+
289+ # Configure iterative linear fitter with rejection:
290+ fitter = fitting .FittingWithOutlierRemoval (
291+ fitting .LinearLSQFitter (),
292+ sigma_clip ,
293+ niter = self .niter ,
294+ # additional args are passed to outlier_func, i.e. sigma_clip
295+ sigma_lower = self .sigma_lower ,
296+ sigma_upper = self .sigma_upper ,
297+ maxiters = 1 ,
298+ cenfunc = 'mean' ,
299+ stdfunc = 'std' ,
300+ grow = self .grow # requires AstroPy 4.2 (#10613)
301+ )
302+
303+ # Fit the pixel data with rejection of outlying points:
304+ fitted_models , fitted_mask = fitter (
305+ model_set ,
306+ points [user_reg ], image_to_fit [user_reg ],
307+ weights = None if weights is None else weights [user_reg ]
308+ )
309+ self .fit_info = fitter .fit_info
310+
311+ # Incorporate mask for fitted columns into the full-sized mask:
312+ if image .ndim > 1 and n_models < image .shape [1 ]:
313+ # this is quite ugly, but seems the best way to assign to an
314+ # array with a mask on both dimensions. This is equivalent to:
315+ # mask[user_reg, masked_cols] = fitted_mask
316+ mask [user_reg [:, None ] & self ._good_cols ] = fitted_mask .flat
317+ else :
318+ mask [user_reg ] = fitted_mask
319+
314320 else :
315- mask [ user_reg ] = fitted_mask
321+ fitted_models = []
316322
317323 else :
318324 #max_order = len(points) - self.model_args["k"]
@@ -359,20 +365,23 @@ def _fit(self, image, weights=None, plot=False):
359365 # Convert the mask to the ordering & shape of the input array and
360366 # save it. Calculate rms. Suppress warnings for no data points
361367 mask = mask .reshape (self ._tmpshape )
362- with np .errstate (invalid = "ignore" , divide = "ignore" ):
363- if astropy_model :
364- start = (self .axis + 1 ) or mask .ndim
365- self .mask = np .rollaxis (mask , 0 , start )
366- rms = (np .rollaxis (image .reshape (self ._tmpshape ), 0 , start ) -
367- self .evaluate ())[~ self .mask ].std ()
368- else :
369- self .mask = np .rollaxis (mask , - 1 , self .axis )
370- rms = (np .rollaxis (image .reshape (self ._tmpshape ), - 1 , self .axis ) -
371- self .evaluate ())[~ self .mask ].std ()
372- self .rms = rms if rms is not np .ma .masked else np .nan
368+ if len (fitted_models ) > 0 :
369+ with np .errstate (invalid = "ignore" , divide = "ignore" ):
370+ if astropy_model :
371+ start = (self .axis + 1 ) or mask .ndim
372+ self .mask = np .rollaxis (mask , 0 , start )
373+ rms = (np .rollaxis (image .reshape (self ._tmpshape ), 0 , start ) -
374+ self .evaluate ())[~ self .mask ].std ()
375+ else :
376+ self .mask = np .rollaxis (mask , - 1 , self .axis )
377+ rms = (np .rollaxis (image .reshape (self ._tmpshape ), - 1 , self .axis ) -
378+ self .evaluate ())[~ self .mask ].std ()
379+ self .rms = rms if rms is not np .ma .masked else np .nan
380+ else :
381+ self .rms = np .nan
373382
374383 # Plot the fit:
375- if plot :
384+ if plot and len ( fitted_models ) > 0 :
376385 self ._plot (origim , index = None if plot is True else plot )
377386
378387 # Basic plot for debugging/inspection (interactive plotting will be handled
@@ -446,9 +455,8 @@ def evaluate(self, points=None):
446455 input `image` to which the fit was performed along any other axes.
447456
448457 """
449- astropy_model = isinstance (self ._models , Model )
450458
451- tmpaxis = 0 if astropy_model else - 1
459+ tmpaxis = 0 if self . _astropy_model else - 1
452460
453461 # Determine how to reproduce the correct array shape, orientation and
454462 # sampling from flattened model output:
@@ -471,23 +479,24 @@ def evaluate(self, points=None):
471479 fitvals = np .zeros (stack_shape , dtype = self ._dtype )
472480
473481 # Determine the model values we want to return:
474- if astropy_model :
475- if fitvals .ndim > 1 and len (self ._models ) < fitvals .shape [1 ]:
476- # If we removed bad columns, we now need to fill them properly
477- # in the output array
478- fitvals [:, self ._good_cols ] = self ._models (points ,
479- model_set_axis = False )
482+ if np .sum (self ._good_cols ) > 0 : # skip if no good columns
483+ if self ._astropy_model :
484+ if fitvals .ndim > 1 and len (self ._models ) < fitvals .shape [1 ]:
485+ # If we removed bad columns, we now need to fill them properly
486+ # in the output array
487+ fitvals [:, self ._good_cols ] = self ._models (points ,
488+ model_set_axis = False )
489+ else :
490+ fitvals [:] = self ._models (points , model_set_axis = False )
480491 else :
481- fitvals [:] = self ._models (points , model_set_axis = False )
482- else :
483- for n , single_model in enumerate (self ._models ):
484- # Determine model values to be returned (see comment in _fit
485- # about discarding values stored in the spline object):
486- fitvals [n ] = single_model (points )
492+ for n , single_model in enumerate (self ._models ):
493+ # Determine model values to be returned (see comment in _fit
494+ # about discarding values stored in the spline object):
495+ fitvals [n ] = single_model (points )
487496
488497 # Restore the ordering & shape of the original input array:
489498 fitvals = fitvals .reshape (tmpshape )
490- if astropy_model :
499+ if self . _astropy_model :
491500 fitvals = np .rollaxis (fitvals , 0 , (self .axis + 1 ) or fitvals .ndim )
492501 else :
493502 fitvals = np .rollaxis (fitvals , - 1 , self .axis )
@@ -508,8 +517,8 @@ def model(self):
508517 if len (self ._models ) > 1 :
509518 raise ValueError ("Cannot provide model property if there are "
510519 "greater than one models." )
511- astropy_model = isinstance ( self . _models , Model )
512- if astropy_model :
520+
521+ if self . _astropy_model :
513522 return self ._models
514523 else :
515524 return self ._models [0 ]
@@ -523,8 +532,8 @@ def offset_fit(self, offset):
523532 offset: float
524533 amount by which all fits are to be shifted
525534 """
526- astropy_model = isinstance ( self . _models , Model )
527- if astropy_model :
535+
536+ if self . _astropy_model :
528537 self ._models .c0 += offset
529538 else :
530539 for spline in self ._models :
0 commit comments