diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2561ecf0..26ea1a9e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,6 +21,9 @@ repos: rev: v1.7.7 hooks: - id: docformatter + # Pin to py3.12 — docformatter's transitive dep `untokenize 0.1.1` + # uses `ast.Constant.s` (removed in py3.14), so pre-commit.ci's + # default py3.14 env fails to build it. language_version: python3.12 - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/fuse/plugins/pmt_and_daq/pmt_afterpulses.py b/fuse/plugins/pmt_and_daq/pmt_afterpulses.py index 2c73c256..2b1a36fd 100644 --- a/fuse/plugins/pmt_and_daq/pmt_afterpulses.py +++ b/fuse/plugins/pmt_and_daq/pmt_afterpulses.py @@ -1,5 +1,6 @@ import strax import numpy as np +import numba import straxen from ...dtypes import propagated_photons_fields @@ -119,6 +120,35 @@ def setup(self): if isinstance(self.uniform_to_pmt_ap[k][q], list): self.uniform_to_pmt_ap[k][q] = np.array(self.uniform_to_pmt_ap[k][q]) + # Pre-flatten the nested CDF dict into parallel arrays for the per-element + # loop. Cast every CDF to float64 once here so the numba kernel sees a + # single dispatch specialisation and per-iteration arithmetic uses a + # fixed dtype throughout (avoids per-chunk promotions). + self._ap_elements = list(self.uniform_to_pmt_ap.keys()) + self._ap_is_uniform = np.array( + [("Uniform" in e) for e in self._ap_elements], dtype=np.bool_ + ) + self._ap_delaytime_cdf = [ + np.asarray(self.uniform_to_pmt_ap[e]["delaytime_cdf"], dtype=np.float64) + for e in self._ap_elements + ] + self._ap_amplitude_cdf = [ + np.asarray(self.uniform_to_pmt_ap[e]["amplitude_cdf"], dtype=np.float64) + for e in self._ap_elements + ] + self._ap_amplitude_cdf_ndim = np.array( + [cdf.ndim for cdf in self._ap_amplitude_cdf], dtype=np.int8 + ) + self._ap_delaytime_bin_size = np.array( + [self.uniform_to_pmt_ap[e]["delaytime_bin_size"] for e in self._ap_elements], + dtype=np.float64, + ) + self._ap_amplitude_bin_size = np.array( + [self.uniform_to_pmt_ap[e]["amplitude_bin_size"] for e in self._ap_elements], + dtype=np.float64, + ) + self._gains_f64 = np.asarray(self.gains, dtype=np.float64) + def compute(self, s1_photons, s2_photons): if not self.enable_pmt_afterpulses or (len(s1_photons) == 0 and len(s2_photons) == 0): return np.zeros(0, dtype=self.dtype) @@ -157,83 +187,261 @@ def photon_afterpulse( self, merged_photon_timings, merged_photon_channels, merged_photon_id_dpe ): """For pmt afterpulses, gain and dpe generation is a bit different from - standard photons.""" - element_list = self.uniform_to_pmt_ap.keys() + standard photons. + + Two-pass numba scheme for the per-element afterpulse loop. Random + draws stay in Python and follow a fixed per-element order — same + argument shapes and same conditional skip on empty selection — so + the PCG64 sequence is preserved across calls. + + Per element: + 1. ``rng.random(N)`` for rU0 (Python). + 2. ``_ap_select_kernel`` (numba): one tight loop that scales + rU0 in place by the per-element modifier (and halves the + dpe entries), compares against ``cdf[channel, -1]`` for + each row, and emits the indices that pass the selection + plus the running max of the per-row probability ceiling. + 3. ``rng.random(N_sel)`` for rU1 (Python; skipped when + ``N_sel == 0`` — the skip MUST stay before this draw or + the random sequence diverges). + 4. Uniform branch: ``rng.uniform(low, high)`` (Python; only + on Uniform elements). Non-Uniform branch: + ``_ap_kernel_nonuniform`` does the argmin-based inverse + CDF lookup. + + Every float operation matches the equivalent numpy expression's + IEEE-754 sequence; rng calls happen in a deterministic order + with the same argument shapes per element. + """ _photon_timings = [] _photon_channels = [] - _photon_amplitude = [] - - for element in element_list: - delaytime_cdf = self.uniform_to_pmt_ap[element]["delaytime_cdf"] - amplitude_cdf = self.uniform_to_pmt_ap[element]["amplitude_cdf"] - - delaytime_bin_size = self.uniform_to_pmt_ap[element]["delaytime_bin_size"] - amplitude_bin_size = self.uniform_to_pmt_ap[element]["amplitude_bin_size"] - - # Assign each photon FRIST random uniform number rU0 from (0, 1] for timing - rU0 = 1 - self.rng.random(len(merged_photon_timings)) - - # delaytime_cdf is intentionally not normalized to 1 but the probability of the AP - prob_ap = delaytime_cdf[merged_photon_channels, -1] - if prob_ap.max() * self.pmt_ap_modifier > 0.5: - prob = prob_ap.max() * self.pmt_ap_modifier + _photon_gains = [] + + N = len(merged_photon_timings) + # Allocate the selection-index scratch buffer ONCE per call. Reused + # across all elements; only the first n_sel entries are valid per element. + sel_buf = np.empty(N, dtype=np.int64) + # Coerce channels to int64 once outside the per-element loop. + chans_int = merged_photon_channels.astype(np.int64, copy=False) + modifier = float(self.pmt_ap_modifier) + + for i, element in enumerate(self._ap_elements): + delaytime_cdf = self._ap_delaytime_cdf[i] + amplitude_cdf = self._ap_amplitude_cdf[i] + delaytime_bin_size = self._ap_delaytime_bin_size[i] + amplitude_bin_size = self._ap_amplitude_bin_size[i] + + # First rng draw for this element. Must happen unconditionally + # before the selection branch — moving it inside the branch + # would diverge the PCG64 state on elements with zero + # selections. + rU0 = 1 - self.rng.random(N) + + # Pass 1 (numba): scales rU0 in-place, emits selection indices, + # also returns max(prob_ap) for the warning check. + n_sel, prob_max = _ap_select_kernel( + rU0, + chans_int, + merged_photon_id_dpe, + delaytime_cdf, + modifier, + sel_buf, + ) + + if prob_max * modifier > 0.5: + prob = prob_max * modifier self.log.warning(f"PMT after pulse probability is {prob} larger than 0.5?") - # Scaling down (up) rU0 effectivly increase (decrease) the ap rate - rU0 /= self.pmt_ap_modifier - - # Double the probability for those photon emitting dpe - rU0[merged_photon_id_dpe] /= 2 - - # Select those photons with U <= max of cdf of specific channel - sel_photon_id = np.where(rU0 <= prob_ap)[0] - if len(sel_photon_id) == 0: + if n_sel == 0: continue - sel_photon_channel = merged_photon_channels[sel_photon_id] - # Assign selected photon SECOND random uniform number rU1 from (0, 1] for amplitude - rU1 = 1 - self.rng.random(len(sel_photon_channel)) + sel_photon_id = sel_buf[:n_sel] + sel_photon_channel = chans_int[sel_photon_id] + + # Second rng draw. The `continue` on n_sel == 0 above MUST stay + # before this draw — skipping it would diverge the PCG64 state + # on elements with empty selections. + rU1 = 1 - self.rng.random(n_sel) - # The map is made so that the indices are delay time in unit of ns - if "Uniform" in element: + if self._ap_is_uniform[i]: + # Third rng draw, conditional on the Uniform branch. + # `rng.uniform(low, high)` consumes len(sel) uniforms internally. ap_delay = ( self.rng.uniform( - delaytime_cdf[sel_photon_channel, 0], delaytime_cdf[sel_photon_channel, 1] + delaytime_cdf[sel_photon_channel, 0], + delaytime_cdf[sel_photon_channel, 1], ) * delaytime_bin_size ) ap_amplitude = np.ones_like(ap_delay) else: - ap_delay = ( - np.argmin( - np.abs(delaytime_cdf[sel_photon_channel] - rU0[sel_photon_id][:, None]), - axis=-1, - ) - * delaytime_bin_size - - self.pmt_ap_t_modifier - ) - if len(amplitude_cdf.shape) == 2: - ap_amplitude = ( - np.argmin(np.abs(amplitude_cdf[sel_photon_channel] - rU1[:, None]), axis=-1) - * amplitude_bin_size - ) + # Non-Uniform branch: argmin-based inverse CDF in the kernel + # below. + rU0_sel = rU0[sel_photon_id] + amp_ndim = int(self._ap_amplitude_cdf_ndim[i]) + if amp_ndim == 2: + amp_cdf_2d = amplitude_cdf + amp_cdf_1d = _AP_1D_PLACEHOLDER else: - ap_amplitude = ( - np.argmin(np.abs(amplitude_cdf[None, :] - rU1[:, None]), axis=-1) - * amplitude_bin_size - ) + amp_cdf_2d = _AP_2D_PLACEHOLDER + amp_cdf_1d = amplitude_cdf + + ap_delay, ap_amplitude = _ap_kernel_nonuniform( + sel_photon_channel, + rU0_sel, + rU1, + delaytime_cdf, + amp_cdf_2d, + amp_cdf_1d, + amp_ndim == 2, + float(delaytime_bin_size), + float(amplitude_bin_size), + float(self.pmt_ap_t_modifier), + ) _photon_timings.append(merged_photon_timings[sel_photon_id] + ap_delay) - _photon_channels.append(merged_photon_channels[sel_photon_id]) - _photon_amplitude.append(np.atleast_1d(ap_amplitude)) + _photon_channels.append(sel_photon_channel) + _photon_gains.append(self._gains_f64[sel_photon_channel] * ap_amplitude) - if len(_photon_timings) > 0: - _photon_timings = np.hstack(_photon_timings) - _photon_channels = np.hstack(_photon_channels).astype(np.int64) - _photon_amplitude = np.hstack(_photon_amplitude) - _photon_gains = np.array(self.gains)[_photon_channels] * _photon_amplitude + if not _photon_timings: + return np.zeros(0, np.int64), np.zeros(0, np.int64), np.zeros(0) + return ( + np.concatenate(_photon_timings), + np.concatenate(_photon_channels).astype(np.int64), + np.concatenate(_photon_gains), + ) - return _photon_timings, _photon_channels, _photon_gains +# Module-level placeholders for the kernel's unused-rank amplitude_cdf parameter. +# Numba does NOT accept a single array param whose shape (1D vs 2D) varies per +# call — so we pass both shapes always; the kernel branches on `amp_is_2d` and +# reads from whichever is real. The placeholder is a single zero element of the +# unused rank. +_AP_2D_PLACEHOLDER = np.zeros((1, 1), dtype=np.float64) +_AP_1D_PLACEHOLDER = np.zeros(1, dtype=np.float64) + + +@numba.njit(cache=True, nogil=True) +def _ap_select_kernel( + rU0, + merged_photon_channels, + merged_photon_id_dpe, + delaytime_cdf, + pmt_ap_modifier, + sel_out, +): + """Selection pass for the per-element afterpulse scheme. + + For each row k of ``rU0``: divide by ``pmt_ap_modifier`` (and halve if + the row is flagged as a DPE photon) in place, compare against the + per-channel probability ceiling ``delaytime_cdf[channel, -1]``, and + if the scaled value passes the cut, record the row index. The kernel + also returns the maximum per-row probability ceiling for the caller's + > 0.5 warning check. + + Modifies ``rU0`` in place. Writes the indices that pass the selection + into ``sel_out[:n_sel]`` and returns ``(n_sel, prob_max)``. Each row's + float ops (scalar division, halving, scalar compare) follow the same + IEEE-754 sequence as the equivalent numpy expression. + """ + n = rU0.shape[0] + K_last = delaytime_cdf.shape[1] - 1 + n_sel = 0 + prob_max = -1.0 + for k in range(n): + v = rU0[k] / pmt_ap_modifier + if merged_photon_id_dpe[k]: + v /= 2.0 + rU0[k] = v + ch = merged_photon_channels[k] + prob = delaytime_cdf[ch, K_last] + if prob > prob_max: + prob_max = prob + if v <= prob: + sel_out[n_sel] = k + n_sel += 1 + return n_sel, prob_max + + +@numba.njit(cache=True, inline="always") +def _argmin_abs_diff_1d(cdf, r, K): + """Bit-identical equivalent of ``np.argmin(np.abs(cdf - r))`` over a 1D CDF. + + Uses strict ``<`` so ties break to the leftmost index — same convention + as ``np.argmin``. Linear O(K) scan with zero allocation; CDF bin counts + here are ~few hundred so the scan fits in L1 and avoids the per-call + ``(N_sel, K)`` intermediate that the broadcast-+-argmin pattern would + materialise. + """ + best_i = 0 + best_d = abs(cdf[0] - r) + for i in range(1, K): + d = abs(cdf[i] - r) + if d < best_d: + best_d = d + best_i = i + return best_i + + +@numba.njit(cache=True, inline="always") +def _argmin_abs_diff_row(cdf2d, row, r, K): + """Same as ``_argmin_abs_diff_1d`` but over ``cdf2d[row, :]``.""" + best_i = 0 + best_d = abs(cdf2d[row, 0] - r) + for i in range(1, K): + d = abs(cdf2d[row, i] - r) + if d < best_d: + best_d = d + best_i = i + return best_i + + +@numba.njit(cache=True, nogil=True) +def _ap_kernel_nonuniform( + sel_ch, + rU0_sel, + rU1, + delaytime_cdf, + amplitude_cdf_2d, + amplitude_cdf_1d, + amp_is_2d, + delaytime_bin_size, + amplitude_bin_size, + pmt_ap_t_modifier, +): + """Inner loop body for the non-Uniform afterpulse case. + + For each selected photon ``k`` of channel ``ch = sel_ch[k]``, computes: + + idx_d = argmin_i |delaytime_cdf[ch, i] - rU0_sel[k]| + idx_a = argmin_i |amplitude_cdf[ch_or_flat, i] - rU1[k]| + ap_delay[k] = idx_d * delaytime_bin_size - pmt_ap_t_modifier + ap_amplitude[k] = idx_a * amplitude_bin_size + + Two register-resident O(K) scans per photon (one over the delay-time + CDF, one over the amplitude CDF). The ``amp_is_2d`` flag picks the + active amplitude-CDF input — the unused one is a placeholder array + (numba does not accept a single parameter whose rank varies per call, + so both shapes are always passed). + """ + n = sel_ch.shape[0] + ap_delay = np.empty(n, dtype=np.float64) + ap_amp = np.empty(n, dtype=np.float64) + K_d = delaytime_cdf.shape[1] + if amp_is_2d: + K_a = amplitude_cdf_2d.shape[1] + else: + K_a = amplitude_cdf_1d.shape[0] + + for k in range(n): + ch = sel_ch[k] + idx_d = _argmin_abs_diff_row(delaytime_cdf, ch, rU0_sel[k], K_d) + ap_delay[k] = idx_d * delaytime_bin_size - pmt_ap_t_modifier + + if amp_is_2d: + idx_a = _argmin_abs_diff_row(amplitude_cdf_2d, ch, rU1[k], K_a) else: - return np.zeros(0, np.int64), np.zeros(0, np.int64), np.zeros(0) + idx_a = _argmin_abs_diff_1d(amplitude_cdf_1d, rU1[k], K_a) + ap_amp[k] = idx_a * amplitude_bin_size + + return ap_delay, ap_amp