-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
117 lines (91 loc) · 3.47 KB
/
Copy pathevaluate.py
File metadata and controls
117 lines (91 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import argparse
import sys
import re
def load_json(filepath):
try:
with open(filepath, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading {filepath}: {e}")
sys.exit(1)
def extract_numbers(text):
if not isinstance(text, str):
text = str(text)
numbers = re.findall(r'-?\d+\.?\d*', text.replace(',', ''))
return [float(n) for n in numbers] if numbers else []
def evaluate_exact_match(gt, pred):
if str(gt).lower() == str(pred).lower():
return True
# Try numerical evaluation with 1% tolerance
gt_nums = extract_numbers(gt)
pred_nums = extract_numbers(pred)
if gt_nums and pred_nums:
return abs(gt_nums[0] - pred_nums[0]) / (abs(gt_nums[0]) + 1e-9) < 0.01
return False
def evaluate_json_match(gt, pred):
if not isinstance(pred, dict):
# Try to parse string as json
try:
pred = json.loads(pred)
except:
return False
correct = 0
total = len(gt)
for key, value in gt.items():
if key in pred and evaluate_exact_match(value, pred[key]):
correct += 1
return correct / total if total > 0 else 0
def evaluate_list_match(gt, pred):
if not isinstance(pred, list):
try:
pred = json.loads(pred)
except:
if isinstance(pred, str):
pred = [x.strip() for x in pred.split(',')]
else:
return False
if not isinstance(pred, list):
return False
gt_lower = [str(x).lower() for x in gt]
pred_lower = [str(x).lower() for x in pred]
correct = sum(1 for item in gt_lower if item in pred_lower)
return correct / len(gt) if len(gt) > 0 else 0
def evaluate(ground_truth_path, predictions_path):
ground_truth = load_json(ground_truth_path)
predictions = load_json(predictions_path)
total_score = 0
num_questions = len(ground_truth)
results = {}
for item in ground_truth:
q_id = item['id']
q_type = item['type']
gt_answer = item['answer']
if q_id not in predictions:
print(f"Warning: {q_id} missing from predictions.")
results[q_id] = 0
continue
pred_answer = predictions[q_id]
score = 0
if q_type == 'exact_match':
score = 1.0 if evaluate_exact_match(gt_answer, pred_answer) else 0.0
elif q_type == 'json_match':
score = evaluate_json_match(gt_answer, pred_answer)
elif q_type == 'list_match':
score = evaluate_list_match(gt_answer, pred_answer)
elif q_type == 'qualitative':
score = 1.0 if pred_answer else 0.0 # Just check if they provided something
results[q_id] = score
total_score += score
final_accuracy = (total_score / num_questions) * 100
print("\n--- Benchmark Results ---")
for q_id, score in results.items():
print(f"{q_id}: {score*100:.1f}%")
print(f"\nFinal Accuracy: {final_accuracy:.2f}%")
return final_accuracy
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate Pave Benchmark')
parser.add_argument('--gt', type=str, default='ground_truth.json', help='Path to ground truth JSON')
parser.add_argument('--pred', type=str, required=True, help='Path to predictions JSON')
args = parser.parse_args()
evaluate(args.gt, args.pred)