-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprompt.py
More file actions
137 lines (112 loc) · 4.87 KB
/
Copy pathprompt.py
File metadata and controls
137 lines (112 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from helper import utils, loader, metric
import torch
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from openprompt.prompts import ManualTemplate
from openprompt.prompts import ManualVerbalizer
from openprompt import PromptForClassification
writer = SummaryWriter()
def run(device, seed, dataset, args_scale, args_fewshot, args_batch_size, args_learning_rate, args_training_epoch,
args_mode, args_plm):
utils.setup_seed(seed)
# There are two classes in Clickbait Detection, one for normal and one for clickbait
classes = [
"normal",
"clickbait"
]
plm_config = loader.choose_pretrained_model(args_plm, load=args_mode)
plm, tokenizer, model_config, WrapperClass = plm_config.plm, plm_config.tokenizer, \
plm_config.model_config, plm_config.WrapperClass
collate_fn = utils.set_collate_fn(tokenizer, device)
promptTemplate = ManualTemplate(
text='{"placeholder":"text_a"} 这是 {"mask"} 标题',
tokenizer=plm_config.tokenizer
)
promptVerbalizer = ManualVerbalizer(
classes=classes,
label_words={
"normal": ["准确", "合适", "合规", "规范", "合格", "高质量"],
"clickbait": ["标题党", "误导", "诱导", "诱饵", "暗示", "失实"],
},
tokenizer=tokenizer,
)
promptModel = PromptForClassification(
template=promptTemplate,
plm=plm,
verbalizer=promptVerbalizer,
)
plm_config.promptTemplate = promptTemplate
promptModel.to(device)
print(promptModel.template.text)
train_loader, dev_loader, test_loader, title_loader = loader.generate_dataloader(
args_scale, args_fewshot, dataset, args_batch_size, plm_config, collate_fn, load="pepl"
)
optimizer = AdamW(promptModel.parameters(), lr=args_learning_rate)
weights = [2.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion.to(device)
epoch = args_training_epoch # recommend 2-4
for exec_index in range(epoch):
promptModel.train()
loss_cum = 0
ind = 0
for batch in train_loader:
for k in batch.keys():
batch[k] = batch[k].to(device)
logits = promptModel(batch)
labels = batch["label"]
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_cum += loss.item()
if ind % 5 == 0:
print(ind, loss_cum / (ind + 1))
writer.add_scalar(str(exec_index) + " Exec Loss", loss_cum/(ind+1), ind)
ind += 1
loader.save_model(os.path.join("checkpoints", "bert_prompt_" + str(exec_index) + ".pt"), exec_index,
promptModel, optimizer)
test(device, args_mode, args_plm, promptModel, test_loader, title_loader)
def test(device, args_mode, args_plm, promptModel, test_loader, title_loader):
promptModel.eval()
correct, total = 0, 0
types_number = 2
matrix = [[0 for _ in range(types_number)] for _ in range(types_number)]
cnt = 0
ans_list = []
for _, (title, input_ids, attention_mask, token_type_ids, labels) in enumerate(title_loader):
ans_list.append([title, labels])
with torch.no_grad():
result_save_path = os.path.join("results", str(args_mode) + "_" + str(args_plm) + "_case_res.txt")
with open(result_save_path, "w", encoding="utf-8") as case_f:
for batch in test_loader:
title, th_label =ans_list[cnt]
for k in batch.keys():
batch[k] = batch[k].to(device)
logits = promptModel(batch)
preds = torch.argmax(logits, dim=-1)
labels = batch["label"]
if th_label != labels[0].tolist():
print("error")
case_f.writelines([str(labels[0].tolist()), "\t", str(logits[0].tolist()), "\t", title[0], "\n"])
cor, tot = metric.accuracy(preds, labels)
correct += cor
total += tot
cur_mat = metric.multi_label_metric(preds, labels, types_number)
for j in range(types_number):
for k in range(types_number):
matrix[j][k] += cur_mat[j][k]
cnt += 1
print("Accuracy: " + str(correct / total))
print("Confusion Matrix")
for i in range(types_number):
cur = []
for j in range(types_number):
cur.append(str(matrix[i][j]))
print(" ".join(cur))
print("marco_Precision: " + str(metric.cal_marco_Pre(matrix)))
print("marco_Recall: " + str(metric.cal_marco_Rec(matrix)))
print("marco_F1: " + str(metric.cal_marco_F1(matrix)))
return correct / total