-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
151 lines (120 loc) · 6.06 KB
/
Copy pathmain.py
File metadata and controls
151 lines (120 loc) · 6.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations
import argparse
import random
import numpy as np
import torch
import gcngrasp
import gcngrasp.utils.lit_extension as lite
from gcngrasp.utils import analysis
from gcngrasp.utils import utils
from gcngrasp.utils.functional import ItemStack
from gcngrasp.utils.loss_fcn import LossFunction, dist_peak
# warnings.filterwarnings("ignore", category=UserWarning)
torch.set_float32_matmul_precision("medium")
class MainModule(lite.VanillaModule):
def __init__(self,
cfg_col: gcngrasp.ConfigCollector,
fold: str):
self.cfg_col = cfg_col
super().__init__(self.cfg_col.train_file)
self.save_hyperparameters("fold")
self.model = self.cfg_col.model_type(self.cfg_col.model_config)
self.loss_fcn = LossFunction(self.config["train"]["loss_weight"],
gripper_info=self.cfg_col.model_config.get("input", {}).get("gripper"))
self.item_stack = {stage: {mode: ItemStack() for mode in ("part", "full")} for stage in ("train", "val", "test")}
def forward(self, stage, batch):
infer_kwargs = dict(objs=batch["obj"], tasks=batch["task"])
# if stage == "test": infer_kwargs["test_time_augment"] = 4
pcd_part, pcd_full, grasp = batch["pcd_part"], batch["pcd_full"], batch["grasp"]
ret_tea = None
if self.training:
pcd_infer = (pcd_full if random.random() < self.config["p_pcd_full"] else pcd_part).clone()
pcd_infer, grasp = utils.random_rotate(pcd_infer, grasp)
ret_stu = self.model(pcd_infer, grasp, **infer_kwargs)
else:
with torch.inference_mode():
ret_stu = self.model(*utils.random_rotate(pcd_part, grasp), **infer_kwargs)
if stage == "test":
ret_tea = self.model(*utils.random_rotate(pcd_full, grasp), **infer_kwargs)
self.item_stack[stage]["full"].push(analysis.dump_prediction(batch, ret_tea["score"]))
if ret_tea.get("affordance"):
dp = dist_peak(ret_tea.get("affordance"), batch)
if len(dp):
self.log(f"{stage}/affF/dist_peak", dp.mean(), batch_size=len(dp), prog_bar=True, on_step=False, on_epoch=True)
# dump prediction
self.item_stack[stage]["part"].push(analysis.dump_prediction(batch, ret_stu["score"].detach()))
if ret_stu.get("affordance"):
dp = dist_peak(ret_stu.get("affordance"), batch)
if len(dp):
self.log(f"{stage}/affP/dist_peak", dp.mean(), batch_size=len(dp), prog_bar=True, on_step=False, on_epoch=True)
if stage != "test": return self.loss_fcn(batch, ret_stu, ret_tea)
def on_fit_start(self):
super().on_fit_start()
self.copy_into_project(self.cfg_col.model_file)
def _on_shared_epoch_end(self, stage: str):
etype = self.hparams["fold"][0]
if etype not in "otc": etype = "o"
metrics_all = {}
# evaluation
for m, e in self.item_stack[stage].items():
if not e: continue
items = e.pop(self.all_gather)
items = items[items[:, -1].long().bool(), :-1].cpu().numpy()
metrics = {}
if stage != "test":
metrics = self.cfg_col.database.evaluate(items, etype=etype)
else:
np.savetxt(self.project / f"test_pred_{m}.txt", items)
for etype in "otc":
ms = self.cfg_col.database.evaluate(items, etype=etype)
for k in list(ms):
if k.startswith("mAP"): ms[f"{k}_{etype}"] = ms.pop(k)
metrics.update(ms)
prefix = f"{stage}/{m}/"
metrics_all.update({prefix + k: v for k, v in metrics.items()})
# fuse metrics for optimal model selection
if stage == "val":
monitor_tog = self.config["train"]["metric"]["monitor"]
monitor_aff = self.trainer.callback_metrics.get("val/affP/dist_peak")
if monitor_aff is not None:
metrics_all[monitor_tog] -= monitor_aff * .01
for k, v in metrics_all.items():
self.log(k, v, prog_bar=True, on_step=False, on_epoch=True)
def _shared_step(self, stage, batch, batch_idx):
ret = self(stage, batch)
if ret:
self.log("/".join([stage, "loss"]), ret["loss"], prog_bar=True, on_step=False, on_epoch=True, batch_size=ret["bs_grasp"])
return ret["loss"]
@classmethod
def from_command(cls, args) -> "MainModule":
org_ckpt = gcngrasp.get_org_ckpt(args.test, args.fold)
if org_ckpt:
# load origin GCNGrasp-v1
cfg_col = gcngrasp.ConfigCollector.from_version("1")
module = cls(cfg_col, args.fold)
module.model.load_org_ckpt(org_ckpt)
else:
# initialize model
p = args.resume or args.test
cfg_col = gcngrasp.ConfigCollector.from_path(p) if p else gcngrasp.ConfigCollector.from_version(args.train)
if p and args.fold is None: args.fold = cfg_col.fold
module = cls(cfg_col, args.fold)
return module
def parse_args():
parser = argparse.ArgumentParser()
mode = parser.add_mutually_exclusive_group(required=True)
mode.add_argument("--train", type=str, help="model version trained from scratch")
mode.add_argument("--resume", type=str, help="resume model path")
mode.add_argument("--test", type=str, help="test model path")
parser.add_argument("--fold", type=str, help="data fold, e.g. o0, t1")
parser.add_argument("--devices", type=int, nargs="+", help="GPU devices")
parser.add_argument("--test-bs-div", type=int, default=1, help="test batch size divider")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
utils.wait_for_cuda(args.devices, sleep=60, re_pattern="python")
# main
module = MainModule.from_command(args)
module.cfg_col.database.pcds
func = module.cfg_col.func_main_loop(args, module.trainer_kwargs())
func(module)