forked from jorgedelpozolerida/Segmentation_CMB
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_split.py
More file actions
137 lines (99 loc) · 4.73 KB
/
Copy pathgenerate_split.py
File metadata and controls
137 lines (99 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
# -*-coding:utf-8 -*-
""" Generates split file for dataset
Two different split types implemented ('train+validation+test', 'train+validation'), both keeping healthy-ill patients ratio
for each split.
@author: jorgedelpozolerida
@date: 29/10/2023
"""
import os
import sys
import argparse
import traceback
import logging # NOQA E402
import numpy as np # NOQA E402
import pandas as pd # NOQA E402
import pickle
logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger(__name__)
import json
import os
import random
from collections import defaultdict
def create_splits(data_dir, seed=42):
"""
Splits data into train, validation, and optionally test sets while maintaining the proportion of healthy and unhealthy subjects.
Args:
data_dir (str): Directory containing the subfolders for each subject.
split_type (str): Type of split ('train_val_test' or 'train_val').
seed (int): Random seed for reproducibility.
Returns:
dict: Dictionary with keys 'train', 'validation', 'test' (if applicable) and corresponding subfolder lists as values.
"""
random.seed(seed)
healthy_subjects = []
unhealthy_subjects = []
# Iterate over subfolders and classify subjects based on the number of CMBs
for subj_folder in os.listdir(data_dir):
subj_path = os.path.join(data_dir, subj_folder)
if os.path.isdir(subj_path):
metadata_filepath = os.path.join(subj_path, 'Annotations_metadata' , f'{subj_folder}_raw.json')
if os.path.exists(metadata_filepath):
with open(metadata_filepath, 'rb') as file:
metadata_dict = json.load(file)
if len(metadata_dict['T2S']['centers_of_mass']) == 0:
healthy_subjects.append(subj_folder)
else:
unhealthy_subjects.append(subj_folder)
# Shuffle lists
random.shuffle(healthy_subjects)
random.shuffle(unhealthy_subjects)
# Calculate split sizes
num_healthy = len(healthy_subjects)
num_unhealthy = len(unhealthy_subjects)
healthy_train_size = int(num_healthy * 0.7)
unhealthy_train_size = int(num_unhealthy * 0.7)
# Create splits
splits = defaultdict(list)
splits['train'] = healthy_subjects[:healthy_train_size] + unhealthy_subjects[:unhealthy_train_size]
if args.split_type == 'train_val':
healthy_val_size = num_healthy - healthy_train_size
unhealthy_val_size = num_unhealthy - unhealthy_train_size
splits['valid'] = healthy_subjects[healthy_train_size:] + unhealthy_subjects[unhealthy_train_size:]
elif args.split_type == 'train_val_test':
healthy_val_test_size = num_healthy - healthy_train_size
unhealthy_val_test_size = num_unhealthy - unhealthy_train_size
splits['valid'] = healthy_subjects[healthy_train_size:healthy_train_size + healthy_val_test_size // 2] + unhealthy_subjects[unhealthy_train_size:unhealthy_train_size + unhealthy_val_test_size // 2]
splits['test'] = healthy_subjects[healthy_train_size + healthy_val_test_size // 2:] + unhealthy_subjects[unhealthy_train_size + unhealthy_val_test_size // 2:]
# Shuffle splits to mix healthy and unhealthy subjects
for key in splits:
random.shuffle(splits[key])
return splits
def main(args):
data_dir = os.path.join(args.dataset_dir, "Data")
splits_dict = create_splits(data_dir, seed=42)
# Save to splits.json
split_path = os.path.join(data_dir, 'splits.json')
_logger.info(f"File path: {split_path}")
if os.path.exists(split_path) and not args.overwrite:
_logger.warning("Splits already exist, add overwrite flag if wanted")
else:
with open(split_path, 'w') as f:
json.dump(splits_dict, f)
_logger.info("Splits created and saved to splits.json.")
return
def parse_args():
'''
Parses all script arguments.
'''
parser = argparse.ArgumentParser()
parser.add_argument('--overwrite', action='store_true', default=False,
help='Add this flag if you want to overwrite file')
parser.add_argument('--dataset_dir', type=str, default=None,
help='Path to the dataset folder of dataset which has Data/ folder inside with subjects')
parser.add_argument('--split_type', type=str, choices=['train_val_test', 'train_val'],
default='train_val', help='Type of split to create (default: train_val_test)')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
main(args)