Skip to content

Commit 0bc2e56

Browse files
committed
Various work on distortion modelling
Change min_line_length to be a fraction of the slit length Move GNIRS determineDistortion() wrapper to GNIRSSpect from GNIRSLongslit and handle XD Put distortion model *after* rectification model and calculate it correctly Allow peaks to be found in the non-linear regime
1 parent 7e0bd4b commit 0bc2e56

12 files changed

Lines changed: 310 additions & 217 deletions

geminidr/core/parameters_spect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ class determineDistortionConfig(config.Config):
145145
max_shift = config.RangeField("Maximum shift per pixel in line position",
146146
float, 0.05, min=0.001, max=0.1)
147147
max_missed = config.RangeField("Maximum number of steps to miss before a line is lost", int, 5, min=0)
148-
min_line_length = config.RangeField("Exclude line traces shorter than this fraction of spatial dimension",
149-
float, 0., min=0., max=1.)
148+
min_line_length = config.RangeField("Exclude line traces shorter than this fraction of slit length",
149+
float, 0.8, min=0., max=1.)
150150
debug_reject_bad = config.Field("Reject lines with suspiciously high SNR (e.g. bad columns)?", bool, True)
151151
debug = config.Field("Display line traces on image display?", bool, False)
152152

geminidr/core/primitives_spect.py

Lines changed: 101 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,67 +1473,105 @@ def determineDistortion(self, adinputs=None, **params):
14731473
initial_peaks, _ = peak_finding.find_wavelet_peaks(
14741474
data, widths=widths, mask=mask & DQ.not_signal,
14751475
variance=variance, min_snr=min_snr, reject_bad=debug_reject_bad)
1476-
log.stdinfo(f"Found {len(initial_peaks)} peaks")
14771476
# The coordinates are always returned as (x-coords, y-coords)
14781477
rwidth = 0.42466 * fwidth
14791478

1480-
# Straight slits, such as in longslit, can have all the lines
1481-
# traced simultaneously since they all have the same starting
1482-
# point. "Curved" slits need to be handled one-by-one. This is
1483-
# quite a bit slower, so this block of code does the line
1484-
# tracing based on the slit involved.
1485-
if constant_slit:
1486-
traces = tracing.trace_lines(
1487-
# Only need a single `start` value for all lines.
1488-
ext, axis=1 - dispaxis,
1489-
start=start, initial=initial_peaks,
1490-
rwidth=rwidth, cwidth=max(int(fwidth), 5), step=step,
1491-
nsum=nsum, max_missed=max_missed,
1492-
max_shift=max_shift * ybin / xbin,
1493-
viewer=self.viewer if debug else None,
1494-
min_line_length=min_line_length)
1495-
1496-
else:
1497-
traces = []
1498-
for peak in initial_peaks:
1499-
# Need to start midway along the slit, which varies
1500-
# along the dispersion axis. `extract_info` here is the
1501-
# polynomial describing that midway line.
1502-
start = extract_info(peak)
1503-
traces.extend(tracing.trace_lines(
1479+
if len(initial_peaks):
1480+
# The slit length may be smaller than the width of the slice,
1481+
# so we need to estimate the slit length and
1482+
# "min_line_length" is the fraction of that, not the fraction
1483+
# of the slice width.
1484+
if ext.mask is not None:
1485+
loc = int(np.median(initial_peaks))
1486+
if dispaxis == 0:
1487+
_slice = ext.mask[loc] & DQ.unilluminated
1488+
else:
1489+
_slice = ext.mask[:, loc] & DQ.unilluminated
1490+
slit_length_frac = 1 - ((_slice.argmin() +
1491+
_slice[::-1].argmin()) / _slice.size)
1492+
else:
1493+
try:
1494+
slit_length_frac = ad.MDF['slitlength_pixels'] / ext.shape[1 - dispaxis]
1495+
except (AttributeError, KeyError):
1496+
slit_length_frac = 1
1497+
1498+
# Straight slits, such as in longslit, can have all the lines
1499+
# traced simultaneously since they all have the same starting
1500+
# point. "Curved" slits need to be handled one-by-one. This is
1501+
# quite a bit slower, so this block of code does the line
1502+
# tracing based on the slit involved.
1503+
if constant_slit:
1504+
traces = tracing.trace_lines(
1505+
# Only need a single `start` value for all lines.
15041506
ext, axis=1 - dispaxis,
1505-
start=start, initial=[peak],
1507+
start=start, initial=initial_peaks,
15061508
rwidth=rwidth, cwidth=max(int(fwidth), 5), step=step,
15071509
nsum=nsum, max_missed=max_missed,
15081510
max_shift=max_shift * ybin / xbin,
15091511
viewer=self.viewer if debug else None,
1510-
min_line_length=0.1))
1512+
min_line_length=min_line_length*slit_length_frac)
15111513

1512-
# List of traced peak positions
1513-
in_coords = np.array([coord for trace in traces for
1514-
coord in trace.input_coordinates()]).T
1514+
else:
1515+
traces = []
1516+
for peak in initial_peaks:
1517+
# Need to start midway along the slit, which varies
1518+
# along the dispersion axis. `extract_info` here is the
1519+
# polynomial describing that midway line.
1520+
start = extract_info(peak)
1521+
traces.extend(tracing.trace_lines(
1522+
ext, axis=1 - dispaxis,
1523+
start=start, initial=[peak],
1524+
rwidth=rwidth, cwidth=max(int(fwidth), 5), step=step,
1525+
nsum=nsum, max_missed=max_missed,
1526+
max_shift=max_shift * ybin / xbin,
1527+
viewer=self.viewer if debug else None,
1528+
min_line_length=min_line_length*slit_length_frac))
1529+
1530+
log.stdinfo(f"Traced {len(traces)} lines from "
1531+
f"{len(initial_peaks)} peaks")
1532+
1533+
# List of traced peak positions
1534+
in_coords = np.array([coord for trace in traces for
1535+
coord in trace.input_coordinates()]).T
1536+
1537+
# We can't do anything if we have no coordinates
1538+
if in_coords.size == 0:
1539+
log.warning("Failed to trace any lines for "
1540+
f"{ad.filename}:{ext.id}")
1541+
continue
15151542

1516-
# We can't do anything if we have no coordinates
1517-
if in_coords.size == 0:
1518-
log.warning("Failed to trace any lines for "
1519-
f"{ad.filename}:{ext.id}")
1543+
else:
1544+
log.warning("Failed to find any peaks in "
1545+
"f{ad.filename}:{ext.id}")
15201546
continue
15211547

1522-
# If there's a "rectified" frame, we want to use the pixel
1523-
# coordinates in *that* frame as input so that the pixels
1524-
# -> rectified -> distortion_corrected transform works
1525-
# correctly.
1526-
try:
1527-
t = ext.wcs.get_transform(ext.wcs.input_frame, 'rectified')
1528-
except CoordinateFrameError:
1529-
pass
1530-
else:
1531-
in_coords = np.array(t(*in_coords))
15321548
# List of "reference" positions (i.e., the coordinate
15331549
# perpendicular to the line remains constant at its initial value
15341550
ref_coords = np.array([coord for trace in traces for
15351551
coord in trace.reference_coordinates()]).T
15361552

1553+
# If the frame has a rectification model, then we want to
1554+
# calculate the distortion transform *after* applying this
1555+
# model. This is important because, if we only have one line
1556+
# from which to determine the distortion, the model will only
1557+
# be a function of X and so we need a vertical spectrum.
1558+
try:
1559+
rect_model = ext.wcs.get_transform(ext.wcs.input_frame,
1560+
"rectified")
1561+
except CoordinateFrameError:
1562+
has_rect_model = False
1563+
else:
1564+
has_rect_model = True
1565+
in_coords = rect_model(*in_coords)
1566+
# For a vertically-dispersed spectrum, the rectification
1567+
# model alters X as a function of Y. Therefore, because
1568+
# incoords and ref_coords have different Y for the same X,
1569+
# putting them *both* through the rectification model will
1570+
# produce different X values, which is not what we want.
1571+
# S0 replace the X values in ref_coords with the ones in
1572+
# in_coords. Coords are *always* (x, y)
1573+
ref_coords[dispaxis] = in_coords[dispaxis]
1574+
15371575
# The model is computed entirely in the pixel coordinate frame
15381576
# of the data, so it could be used as a gWCS object
15391577
m_init = models.Chebyshev2D(x_degree=orders[1 - dispaxis],
@@ -1573,8 +1611,25 @@ def determineDistortion(self, adinputs=None, **params):
15731611
ext.wcs = gWCS([(cf.Frame2D(name="pixels"), model),
15741612
(cf.Frame2D(name="world"), None)])
15751613
else:
1576-
ext.wcs.insert_frame(ext.wcs.input_frame, model,
1577-
cf.Frame2D(name="distortion_corrected"))
1614+
try:
1615+
frame_index = ext.wcs.available_frames.index("distortion_corrected")
1616+
except ValueError:
1617+
pass
1618+
else:
1619+
log.warning("Deleting existing distortion model in "
1620+
f"{ad.filename}:{ext.id}")
1621+
ext.wcs = ext.wcs.__class__(
1622+
ext.wcs.pipeline[:frame_index-1] +
1623+
[(ext.wcs.pipeline[frame_index-1].frame,
1624+
ext.wcs.pipeline[frame_index].transform)] +
1625+
ext.wcs.pipeline[frame_index+1:]
1626+
)
1627+
if has_rect_model:
1628+
ext.wcs.insert_frame("rectified", model,
1629+
cf.Frame2D(name="distortion_corrected"))
1630+
else:
1631+
ext.wcs.insert_frame(ext.wcs.input_frame, model,
1632+
cf.Frame2D(name="distortion_corrected"))
15781633

15791634
nsuccess += 1
15801635

geminidr/f2/tests/longslit/test_determine_distortion.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"nsum": 10,
3838
"spatial_order": 3,
3939
"spectral_order": 3,
40-
"min_line_length": 0.3,
40+
"min_line_length": 0.8,
4141
"debug_reject_bad": False
4242
}
4343

@@ -186,7 +186,6 @@
186186
# Tests Definitions ------------------------------------------------------------
187187
@pytest.mark.f2ls
188188
@pytest.mark.preprocessed_data
189-
@pytest.mark.regression
190189
@pytest.mark.parametrize("ad,params", input_pars, indirect=['ad'])
191190
def test_regression_for_determine_distortion_using_wcs(
192191
ad, params, change_working_dir, ref_ad_factory):
@@ -214,18 +213,24 @@ def test_regression_for_determine_distortion_using_wcs(
214213
distortion_determined_ad = p.writeOutputs().pop()
215214

216215
ref_ad = ref_ad_factory(distortion_determined_ad.filename)
217-
model = distortion_determined_ad[0].wcs.get_transform(
218-
"pixels", "distortion_corrected")[2]
219-
ref_model = ref_ad[0].wcs.get_transform("pixels", "distortion_corrected")[2]
216+
217+
# Confirm that the distortion model is placed after the rectification model
218+
assert (distortion_determined_ad[0].wcs.available_frames.index("distortion_corrected") >
219+
distortion_determined_ad[0].wcs.available_frames.index("rectified"))
220+
# assert (ref_ad[0].wcs.available_frames.index("distortion_corrected") >
221+
# ref_ad[0].wcs.available_frames.index("rectified"))
222+
223+
model = distortion_determined_ad[0].wcs.get_transform("pixels", "distortion_corrected")
224+
ref_model = ref_ad[0].wcs.get_transform("pixels", "distortion_corrected")
220225

221226
# Otherwise we're doing something wrong!
222-
assert model.__class__.__name__ == ref_model.__class__.__name__ == "Chebyshev2D"
227+
assert model[-1].__class__.__name__ == ref_model[-1].__class__.__name__ == "Chebyshev2D"
223228

224-
X, Y = np.mgrid[:ad[0].shape[0], :ad[0].shape[1]]
229+
Y, X = np.mgrid[:ad[0].shape[0], :ad[0].shape[1]]
225230

226-
# Increasing atol to 0.07 due to S20180114S0104_flatCorrected.fits producing
227-
# slightly different results on Jenkins vs. on MacOS. DB 20240820
228-
np.testing.assert_allclose(model(X, Y), ref_model(X, Y), atol=0.07)
231+
xx, yy = X[ad[0].mask == 0], Y[ad[0].mask == 0]
232+
diffs = model(xx, yy)[1] - ref_model(xx, yy)[1] # 1 is y-axis in astropy
233+
np.testing.assert_allclose(diffs, 0, atol=1)
229234

230235

231236
@pytest.mark.f2ls
@@ -252,8 +257,9 @@ def test_fitcoord_table_and_gwcs_match(ad, params, change_working_dir):
252257
p.determineDistortion(**fixed_parameters_for_determine_distortion)
253258
distortion_determined_ad = p.writeOutputs().pop()
254259

255-
model = distortion_determined_ad[0].wcs.get_transform(
256-
"pixels", "distortion_corrected")
260+
model = distortion_determined_ad[0].wcs.pipeline[
261+
distortion_determined_ad[0].wcs.available_frames.index(
262+
"distortion_corrected") - 1].transform
257263

258264
fitcoord = distortion_determined_ad[0].FITCOORD
259265
fitcoord_model = am.table_to_model(fitcoord[0])

geminidr/gnirs/parameters_gnirs_crossdispersed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from geminidr.core import parameters_crossdispersed
55
from geminidr.core import parameters_standardize
66
from geminidr.core.parameters_standardize import addIllumMaskToDQConfig
7+
from . import parameters_gnirs_spect
78
from gempy.library import config
89

910

@@ -17,6 +18,12 @@ def setDefaults(self):
1718
self.add_illum_mask = True
1819

1920

21+
class determineDistortionConfig(parameters_gnirs_spect.determineDistortionConfig):
22+
def setDefaults(self):
23+
self.spatial_order = 2
24+
self.min_line_length = 0.5 # because some orders go off the edge
25+
26+
2027
class determineSlitEdgesConfig(parameters_spect.determineSlitEdgesConfig):
2128
# GNIRS XD has narrow slits with more curvature than the longslit flats
2229
# the default values were calibrated to, so adjust some values.

geminidr/gnirs/parameters_gnirs_longslit.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,6 @@ def list_of_ints_check(value):
1111
[int(x) for x in str(value).split(',')]
1212
return True
1313

14-
class determineDistortionConfig(parameters_spect.determineDistortionConfig):
15-
spectral_order = config.RangeField("Fitting order in spectral direction", int, None, min=1, optional=True)
16-
min_line_length = config.RangeField("Exclude line traces shorter than this fraction of spatial dimension",
17-
float, None, min=0., max=1., optional=True)
18-
max_missed = config.RangeField("Maximum number of steps to miss before a line is lost",
19-
int, None, min=0, optional=True)
20-
def setDefaults(self):
21-
self.min_snr = 10
22-
self.debug_reject_bad = False
23-
2414
class determineWavelengthSolutionConfig(parameters_spect.determineWavelengthSolutionConfig):
2515
order = config.RangeField("Order of fitting function", int, None, min=1,
2616
optional=True)

geminidr/gnirs/parameters_gnirs_spect.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# in the primitives_gnirs_spect.py file, in alphabetical order.
33
from astrodata import AstroData
44
from gempy.library import config
5-
from geminidr.core import parameters_preprocess
5+
from geminidr.core import parameters_preprocess, parameters_spect
66
from . import parameters_gnirs
77

88

@@ -16,6 +16,15 @@ def setDefaults(self):
1616
self.min_skies = 2
1717

1818

19+
class determineDistortionConfig(parameters_spect.determineDistortionConfig):
20+
spectral_order = config.RangeField("Fitting order in spectral direction", int, None, min=1, optional=True)
21+
max_missed = config.RangeField("Maximum number of steps to miss before a line is lost",
22+
int, None, min=0, optional=True)
23+
def setDefaults(self):
24+
self.min_snr = 10
25+
self.debug_reject_bad = False
26+
27+
1928
class skyCorrectConfig(parameters_preprocess.skyCorrectConfig):
2029
def setDefaults(self):
2130
# self.scale_sky = False #MS: IF for whatever reason the exposure times are different between frames being subtracted, that case may require a special treatment

0 commit comments

Comments
 (0)