Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ repos:
rev: v1.7.7
hooks:
- id: docformatter
# Pin to py3.12 — docformatter's transitive dep `untokenize 0.1.1`
# uses `ast.Constant.s` (removed in py3.14), so pre-commit.ci's
# default py3.14 env fails to build it.
language_version: python3.12

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1
Expand Down
326 changes: 267 additions & 59 deletions fuse/plugins/pmt_and_daq/pmt_afterpulses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import strax
import numpy as np
import numba
import straxen

from ...dtypes import propagated_photons_fields
Expand Down Expand Up @@ -119,6 +120,35 @@ def setup(self):
if isinstance(self.uniform_to_pmt_ap[k][q], list):
self.uniform_to_pmt_ap[k][q] = np.array(self.uniform_to_pmt_ap[k][q])

# Pre-flatten the nested CDF dict into parallel arrays for the per-element
# loop. Cast every CDF to float64 once here so the numba kernel sees a
# single dispatch specialisation and per-iteration arithmetic uses a
# fixed dtype throughout (avoids per-chunk promotions).
self._ap_elements = list(self.uniform_to_pmt_ap.keys())
self._ap_is_uniform = np.array(
[("Uniform" in e) for e in self._ap_elements], dtype=np.bool_
)
self._ap_delaytime_cdf = [
np.asarray(self.uniform_to_pmt_ap[e]["delaytime_cdf"], dtype=np.float64)
for e in self._ap_elements
]
self._ap_amplitude_cdf = [
np.asarray(self.uniform_to_pmt_ap[e]["amplitude_cdf"], dtype=np.float64)
for e in self._ap_elements
]
self._ap_amplitude_cdf_ndim = np.array(
[cdf.ndim for cdf in self._ap_amplitude_cdf], dtype=np.int8
)
self._ap_delaytime_bin_size = np.array(
[self.uniform_to_pmt_ap[e]["delaytime_bin_size"] for e in self._ap_elements],
dtype=np.float64,
)
self._ap_amplitude_bin_size = np.array(
[self.uniform_to_pmt_ap[e]["amplitude_bin_size"] for e in self._ap_elements],
dtype=np.float64,
)
self._gains_f64 = np.asarray(self.gains, dtype=np.float64)

def compute(self, s1_photons, s2_photons):
if not self.enable_pmt_afterpulses or (len(s1_photons) == 0 and len(s2_photons) == 0):
return np.zeros(0, dtype=self.dtype)
Expand Down Expand Up @@ -157,83 +187,261 @@ def photon_afterpulse(
self, merged_photon_timings, merged_photon_channels, merged_photon_id_dpe
):
"""For pmt afterpulses, gain and dpe generation is a bit different from
standard photons."""
element_list = self.uniform_to_pmt_ap.keys()
standard photons.

Two-pass numba scheme for the per-element afterpulse loop. Random
draws stay in Python and follow a fixed per-element order — same
argument shapes and same conditional skip on empty selection — so
the PCG64 sequence is preserved across calls.

Per element:
1. ``rng.random(N)`` for rU0 (Python).
2. ``_ap_select_kernel`` (numba): one tight loop that scales
rU0 in place by the per-element modifier (and halves the
dpe entries), compares against ``cdf[channel, -1]`` for
each row, and emits the indices that pass the selection
plus the running max of the per-row probability ceiling.
3. ``rng.random(N_sel)`` for rU1 (Python; skipped when
``N_sel == 0`` — the skip MUST stay before this draw or
the random sequence diverges).
4. Uniform branch: ``rng.uniform(low, high)`` (Python; only
on Uniform elements). Non-Uniform branch:
``_ap_kernel_nonuniform`` does the argmin-based inverse
CDF lookup.

Every float operation matches the equivalent numpy expression's
IEEE-754 sequence; rng calls happen in a deterministic order
with the same argument shapes per element.
"""
_photon_timings = []
_photon_channels = []
_photon_amplitude = []

for element in element_list:
delaytime_cdf = self.uniform_to_pmt_ap[element]["delaytime_cdf"]
amplitude_cdf = self.uniform_to_pmt_ap[element]["amplitude_cdf"]

delaytime_bin_size = self.uniform_to_pmt_ap[element]["delaytime_bin_size"]
amplitude_bin_size = self.uniform_to_pmt_ap[element]["amplitude_bin_size"]

# Assign each photon FRIST random uniform number rU0 from (0, 1] for timing
rU0 = 1 - self.rng.random(len(merged_photon_timings))

# delaytime_cdf is intentionally not normalized to 1 but the probability of the AP
prob_ap = delaytime_cdf[merged_photon_channels, -1]
if prob_ap.max() * self.pmt_ap_modifier > 0.5:
prob = prob_ap.max() * self.pmt_ap_modifier
_photon_gains = []

N = len(merged_photon_timings)
# Allocate the selection-index scratch buffer ONCE per call. Reused
# across all elements; only the first n_sel entries are valid per element.
sel_buf = np.empty(N, dtype=np.int64)
# Coerce channels to int64 once outside the per-element loop.
chans_int = merged_photon_channels.astype(np.int64, copy=False)
modifier = float(self.pmt_ap_modifier)

for i, element in enumerate(self._ap_elements):
delaytime_cdf = self._ap_delaytime_cdf[i]
amplitude_cdf = self._ap_amplitude_cdf[i]
delaytime_bin_size = self._ap_delaytime_bin_size[i]
amplitude_bin_size = self._ap_amplitude_bin_size[i]

# First rng draw for this element. Must happen unconditionally
# before the selection branch — moving it inside the branch
# would diverge the PCG64 state on elements with zero
# selections.
rU0 = 1 - self.rng.random(N)

# Pass 1 (numba): scales rU0 in-place, emits selection indices,
# also returns max(prob_ap) for the warning check.
n_sel, prob_max = _ap_select_kernel(
rU0,
chans_int,
merged_photon_id_dpe,
delaytime_cdf,
modifier,
sel_buf,
)

if prob_max * modifier > 0.5:
prob = prob_max * modifier
self.log.warning(f"PMT after pulse probability is {prob} larger than 0.5?")

# Scaling down (up) rU0 effectivly increase (decrease) the ap rate
rU0 /= self.pmt_ap_modifier

# Double the probability for those photon emitting dpe
rU0[merged_photon_id_dpe] /= 2

# Select those photons with U <= max of cdf of specific channel
sel_photon_id = np.where(rU0 <= prob_ap)[0]
if len(sel_photon_id) == 0:
if n_sel == 0:
continue
sel_photon_channel = merged_photon_channels[sel_photon_id]

# Assign selected photon SECOND random uniform number rU1 from (0, 1] for amplitude
rU1 = 1 - self.rng.random(len(sel_photon_channel))
sel_photon_id = sel_buf[:n_sel]
sel_photon_channel = chans_int[sel_photon_id]

# Second rng draw. The `continue` on n_sel == 0 above MUST stay
# before this draw — skipping it would diverge the PCG64 state
# on elements with empty selections.
rU1 = 1 - self.rng.random(n_sel)

# The map is made so that the indices are delay time in unit of ns
if "Uniform" in element:
if self._ap_is_uniform[i]:
# Third rng draw, conditional on the Uniform branch.
# `rng.uniform(low, high)` consumes len(sel) uniforms internally.
ap_delay = (
self.rng.uniform(
delaytime_cdf[sel_photon_channel, 0], delaytime_cdf[sel_photon_channel, 1]
delaytime_cdf[sel_photon_channel, 0],
delaytime_cdf[sel_photon_channel, 1],
)
* delaytime_bin_size
)
ap_amplitude = np.ones_like(ap_delay)
else:
ap_delay = (
np.argmin(
np.abs(delaytime_cdf[sel_photon_channel] - rU0[sel_photon_id][:, None]),
axis=-1,
)
* delaytime_bin_size
- self.pmt_ap_t_modifier
)
if len(amplitude_cdf.shape) == 2:
ap_amplitude = (
np.argmin(np.abs(amplitude_cdf[sel_photon_channel] - rU1[:, None]), axis=-1)
* amplitude_bin_size
)
# Non-Uniform branch: argmin-based inverse CDF in the kernel
# below.
rU0_sel = rU0[sel_photon_id]
amp_ndim = int(self._ap_amplitude_cdf_ndim[i])
if amp_ndim == 2:
amp_cdf_2d = amplitude_cdf
amp_cdf_1d = _AP_1D_PLACEHOLDER
else:
ap_amplitude = (
np.argmin(np.abs(amplitude_cdf[None, :] - rU1[:, None]), axis=-1)
* amplitude_bin_size
)
amp_cdf_2d = _AP_2D_PLACEHOLDER
amp_cdf_1d = amplitude_cdf

ap_delay, ap_amplitude = _ap_kernel_nonuniform(
sel_photon_channel,
rU0_sel,
rU1,
delaytime_cdf,
amp_cdf_2d,
amp_cdf_1d,
amp_ndim == 2,
float(delaytime_bin_size),
float(amplitude_bin_size),
float(self.pmt_ap_t_modifier),
)

_photon_timings.append(merged_photon_timings[sel_photon_id] + ap_delay)
_photon_channels.append(merged_photon_channels[sel_photon_id])
_photon_amplitude.append(np.atleast_1d(ap_amplitude))
_photon_channels.append(sel_photon_channel)
_photon_gains.append(self._gains_f64[sel_photon_channel] * ap_amplitude)

if len(_photon_timings) > 0:
_photon_timings = np.hstack(_photon_timings)
_photon_channels = np.hstack(_photon_channels).astype(np.int64)
_photon_amplitude = np.hstack(_photon_amplitude)
_photon_gains = np.array(self.gains)[_photon_channels] * _photon_amplitude
if not _photon_timings:
return np.zeros(0, np.int64), np.zeros(0, np.int64), np.zeros(0)
return (
np.concatenate(_photon_timings),
np.concatenate(_photon_channels).astype(np.int64),
np.concatenate(_photon_gains),
)

return _photon_timings, _photon_channels, _photon_gains

# Module-level placeholders for the kernel's unused-rank amplitude_cdf parameter.
# Numba does NOT accept a single array param whose shape (1D vs 2D) varies per
# call — so we pass both shapes always; the kernel branches on `amp_is_2d` and
# reads from whichever is real. The placeholder is a single zero element of the
# unused rank.
_AP_2D_PLACEHOLDER = np.zeros((1, 1), dtype=np.float64)
_AP_1D_PLACEHOLDER = np.zeros(1, dtype=np.float64)


@numba.njit(cache=True, nogil=True)
def _ap_select_kernel(
rU0,
merged_photon_channels,
merged_photon_id_dpe,
delaytime_cdf,
pmt_ap_modifier,
sel_out,
):
"""Selection pass for the per-element afterpulse scheme.

For each row k of ``rU0``: divide by ``pmt_ap_modifier`` (and halve if
the row is flagged as a DPE photon) in place, compare against the
per-channel probability ceiling ``delaytime_cdf[channel, -1]``, and
if the scaled value passes the cut, record the row index. The kernel
also returns the maximum per-row probability ceiling for the caller's
> 0.5 warning check.

Modifies ``rU0`` in place. Writes the indices that pass the selection
into ``sel_out[:n_sel]`` and returns ``(n_sel, prob_max)``. Each row's
float ops (scalar division, halving, scalar compare) follow the same
IEEE-754 sequence as the equivalent numpy expression.
"""
n = rU0.shape[0]
K_last = delaytime_cdf.shape[1] - 1
n_sel = 0
prob_max = -1.0
for k in range(n):
v = rU0[k] / pmt_ap_modifier
if merged_photon_id_dpe[k]:
v /= 2.0
rU0[k] = v
ch = merged_photon_channels[k]
prob = delaytime_cdf[ch, K_last]
if prob > prob_max:
prob_max = prob
if v <= prob:
sel_out[n_sel] = k
n_sel += 1
return n_sel, prob_max


@numba.njit(cache=True, inline="always")
def _argmin_abs_diff_1d(cdf, r, K):
"""Bit-identical equivalent of ``np.argmin(np.abs(cdf - r))`` over a 1D CDF.

Uses strict ``<`` so ties break to the leftmost index — same convention
as ``np.argmin``. Linear O(K) scan with zero allocation; CDF bin counts
here are ~few hundred so the scan fits in L1 and avoids the per-call
``(N_sel, K)`` intermediate that the broadcast-+-argmin pattern would
materialise.
"""
best_i = 0
best_d = abs(cdf[0] - r)
for i in range(1, K):
d = abs(cdf[i] - r)
if d < best_d:
best_d = d
best_i = i
return best_i


@numba.njit(cache=True, inline="always")
def _argmin_abs_diff_row(cdf2d, row, r, K):
"""Same as ``_argmin_abs_diff_1d`` but over ``cdf2d[row, :]``."""
best_i = 0
best_d = abs(cdf2d[row, 0] - r)
for i in range(1, K):
d = abs(cdf2d[row, i] - r)
if d < best_d:
best_d = d
best_i = i
return best_i


@numba.njit(cache=True, nogil=True)
def _ap_kernel_nonuniform(
sel_ch,
rU0_sel,
rU1,
delaytime_cdf,
amplitude_cdf_2d,
amplitude_cdf_1d,
amp_is_2d,
delaytime_bin_size,
amplitude_bin_size,
pmt_ap_t_modifier,
):
"""Inner loop body for the non-Uniform afterpulse case.

For each selected photon ``k`` of channel ``ch = sel_ch[k]``, computes:

idx_d = argmin_i |delaytime_cdf[ch, i] - rU0_sel[k]|
idx_a = argmin_i |amplitude_cdf[ch_or_flat, i] - rU1[k]|
ap_delay[k] = idx_d * delaytime_bin_size - pmt_ap_t_modifier
ap_amplitude[k] = idx_a * amplitude_bin_size

Two register-resident O(K) scans per photon (one over the delay-time
CDF, one over the amplitude CDF). The ``amp_is_2d`` flag picks the
active amplitude-CDF input — the unused one is a placeholder array
(numba does not accept a single parameter whose rank varies per call,
so both shapes are always passed).
"""
n = sel_ch.shape[0]
ap_delay = np.empty(n, dtype=np.float64)
ap_amp = np.empty(n, dtype=np.float64)
K_d = delaytime_cdf.shape[1]
if amp_is_2d:
K_a = amplitude_cdf_2d.shape[1]
else:
K_a = amplitude_cdf_1d.shape[0]

for k in range(n):
ch = sel_ch[k]
idx_d = _argmin_abs_diff_row(delaytime_cdf, ch, rU0_sel[k], K_d)
ap_delay[k] = idx_d * delaytime_bin_size - pmt_ap_t_modifier

if amp_is_2d:
idx_a = _argmin_abs_diff_row(amplitude_cdf_2d, ch, rU1[k], K_a)
else:
return np.zeros(0, np.int64), np.zeros(0, np.int64), np.zeros(0)
idx_a = _argmin_abs_diff_1d(amplitude_cdf_1d, rU1[k], K_a)
ap_amp[k] = idx_a * amplitude_bin_size

return ap_delay, ap_amp
Loading