-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils_predictions.py
More file actions
81 lines (63 loc) · 3.42 KB
/
Copy pathutils_predictions.py
File metadata and controls
81 lines (63 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import mne
import numpy as np
import matplotlib.pyplot as plt
import Metric.EvaluatorConverter as Converter
from Database.DatabaseMetrics import DatabaseMetrics
from IA.NNBase import NNBase
from Preprocessor.Preprocessor import Preprocessor
from Dataset.DatasetTypeEnum import dataset_enum_by_name
from Object.Signal.SignalTypeEnum import signal_enum_by_name
from Object.Signal.SignalTypeEnum import SignalTypeEnum
from Utils.Graphics.Heatmap import plot_2d_heatmap
from Utils.Graphics.SignalPlot import plot_signal
from Preprocessor.Loader.Loader import Loader
from IA.NeuralNetworkTypeEnum import neural_network_enum_by_name
DATASET = dataset_enum_by_name("CHBMIT")
MODELS = [
neural_network_enum_by_name("RNN"),
# neural_network_enum_by_name("CNN"),
neural_network_enum_by_name("CRNN"),
]
DOMAINS = [
# signal_enum_by_name("PSDWelch"),
signal_enum_by_name("PSDMultitaper"),
]
WINDOWS = [
# 5,
10,
]
database = DatabaseMetrics()
all_wrong_pred = []
for model in MODELS:
for domain in DOMAINS:
for window in WINDOWS:
db_objects = database.metrics_by_model_domain_window(DATASET.name, model.name, domain.name, window)
if db_objects == []:
print(f"No data found for {model.name}, {domain.name}, window size {window}")
continue
evaluations = []
for db_object in db_objects:
evaluations.append(Converter.model_from_tuple(db_object))
summaries = Loader.load_summary(dataset_type=DATASET, file_name="chb04_28.edf")
mne.utils.set_log_level("CRITICAL")
Loader.load_segmented_data(summaries, signal_type=SignalTypeEnum.Time, window_length=window, overlap_shift_size=window)
signal_original = summaries[0].signal.get_data_segmented()
labels_original = summaries[0].signal.get_label_segmented()
signal_combined = np.concatenate(signal_original, axis=1)
# plot_signal(signal_combined, "chb04_28.edf", SignalTypeEnum.Time, 256)
data, labels = Preprocessor.preprocess_with_file_name(dataset_type=DATASET, model_type=model, signal_type=domain, window_length=window, overlap_shift_size=window, file_name="chb04_28.edf")
for evaluation in evaluations:
modelNN = NNBase(0, 0)
modelNN.load_model(f"data/Models/{DATASET.name}/{domain.name}/{model.name}/{window}/{evaluation.accuracy}.keras")
modelNN.predict_classes(data)
index = []
for i in range(modelNN.predictions.size):
if modelNN.predictions[i] != bool(labels[i]):
index.append(i)
all_wrong_pred.append(f"Model: {model.name} | Domain: {domain.name} | Window: {window} | INDEXES: {index} | Acc: {evaluation.accuracy}")
for i in range(summaries[0].get_nr_occurrence()):
print(f"START: {summaries[0].start_time_of_ocurrence(i)} | END: {summaries[0].end_time_of_ocurrence(i)}")
for i in index:
plot_signal(signal_original[i], f"Index: {i} | GROUND TRUTH: {labels_original[i]} | PREDICTION: {'normal' if modelNN.predictions[i] else 'epilepsy'}", SignalTypeEnum.Time, 256)
plot_signal(data[i], f"Index: {i} | GROUND TRUTH: {'normal' if bool(labels[i]) else 'epilepsy'} | PREDICTION: {'normal' if modelNN.predictions[i] else 'epilepsy'}", domain, 256)
print(all_wrong_pred)