-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun-rewrite.py
More file actions
139 lines (109 loc) · 4.76 KB
/
Copy pathrun-rewrite.py
File metadata and controls
139 lines (109 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import sqlite3
from openai import OpenAI
from groq import Groq
import time
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
from statuti_utils import *
from dotenv import load_dotenv
import os
import argparse
load_dotenv()
parser = argparse.ArgumentParser(description="Rewriting engine configuration parameters")
parser.add_argument("--count", type=int, default=1, help="Instance number")
parser.add_argument("--service", type=str, choices=["openai", "groq"], default="groq", help="Which service to use")
parser.add_argument("--model-name", type=str, help="Model name to use", required=True)
parser.add_argument("--stop-after", type=int, default=None, help="Stop after N items (if specified)")
parser.add_argument("--pause", type=float, default=0.5, help="Pause time in seconds between requests")
args = parser.parse_args()
COUNT = args.count
SERVICE = args.service
MODEL_NAME = args.model_name
STOP_AFTER = args.stop_after
PAUSE = args.pause
RES_FOLDER = "db-rewrite"
PROMPT_FILE = f"prompt/prompt-rewrite.txt"
if COUNT:
DB_PATH = f"-{COUNT}.db"
else:
DB_PATH = f".db"
TSV_PATH = "data/parallel_task_test_data.tsv"
# ===
strategies = {
"fslashm": "inserire entrambe le forme, femminile e maschile, separate dallo slash, come studentesse/studenti o rettrice/rettore",
"mslashf": "inserire entrambe le forme, maschile e femminile, separate dallo slash, come studenti/studentesse o rettore/rettrice",
"fandm": "inserire entrambe le forme, femminile e maschile, separate dalla congiunzione 'e', come 'studentesse e studenti' o 'rettrice e rettore'",
"mandf": "inserire entrambe le forme, maschile e femminile, separate dalla congiunzione 'e', come 'studenti e studentesse' o 'rettore e rettrice'",
"group": "utilizzare una forma collettiva neutra che includa entrambi i generi, come 'il corpo studentesco' o 'la dirigenza universitaria'",
"obscuring": "riformulare la frase in modo da evitare l'uso di termini specifici di genere, ad esempio utilizzando il passivo o sostantivi neutri",
"both_fm": "inserire entrambe le forme, femminile e maschile, separate dalla congiunzione 'e' o dallo slash, come 'studentesse/studenti' o 'studentesse e studenti'",
"both_mf": "inserire entrambe le forme, maschile e femminile, separate dalla congiunzione 'e' o dallo slash, come 'studenti/studentesse' o 'studenti e studentesse'"
}
if not os.path.exists(os.path.join(RES_FOLDER, SERVICE)):
os.makedirs(os.path.join(RES_FOLDER, SERVICE))
DB_PATH = os.path.join(RES_FOLDER, SERVICE, f"{MODEL_NAME.replace("/", "-")}{DB_PATH}")
CHATGPT_KEY = os.getenv("CHATGPT_KEY")
GROQ_KEY = os.getenv("GROQ_KEY")
if SERVICE == "groq":
client = Groq(api_key=GROQ_KEY)
elif SERVICE == "openai":
client = OpenAI(api_key=CHATGPT_KEY)
else:
raise ValueError("Invalid SERVICE value. Use 'openai' or 'groq'.")
TABLE_NAME = "records"
train_data = {}
vectorizers = {}
tfidf_train = {}
nn = {}
with open(PROMPT_FILE, "r", encoding="utf-8") as f:
prompt_text = f.read()
conn = sqlite3.connect(DB_PATH)
query = f"""
CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
text_old TEXT,
text_new TEXT,
strategy TEXT,
similarity_score INTEGER,
prompt TEXT,
chatgpt_answer TEXT
)
"""
init_db(conn, TABLE_NAME, TSV_PATH, query)
cur = conn.cursor()
limit_clause = f"LIMIT {STOP_AFTER}" if STOP_AFTER else ""
cur.execute(f"SELECT id, text_old, strategy FROM {TABLE_NAME} WHERE chatgpt_answer IS NULL {limit_clause}")
rows = cur.fetchall()
print(f"🔍 Found {len(rows)} records to process.")
for row in rows:
rec_id, text, strategy = row
print(f"\n🧠 Processing ID {rec_id}...")
if strategy.startswith("["):
strategy_list = eval(strategy)
if "fslashm" and "fandm" in strategy_list:
strategy = "both_fm"
elif "mslashf" and "mandf" in strategy_list:
strategy = "both_mf"
else:
strategy_list = [x for x in strategy_list if x != "group"]
strategy_list = [x for x in strategy_list if x != "obscuring"]
if len(strategy_list) > 0:
strategy = strategy_list[0]
else:
strategy = "obscuring"
prompt_instance = prompt_text.replace("{strategy_description}", strategies[strategy])
prompt_instance = prompt_instance.replace("{sentence_to_rewrite}", text)
answer = get_chatgpt_answer(prompt_instance, client, MODEL_NAME)
if not answer:
print("❌ No response, skipping.")
continue
cur.execute(f"""
UPDATE {TABLE_NAME}
SET chatgpt_answer = ?, prompt = ?
WHERE id = ?
""", (answer, prompt_instance, rec_id))
conn.commit()
print(f"✅ Updated ID {rec_id}")
time.sleep(PAUSE)
print("\n🏁 All done!")
conn.close()