-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsd_benchmark_model.py
More file actions
66 lines (44 loc) · 2.27 KB
/
Copy pathsd_benchmark_model.py
File metadata and controls
66 lines (44 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# -*- coding: utf-8 -*-
"""
Script for estimating the performance basing on the predictions of the model.
@author: iurii
"""
import argparse
import os
import csv
from glob import glob
import numpy as np
import face_morphing_benchmark.fmb_utils.fmb_utilities as fmb_utilities
import face_morphing_benchmark.fmb_utils.BPCER_at_APCER as bpcer_apcer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_name', default="test_model", type=str, help='name of the model')
parser.add_argument('-n', '--protocol_name', default="test_protocol", type=str, help='name of the protocol')
parser.add_argument('-p', '--models_path', default="./models", type=str, help='path to the models')
parser.add_argument('-r', '--predictions_filename', default="predictions.txt", type=str, help='path to the models')
parser.add_argument('-l', '--gt_labels_filename', default="gt_labels.txt", type=str, help='path to the models')
args = parser.parse_args()
return args
def benchmark_model(args):
predictions = (np.loadtxt("%s/%s/%s/%s" %(args.models_path, args.model_name, args.protocol_name, args.predictions_filename), dtype=np.dtype(float))).tolist()
gt_labels = (np.loadtxt("%s/%s/%s/%s" %(args.models_path, args.model_name, args.protocol_name, args.gt_labels_filename), dtype=np.dtype(float))).tolist()
# print(predictions)
# print(gt_labels)
#extract metrics
FMR_compare = [0.2, 0.1, 0.01, 0.001, 0.0001]
bpcer_apcer.get_ROC_metric(predictions,
gt_labels,
args.protocol_name,
args.model_name,
benchmark_data_path = args.models_path,
FMR_compare = FMR_compare)
APCER_compare = [0.2, 0.1, 0.01, 0.001, 0.0001]
bpcer_apcer.get_DET_metric(predictions,
gt_labels,
args.protocol_name,
args.model_name,
benchmark_data_path = args.models_path,
APCER_compare = APCER_compare)
if __name__ == '__main__':
args = parse_args()
benchmark_model(args)