-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_nbv.py
More file actions
123 lines (102 loc) · 4.43 KB
/
Copy pathmain_nbv.py
File metadata and controls
123 lines (102 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations
import argparse
import colorama
import pandas as pd
import sklearn.metrics as skm
import trimesh
from logdir import LogDir
from gcngrasp.api import Path
from gcngrasp.system import process_config
from gcngrasp.system.scene import Scene
from gcngrasp.system.system import System
from gcngrasp.system.virtual_robot import VirtualRobot
def run(system: System,
logdir: LogDir,
data_path: Path,
max_iter: int,
seeds: list[int],
init_frames: list[int] = None):
# prepare
logdir.copy(CONFIG, CONFIG.name)
for robot in VirtualRobot.load_all_instances(data_path, check_attr=True):
scene_name = robot.root.name
print(colorama.Fore.BLUE + colorama.Style.BRIGHT + scene_name)
results = {}
# robot.config["init_frames"] = None
for i in (init_frames or robot.config["init_frames"] or list(range(len(robot)))):
for s in seeds:
trial = f"{scene_name}/init{i}/seed{s}"
robot.set_cur_frame(i, seed=s)
metrics = pd.DataFrame(columns=["AP", "pre_greedy"])
# iterate
scene = Scene(system.config, **robot.kwargs_scene)
for j in range(max_iter):
folder = f"{trial}/iter{j}"
# forward
frame = robot.get_frame()
if frame.Tcw is scene.Tcw: break
frame.metadata["seed"] = s
scene.integrate(frame)
act = system.planner.solve(scene) if j + 1 != max_iter else None
# save
scene.save_info(
folder=logdir.pdir(folder),
tog=True, traj=True, field=not system.cfg_nbv_proxy
)
# evaluate
tog = scene.tog_poses[0]
y_true = robot.tog_label[robot.tog_label.nearest_search(tog)].score
y_score = tog.score
logdir.save_data((tog, y_true), f"{folder}/tog_and_true.pkl")
metrics.loc[j] = [
skm.average_precision_score(y_true, y_score).item() if y_true.any() else 0.,
y_true[0]
]
# execute
if not act: break
robot.move(act.Teb)
# save metrics
print(colorama.Fore.CYAN + f"{trial}: " + colorama.Style.RESET_ALL + str(metrics.to_dict(orient='list')))
for j in range(len(metrics), max_iter):
metrics.loc[j] = metrics.iloc[-1]
results[trial] = metrics
# aggregate
results["avg"] = sum(results.values()) / len(results)
results = {k: v.to_dict(orient='list') for k, v in results.items()}
logdir.save_data(results, f"{scene_name}/results.yaml")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--name", default="gcngrasp-vp", help="project name")
parser.add_argument("--config", type=str, default="config/system/params.yaml", help="config file")
parser.add_argument("--data-path", type=str, default="./dataset", help="data directory")
parser.add_argument("--max-iter", type=int, default=4, help="max iteration")
parser.add_argument("--n-seed", type=int, default=4, help="number of seeds")
parser.add_argument("--init-frames", type=int, nargs="+", help="init frames")
parser.add_argument("--vis-path", type=str, help="visualize directory")
args = parser.parse_args()
CONFIG = Path(args.config)
if args.vis_path:
for f in Path(args.vis_path).rglob("visualize.glb"):
print(f)
scene = trimesh.load(f)
scene.show()
else:
cfg_planner = dict(
border_margin=.05,
max_stride=.50,
min_stride=.10,
n_candidates=50,
dist_obs=0.,
elev_range=None,
)
cfg = process_config(CONFIG, with_api=True)
cfg["planner"].update(cfg_planner)
system = System(cfg, cfg_nbv_proxy=None if args.name == "gcngrasp-vp" else cfg_planner)
logdir = LogDir(args.name, rootdir="runs/nbv_logs")
run(
system, logdir,
data_path=Path(args.data_path),
max_iter=args.max_iter,
seeds=list(range(args.n_seed)),
init_frames=args.init_frames,
)