Skip to content

Commit 8eccd24

Browse files
committed
Add bfloat16 support
1 parent 8348230 commit 8eccd24

10 files changed

Lines changed: 162 additions & 56 deletions

File tree

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ A valid configuration file has three main sections:
202202
},
203203
"model_config": {
204204
"target_bones": ["femur", "tibia", "patella"],
205-
"model_directory": "/path/to/save/models"
205+
"model_directory": "/path/to/save/models",
206+
"dtype": "bfloat16"
206207
},
207208
"output_config": {
208209
"prediction_directory": "/path/to/save/predictions"
@@ -215,6 +216,9 @@ A valid configuration file has three main sections:
215216
- Default for SKI10: `["femur", "tibia"]`
216217
- Default for OAI: `["femur", "tibia", "patella"]`
217218
- *Cartilage is closely coupled: "femur" includes "femoral cartilage".*
219+
- **`dtype`**: (Optional) Data type for feature extraction matrices.
220+
- Options: `"float32"` (default), `"bfloat16"`.
221+
- **Recommedation**: Use `"bfloat16"` to reduce memory usage by ~50%. Requires `ml_dtypes`.
218222

219223
> **Note**: The `split_file` should be a JSON containing `{"train": ["file1.mhd", ...], "eval": ["file2.mhd", ...]}`.
220224

kneeseg/bone_rf.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .features import extract_features, compute_rsid_features
66
from scipy.ndimage import gaussian_filter
77
from tqdm import tqdm
8+
import ml_dtypes
89

910
class BoneClassifier:
1011
def __init__(self, n_estimators=100, max_depth=25, n_jobs=-1):
@@ -16,47 +17,57 @@ def __init__(self, n_estimators=100, max_depth=25, n_jobs=-1):
1617
class_weight='balanced'
1718
)
1819

19-
def extract_bone_features(self, image, prob_map=None, spacing=None):
20+
def extract_bone_features(self, image, prob_map=None, spacing=None, target_dtype='float32'):
2021
"""
2122
Extract features suitable for Bone Segmentation.
2223
Focuses on larger context and spatial coordinates.
2324
"""
25+
# Resolve dtype
26+
if isinstance(target_dtype, str):
27+
if target_dtype == 'bfloat16':
28+
dtype = ml_dtypes.bfloat16
29+
else:
30+
dtype = np.float32
31+
else:
32+
dtype = target_dtype
33+
2434
features = []
2535

2636
# Normalize
2737
img_mean = image.mean()
2838
img_std = image.std()
2939
img_norm = (image.astype(np.float32) - img_mean) / (img_std + 1e-6)
40+
img_norm = img_norm.astype(dtype)
3041

3142
# 1. Intensity & Smooth (Multi-scale)
3243
features.append(img_norm.flatten())
33-
features.append(gaussian_filter(img_norm, sigma=2.0).flatten())
34-
features.append(gaussian_filter(img_norm, sigma=4.0).flatten())
44+
features.append(gaussian_filter(img_norm.astype(np.float32), sigma=2.0).astype(dtype).flatten())
45+
features.append(gaussian_filter(img_norm.astype(np.float32), sigma=4.0).astype(dtype).flatten())
3546

3647
# 2. Spatial Coordinates (Normalized 0-1)
3748
z, y, x = np.mgrid[0:image.shape[0], 0:image.shape[1], 0:image.shape[2]]
38-
features.append(z.flatten() / image.shape[0])
39-
features.append(y.flatten() / image.shape[1])
40-
features.append(x.flatten() / image.shape[2])
49+
features.append((z.flatten() / image.shape[0]).astype(dtype))
50+
features.append((y.flatten() / image.shape[1]).astype(dtype))
51+
features.append((x.flatten() / image.shape[2]).astype(dtype))
4152

4253
# 3. RSID (Texture) - Sparse but helpful
4354
# Keep shift small-ish but larger than cartilage
4455
# Downsample image for RSID to save memory? No, standard RSID.
45-
rsid = compute_rsid_features(img_norm, num_shifts=20, max_shift=20, seed=42)
56+
rsid = compute_rsid_features(img_norm, num_shifts=20, max_shift=20, seed=42, dtype=dtype)
4657
for i in range(rsid.shape[-1]):
4758
features.append(rsid[..., i].flatten())
4859

4960
# 4. Auto-Context Probabilities
5061
if prob_map is not None:
5162
# prob_map is (Z, Y, X, C)
5263
for c in range(prob_map.shape[-1]):
53-
p_ch = prob_map[..., c]
64+
p_ch = prob_map[..., c].astype(dtype)
5465
features.append(p_ch.flatten())
55-
features.append(gaussian_filter(p_ch, sigma=2.0).flatten())
66+
features.append(gaussian_filter(p_ch.astype(np.float32), sigma=2.0).astype(dtype).flatten())
5667

5768
return np.stack(features, axis=1)
5869

59-
def train(self, images, labels, prob_maps=None, subsample=50000):
70+
def train(self, images, labels, prob_maps=None, subsample=50000, dtype='float32'):
6071
"""
6172
images: list of 3D arrays
6273
labels: list of 3D arrays (0=bg, 1=Femur, 2=FemCart, 3=Tibia, 4=TibCart)
@@ -113,7 +124,7 @@ def train(self, images, labels, prob_maps=None, subsample=50000):
113124
])
114125

115126
pm = prob_maps[i] if prob_maps else None
116-
feats_flat = self.extract_bone_features(img, prob_map=pm) # (N_all, F)
127+
feats_flat = self.extract_bone_features(img, prob_map=pm, target_dtype=dtype) # (N_all, F)
117128

118129
# Convert 3D coords to 1D indices
119130
flat_indices = np.ravel_multi_index(coords.T, img.shape)
@@ -127,8 +138,8 @@ def train(self, images, labels, prob_maps=None, subsample=50000):
127138
print(" Fitting Bone Random Forest...")
128139
self.clf.fit(np.vstack(X_all), np.concatenate(y_all))
129140

130-
def predict(self, image, prob_map=None):
131-
feats_flat = self.extract_bone_features(image, prob_map)
141+
def predict(self, image, prob_map=None, dtype='float32'):
142+
feats_flat = self.extract_bone_features(image, prob_map, target_dtype=dtype)
132143

133144
# Predict in chunks to be safe? Or full.
134145
# 100GB RAM. Image ~10M voxels. Features ~30 floats -> 300MB * 4 = 1.2GB.

kneeseg/configs/config_schema.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@
142142
"type": "string",
143143
"description": "Optional override for model directory."
144144
},
145+
"dtype": {
146+
"type": "string",
147+
"enum": [
148+
"float32",
149+
"bfloat16"
150+
],
151+
"default": "float32",
152+
"description": "Data type for feature extraction and training. Use 'bfloat16' to reduce memory usage."
153+
},
145154
"n_jobs": {
146155
"type": "integer",
147156
"default": -1,

kneeseg/features.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compute_dts_from_landmarks(image_shape, landmarks_dict, spacing=None):
3939
dts[name] = np.full(image_shape, 200.0, dtype=np.float32)
4040
return dts
4141

42-
def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None):
42+
def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None, dtype=np.float32):
4343
"""
4444
Computes Random Shift Intensity Difference (RSID) features only at mask locations.
4545
Returns (num_masked_voxels, num_shifts) if mask provided, else full volume.
@@ -58,7 +58,7 @@ def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None
5858
features.append(base_slice - padded[max_shift+dz : max_shift+z+dz,
5959
max_shift+dy : max_shift+y+dy,
6060
max_shift+dx : max_shift+x+dx])
61-
return np.stack(features, axis=-1).astype(np.float32)
61+
return np.stack(features, axis=-1).astype(dtype)
6262

6363
# Masked computation
6464
if mask.ndim == 3:
@@ -70,7 +70,7 @@ def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None
7070
raise ValueError("RSID requires a 3D mask for spatial optimization.")
7171

7272
n_vox = len(mask_indices)
73-
rsid_features = np.zeros((n_vox, num_shifts), dtype=np.float32)
73+
rsid_features = np.zeros((n_vox, num_shifts), dtype=dtype)
7474

7575
# Pad image once
7676
padded = np.pad(image, max_shift, mode='edge')
@@ -87,7 +87,7 @@ def compute_rsid_features(image, num_shifts=30, max_shift=10, seed=42, mask=None
8787

8888
return rsid_features
8989

90-
def compute_landmark_features(image_shape, landmarks_dict, indices_dict, mask=None, spacing=None):
90+
def compute_landmark_features(image_shape, landmarks_dict, indices_dict, mask=None, spacing=None, dtype=np.float32):
9191
"""
9292
Computes distance to landmarks only at mask locations.
9393
Operates in Physical Space (MM) if spacing provided.
@@ -115,12 +115,17 @@ def compute_landmark_features(image_shape, landmarks_dict, indices_dict, mask=No
115115
if idx < len(points):
116116
p = points[idx]
117117
dist = np.sqrt(np.sum((coords - p)**2, axis=1))
118-
features.append(dist.astype(np.float32))
118+
features.append(dist.astype(dtype))
119119
else:
120-
features.append(np.full(len(coords), 200.0, dtype=np.float32))
120+
features.append(np.full(len(coords), 200.0, dtype=dtype))
121121
return features
122122

123123
def compute_dt_arithmetic_features(dts, mask=None):
124+
# This just returns float arrays, casting happens in extract_features usually
125+
# But let's check its usage.
126+
# It returns list of arrays. extract_features casts them.
127+
# So we don't strictly need to update this unless we want intermediate memory savings.
128+
# Let's keep it as is, casting is done by caller.
124129
features = []
125130
if 'femur' in dts and 'tibia' in dts:
126131
f = dts['femur']
@@ -139,32 +144,46 @@ def compute_dt_arithmetic_features(dts, mask=None):
139144
features.append(f - t)
140145
return features
141146

142-
def extract_features(image, dts, sigma=1.0, mask=None, r_shifts=30, landmarks_dict=None, landmark_indices=None, prob_map=None, spacing=None, sorted_bones_override=None):
147+
def extract_features(image, dts, sigma=1.0, mask=None, r_shifts=30, landmarks_dict=None, landmark_indices=None, prob_map=None, spacing=None, sorted_bones_override=None, target_dtype='float32'):
143148
"""
144-
Optimized feature extraction that avoids 4D intermediate arrays and respects masks.
149+
Computes Optimized feature extraction that avoids 4D intermediate arrays and respects masks.
145150
"""
151+
import ml_dtypes
152+
153+
# Resolve dtype
154+
if isinstance(target_dtype, str):
155+
if target_dtype == 'bfloat16':
156+
dtype = ml_dtypes.bfloat16
157+
else:
158+
dtype = np.float32
159+
else:
160+
dtype = target_dtype
161+
146162
img_mean = image.mean()
147163
img_std = image.std()
164+
165+
# Normalize and cast immediately to target dtype
148166
image_norm = (image.astype(np.float32) - img_mean) / (img_std + 1e-6)
167+
image_norm = image_norm.astype(dtype)
149168

150169
def get_masked(arr):
151170
if mask is not None:
152171
if mask.ndim == arr.ndim:
153-
return arr[mask].astype(np.float32)
172+
return arr[mask].astype(dtype)
154173
else:
155-
return arr.flatten()[mask].astype(np.float32)
156-
return arr.flatten().astype(np.float32)
174+
return arr.flatten()[mask].astype(dtype)
175+
return arr.flatten().astype(dtype)
157176

158177
features = []
159178

160179
# 1. Intensity
161180
features.append(get_masked(image_norm))
162181

163182
# 2. Gaussian
164-
features.append(get_masked(gaussian_filter(image_norm, sigma=sigma)))
183+
features.append(get_masked(gaussian_filter(image_norm.astype(np.float32), sigma=sigma)))
165184

166185
# 3. Gradient
167-
features.append(get_masked(gaussian_gradient_magnitude(image_norm, sigma=sigma)))
186+
features.append(get_masked(gaussian_gradient_magnitude(image_norm.astype(np.float32), sigma=sigma)))
168187

169188
# 5. DTs
170189
if sorted_bones_override is None:
@@ -185,31 +204,31 @@ def get_masked(arr):
185204
# get_masked returns flat array of size N_masked
186205
# We can use shape from an existing feature (like Intensity)
187206
ref_shape = features[0].shape
188-
features.append(np.full(ref_shape, 100.0, dtype=np.float32))
207+
features.append(np.full(ref_shape, 100.0, dtype=dtype))
189208

190209
# 6. DT Arithmetic
191210
# Only compute if 'femur' and 'tibia' are present/logic applies
192211
dt_arith = compute_dt_arithmetic_features(dts, mask=mask)
193212
for f in dt_arith:
194213
if mask is None:
195-
features.append(f.flatten().astype(np.float32))
214+
features.append(f.flatten().astype(dtype))
196215
else:
197-
features.append(f.astype(np.float32))
216+
features.append(f.astype(dtype))
198217

199218
# 7. RSID (Mask-aware)
200219
if mask is not None and mask.ndim == 3:
201-
rsid_masked = compute_rsid_features(image_norm, num_shifts=r_shifts, max_shift=10, mask=mask)
220+
rsid_masked = compute_rsid_features(image_norm, num_shifts=r_shifts, max_shift=10, mask=mask, dtype=dtype)
202221
for i in range(rsid_masked.shape[1]):
203222
features.append(rsid_masked[:, i])
204223
else:
205224
# Fallback to full then mask (wasteful)
206-
rsid_full = compute_rsid_features(image_norm, num_shifts=r_shifts, max_shift=10)
225+
rsid_full = compute_rsid_features(image_norm, num_shifts=r_shifts, max_shift=10, dtype=dtype)
207226
for i in range(rsid_full.shape[-1]):
208227
features.append(get_masked(rsid_full[..., i]))
209228

210229
# 8. Landmarks
211230
if landmarks_dict and landmark_indices:
212-
lm_features = compute_landmark_features(image.shape, landmarks_dict, landmark_indices, mask=mask, spacing=spacing)
231+
lm_features = compute_landmark_features(image.shape, landmarks_dict, landmark_indices, mask=mask, spacing=spacing, dtype=dtype)
213232
for f in lm_features:
214233
features.append(f)
215234

@@ -220,15 +239,15 @@ def get_masked(arr):
220239

221240
for p_ch in channels:
222241
features.append(get_masked(p_ch))
223-
features.append(get_masked(gaussian_filter(p_ch, sigma=sigma)))
242+
features.append(get_masked(gaussian_filter(p_ch.astype(np.float32), sigma=sigma)))
224243

225244
# Context RSID
226245
if mask is not None and mask.ndim == 3:
227-
rsid_p = compute_rsid_features(p_ch, num_shifts=15, max_shift=15, mask=mask)
246+
rsid_p = compute_rsid_features(p_ch, num_shifts=15, max_shift=15, mask=mask, dtype=dtype)
228247
for i in range(rsid_p.shape[1]):
229248
features.append(rsid_p[:, i])
230249
else:
231-
rsid_p_full = compute_rsid_features(p_ch, num_shifts=15, max_shift=15)
250+
rsid_p_full = compute_rsid_features(p_ch, num_shifts=15, max_shift=15, dtype=dtype)
232251
for i in range(rsid_p_full.shape[-1]):
233252
features.append(get_masked(rsid_p_full[..., i]))
234253

kneeseg/pipeline/inference.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def inference_improved(config=None):
3030
target_bones = model_cfg.get('target_bones', ['femur', 'tibia'])
3131
out_cfg = config['output_config']
3232

33+
target_dtype = model_cfg.get('dtype', 'float32')
34+
print(f"Target Bones: {target_bones}, dtype: {target_dtype}")
35+
3336
image_dir = data_cfg['image_directory']
3437
label_dir = data_cfg.get('label_directory') # Optional
3538
split_file = data_cfg.get('split_file') # Optional
@@ -110,10 +113,10 @@ def inference_improved(config=None):
110113
img, spacing = load_volume(img_path, return_spacing=True)
111114

112115
# 1. Predict Bones (Pass 1)
113-
_, prob1 = bone_rf_p1.predict(img)
116+
_, prob1 = bone_rf_p1.predict(img, dtype=target_dtype)
114117

115118
# 2. Predict Bones (Pass 2 - Auto-Context)
116-
bone_pred_flat, _ = bone_rf_p2.predict(img, prob_map=prob1)
119+
bone_pred_flat, _ = bone_rf_p2.predict(img, prob_map=prob1, dtype=target_dtype)
117120
bone_pred = bone_pred_flat.reshape(img.shape)
118121

119122
bone_masks = {}
@@ -127,14 +130,14 @@ def inference_improved(config=None):
127130
if cart_rf_p1 is not None and cart_rf_p2 is not None:
128131
# 2. Predict Cartilage
129132
# Pass 1: Prob Map
130-
c_prob1 = cart_rf_p1.predict_proba_map(img, bone_masks, proximity_mm=20.0)
133+
c_prob1 = cart_rf_p1.predict_proba_map(img, bone_masks, proximity_mm=20.0, dtype=target_dtype)
131134

132135
# Pass 2: Final Prediction (Auto-Context)
133-
cart_pred, _ = cart_rf_p2.predict(img, bone_masks, proximity_mm=20.0, prob_map=c_prob1)
136+
cart_pred, _ = cart_rf_p2.predict(img, bone_masks, proximity_mm=20.0, prob_map=c_prob1, dtype=target_dtype)
134137
elif cart_rf_p1 is not None:
135138
# Fallback to single pass if p2 missing (compatibility)
136139
print("Warning: Only P1 model found. Running single pass.")
137-
cart_pred, _ = cart_rf_p1.predict(img, bone_masks, proximity_mm=20.0)
140+
cart_pred, _ = cart_rf_p1.predict(img, bone_masks, proximity_mm=20.0, dtype=target_dtype)
138141

139142
# 3. Evaluate
140143
if lbl is not None:

0 commit comments

Comments
 (0)