-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcreate_splits_fewshot.py
More file actions
51 lines (39 loc) · 1.83 KB
/
Copy pathcreate_splits_fewshot.py
File metadata and controls
51 lines (39 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pandas as pd
import numpy as np
import os
N = 16
dataset = 'tcga_nsclc'
data_folder = f'splits/{dataset}/{dataset}_100'
all_data = np.array(pd.read_csv(f'dataset_csv/{dataset}.csv', header=None))
save_folder = f'splits/{dataset}/{dataset}_{N}'
if(not os.path.exists(save_folder)):
os.makedirs(save_folder)
for j in range(5):
orginal_data_split_path = data_folder + '/splits_'+str(j)+'.csv'
orginal_data_stastic_path = data_folder + '/splits_'+str(j)+'_descriptor.csv'
save_path = save_folder + '/splits_'+str(j)+'.csv'
orginal_data_split = np.array(pd.read_csv(orginal_data_split_path))
slidename2label = {}
for each_data in all_data:
slidename2label[each_data[1].rstrip('.svs')] = each_data[-1]
all_slide_label = []
selected_train_slide = []
for each_data in orginal_data_split:
slide_label = slidename2label[each_data[1]]
all_slide_label.append(slide_label)
unique_label = np.unique(all_slide_label)
for each_label in unique_label:
all_slide_label = np.array(all_slide_label)
each_label = np.array(each_label)
each_index = np.array(np.where(all_slide_label == each_label)[0])
selected_index = np.random.choice(each_index, size=N, replace=False)
for each_index in selected_index:
selected_train_slide.append(orginal_data_split[each_index][1])
orginal_data_split[:, 1][0:len(selected_train_slide)] = selected_train_slide
orginal_data_split[:, 1][len(selected_train_slide):-1] = np.nan
all_nums = np.array(pd.read_csv(orginal_data_stastic_path))
val_num = np.sum(all_nums[:, 2])
new_data_split = orginal_data_split[:val_num]
column_name = ['','train', 'val', 'test']
csv = pd.DataFrame(columns=column_name, data = new_data_split)
csv.to_csv(save_path, index=False)