@@ -39,7 +39,7 @@ def compute_dts_from_landmarks(image_shape, landmarks_dict, spacing=None):
3939 dts [name ] = np .full (image_shape , 200.0 , dtype = np .float32 )
4040 return dts
4141
42- def compute_rsid_features (image , num_shifts = 30 , max_shift = 10 , seed = 42 , mask = None ):
42+ def compute_rsid_features (image , num_shifts = 30 , max_shift = 10 , seed = 42 , mask = None , dtype = np . float32 ):
4343 """
4444 Computes Random Shift Intensity Difference (RSID) features only at mask locations.
4545 Returns (num_masked_voxels, num_shifts) if mask provided, else full volume.
@@ -58,7 +58,7 @@ def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None
5858 features .append (base_slice - padded [max_shift + dz : max_shift + z + dz ,
5959 max_shift + dy : max_shift + y + dy ,
6060 max_shift + dx : max_shift + x + dx ])
61- return np .stack (features , axis = - 1 ).astype (np . float32 )
61+ return np .stack (features , axis = - 1 ).astype (dtype )
6262
6363 # Masked computation
6464 if mask .ndim == 3 :
@@ -70,7 +70,7 @@ def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None
7070 raise ValueError ("RSID requires a 3D mask for spatial optimization." )
7171
7272 n_vox = len (mask_indices )
73- rsid_features = np .zeros ((n_vox , num_shifts ), dtype = np . float32 )
73+ rsid_features = np .zeros ((n_vox , num_shifts ), dtype = dtype )
7474
7575 # Pad image once
7676 padded = np .pad (image , max_shift , mode = 'edge' )
@@ -87,7 +87,7 @@ def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None
8787
8888 return rsid_features
8989
90- def compute_landmark_features (image_shape , landmarks_dict , indices_dict , mask = None , spacing = None ):
90+ def compute_landmark_features (image_shape , landmarks_dict , indices_dict , mask = None , spacing = None , dtype = np . float32 ):
9191 """
9292 Computes distance to landmarks only at mask locations.
9393 Operates in Physical Space (MM) if spacing provided.
@@ -115,12 +115,17 @@ def compute_landmark_features(image_shape, landmarks_dict, indices_dict, mask=No
115115 if idx < len (points ):
116116 p = points [idx ]
117117 dist = np .sqrt (np .sum ((coords - p )** 2 , axis = 1 ))
118- features .append (dist .astype (np . float32 ))
118+ features .append (dist .astype (dtype ))
119119 else :
120- features .append (np .full (len (coords ), 200.0 , dtype = np . float32 ))
120+ features .append (np .full (len (coords ), 200.0 , dtype = dtype ))
121121 return features
122122
123123def compute_dt_arithmetic_features (dts , mask = None ):
124+ # This just returns float arrays, casting happens in extract_features usually
125+ # But let's check its usage.
126+ # It returns list of arrays. extract_features casts them.
127+ # So we don't strictly need to update this unless we want intermediate memory savings.
128+ # Let's keep it as is, casting is done by caller.
124129 features = []
125130 if 'femur' in dts and 'tibia' in dts :
126131 f = dts ['femur' ]
@@ -139,32 +144,46 @@ def compute_dt_arithmetic_features(dts, mask=None):
139144 features .append (f - t )
140145 return features
141146
142- def extract_features (image , dts , sigma = 1.0 , mask = None , r_shifts = 30 , landmarks_dict = None , landmark_indices = None , prob_map = None , spacing = None , sorted_bones_override = None ):
147+ def extract_features (image , dts , sigma = 1.0 , mask = None , r_shifts = 30 , landmarks_dict = None , landmark_indices = None , prob_map = None , spacing = None , sorted_bones_override = None , target_dtype = 'float32' ):
143148 """
144- Optimized feature extraction that avoids 4D intermediate arrays and respects masks.
149+ Computes Optimized feature extraction that avoids 4D intermediate arrays and respects masks.
145150 """
151+ import ml_dtypes
152+
153+ # Resolve dtype
154+ if isinstance (target_dtype , str ):
155+ if target_dtype == 'bfloat16' :
156+ dtype = ml_dtypes .bfloat16
157+ else :
158+ dtype = np .float32
159+ else :
160+ dtype = target_dtype
161+
146162 img_mean = image .mean ()
147163 img_std = image .std ()
164+
165+ # Normalize and cast immediately to target dtype
148166 image_norm = (image .astype (np .float32 ) - img_mean ) / (img_std + 1e-6 )
167+ image_norm = image_norm .astype (dtype )
149168
150169 def get_masked (arr ):
151170 if mask is not None :
152171 if mask .ndim == arr .ndim :
153- return arr [mask ].astype (np . float32 )
172+ return arr [mask ].astype (dtype )
154173 else :
155- return arr .flatten ()[mask ].astype (np . float32 )
156- return arr .flatten ().astype (np . float32 )
174+ return arr .flatten ()[mask ].astype (dtype )
175+ return arr .flatten ().astype (dtype )
157176
158177 features = []
159178
160179 # 1. Intensity
161180 features .append (get_masked (image_norm ))
162181
163182 # 2. Gaussian
164- features .append (get_masked (gaussian_filter (image_norm , sigma = sigma )))
183+ features .append (get_masked (gaussian_filter (image_norm . astype ( np . float32 ) , sigma = sigma )))
165184
166185 # 3. Gradient
167- features .append (get_masked (gaussian_gradient_magnitude (image_norm , sigma = sigma )))
186+ features .append (get_masked (gaussian_gradient_magnitude (image_norm . astype ( np . float32 ) , sigma = sigma )))
168187
169188 # 5. DTs
170189 if sorted_bones_override is None :
@@ -185,31 +204,31 @@ def get_masked(arr):
185204 # get_masked returns flat array of size N_masked
186205 # We can use shape from an existing feature (like Intensity)
187206 ref_shape = features [0 ].shape
188- features .append (np .full (ref_shape , 100.0 , dtype = np . float32 ))
207+ features .append (np .full (ref_shape , 100.0 , dtype = dtype ))
189208
190209 # 6. DT Arithmetic
191210 # Only compute if 'femur' and 'tibia' are present/logic applies
192211 dt_arith = compute_dt_arithmetic_features (dts , mask = mask )
193212 for f in dt_arith :
194213 if mask is None :
195- features .append (f .flatten ().astype (np . float32 ))
214+ features .append (f .flatten ().astype (dtype ))
196215 else :
197- features .append (f .astype (np . float32 ))
216+ features .append (f .astype (dtype ))
198217
199218 # 7. RSID (Mask-aware)
200219 if mask is not None and mask .ndim == 3 :
201- rsid_masked = compute_rsid_features (image_norm , num_shifts = r_shifts , max_shift = 10 , mask = mask )
220+ rsid_masked = compute_rsid_features (image_norm , num_shifts = r_shifts , max_shift = 10 , mask = mask , dtype = dtype )
202221 for i in range (rsid_masked .shape [1 ]):
203222 features .append (rsid_masked [:, i ])
204223 else :
205224 # Fallback to full then mask (wasteful)
206- rsid_full = compute_rsid_features (image_norm , num_shifts = r_shifts , max_shift = 10 )
225+ rsid_full = compute_rsid_features (image_norm , num_shifts = r_shifts , max_shift = 10 , dtype = dtype )
207226 for i in range (rsid_full .shape [- 1 ]):
208227 features .append (get_masked (rsid_full [..., i ]))
209228
210229 # 8. Landmarks
211230 if landmarks_dict and landmark_indices :
212- lm_features = compute_landmark_features (image .shape , landmarks_dict , landmark_indices , mask = mask , spacing = spacing )
231+ lm_features = compute_landmark_features (image .shape , landmarks_dict , landmark_indices , mask = mask , spacing = spacing , dtype = dtype )
213232 for f in lm_features :
214233 features .append (f )
215234
@@ -220,15 +239,15 @@ def get_masked(arr):
220239
221240 for p_ch in channels :
222241 features .append (get_masked (p_ch ))
223- features .append (get_masked (gaussian_filter (p_ch , sigma = sigma )))
242+ features .append (get_masked (gaussian_filter (p_ch . astype ( np . float32 ) , sigma = sigma )))
224243
225244 # Context RSID
226245 if mask is not None and mask .ndim == 3 :
227- rsid_p = compute_rsid_features (p_ch , num_shifts = 15 , max_shift = 15 , mask = mask )
246+ rsid_p = compute_rsid_features (p_ch , num_shifts = 15 , max_shift = 15 , mask = mask , dtype = dtype )
228247 for i in range (rsid_p .shape [1 ]):
229248 features .append (rsid_p [:, i ])
230249 else :
231- rsid_p_full = compute_rsid_features (p_ch , num_shifts = 15 , max_shift = 15 )
250+ rsid_p_full = compute_rsid_features (p_ch , num_shifts = 15 , max_shift = 15 , dtype = dtype )
232251 for i in range (rsid_p_full .shape [- 1 ]):
233252 features .append (get_masked (rsid_p_full [..., i ]))
234253
0 commit comments