Skip to content

Commit f4cf46e

Browse files
authored
experiment with larger penalties (#19)
* experiment with larger penalties * also penalize hip roll * revert to previous scales * bigger rewards * reduce max speed * update name * misc cleanup * update checkpoint
1 parent 821172c commit f4cf46e

2 files changed

Lines changed: 67 additions & 42 deletions

File tree

assets/ckpt.bin

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:4938765b0e9776e7800557bd118b209034afffc403959f15386b903c82d7d1b6
3-
size 12137848
2+
oid sha256:d68aabccd60572ebd0b3225e41c68da8da948d6d77737875ade95f2dfa21a636
3+
size 12129864

train.py

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -49,45 +49,72 @@
4949

5050

5151
@attrs.define(frozen=True, kw_only=True)
52-
class BentArmPenalty(ksim.Reward):
53-
arm_indices: tuple[int, ...] = attrs.field()
54-
arm_targets: tuple[float, ...] = attrs.field()
52+
class JointPositionPenalty(ksim.JointDeviationPenalty):
53+
@classmethod
54+
def create_from_names(
55+
cls,
56+
names: list[str],
57+
physics_model: ksim.PhysicsModel,
58+
scale: float = -1.0,
59+
scale_by_curriculum: bool = False,
60+
) -> Self:
61+
zeros = {k: v for k, v in ZEROS}
62+
joint_targets = [zeros[name] for name in names]
63+
64+
return cls.create(
65+
physics_model=physics_model,
66+
joint_names=tuple(names),
67+
joint_targets=tuple(joint_targets),
68+
scale=scale,
69+
scale_by_curriculum=scale_by_curriculum,
70+
)
5571

56-
def get_reward(self, trajectory: ksim.Trajectory) -> Array:
57-
qpos = trajectory.qpos[..., self.arm_indices]
58-
qpos_targets = jnp.array(self.arm_targets)
59-
qpos_diff = qpos - qpos_targets
60-
return xax.get_norm(qpos_diff, "l1").mean(axis=-1)
6172

73+
@attrs.define(frozen=True, kw_only=True)
74+
class BentArmPenalty(JointPositionPenalty):
6275
@classmethod
63-
def create(
76+
def create_penalty(
6477
cls,
65-
model: ksim.PhysicsModel,
66-
scale: float,
78+
physics_model: ksim.PhysicsModel,
79+
scale: float = -1.0,
6780
scale_by_curriculum: bool = False,
6881
) -> Self:
69-
qpos_mapping = ksim.get_qpos_data_idxs_by_name(model)
70-
71-
names = [
72-
"dof_right_shoulder_pitch_03",
73-
"dof_right_shoulder_roll_03",
74-
"dof_right_shoulder_yaw_02",
75-
"dof_right_elbow_02",
76-
"dof_right_wrist_00",
77-
"dof_left_shoulder_pitch_03",
78-
"dof_left_shoulder_roll_03",
79-
"dof_left_shoulder_yaw_02",
80-
"dof_left_elbow_02",
81-
"dof_left_wrist_00",
82-
]
82+
return cls.create_from_names(
83+
names=[
84+
"dof_right_shoulder_pitch_03",
85+
"dof_right_shoulder_roll_03",
86+
"dof_right_shoulder_yaw_02",
87+
"dof_right_elbow_02",
88+
"dof_right_wrist_00",
89+
"dof_left_shoulder_pitch_03",
90+
"dof_left_shoulder_roll_03",
91+
"dof_left_shoulder_yaw_02",
92+
"dof_left_elbow_02",
93+
"dof_left_wrist_00",
94+
],
95+
physics_model=physics_model,
96+
scale=scale,
97+
scale_by_curriculum=scale_by_curriculum,
98+
)
8399

84-
zeros = {k: v for k, v in ZEROS}
85-
arm_indices = [qpos_mapping[name][0] for name in names]
86-
arm_targets = [zeros[name] for name in names]
87100

88-
return cls(
89-
arm_indices=tuple(arm_indices),
90-
arm_targets=tuple(arm_targets),
101+
@attrs.define(frozen=True, kw_only=True)
102+
class StraightLegPenalty(JointPositionPenalty):
103+
@classmethod
104+
def create_penalty(
105+
cls,
106+
physics_model: ksim.PhysicsModel,
107+
scale: float = -1.0,
108+
scale_by_curriculum: bool = False,
109+
) -> Self:
110+
return cls.create_from_names(
111+
names=[
112+
"dof_left_hip_roll_03",
113+
"dof_left_hip_yaw_03",
114+
"dof_right_hip_roll_03",
115+
"dof_right_hip_yaw_03",
116+
],
117+
physics_model=physics_model,
91118
scale=scale,
92119
scale_by_curriculum=scale_by_curriculum,
93120
)
@@ -436,16 +463,14 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
436463
ksim.UprightReward(index="x", inverted=False, scale=0.1),
437464
# Normalization penalties.
438465
ksim.ActionInBoundsReward.create(physics_model, scale=0.01),
439-
ksim.ActionSmoothnessPenalty(scale=-0.01),
440-
ksim.ActuatorJerkPenalty(ctrl_dt=self.config.ctrl_dt, scale=-0.001),
441-
ksim.ActuatorRelativeForcePenalty.create(physics_model, scale=-0.001),
442-
ksim.AngularVelocityPenalty(index="x", scale=-0.0005),
443-
ksim.AngularVelocityPenalty(index="y", scale=-0.0005),
444-
ksim.AngularVelocityPenalty(index="z", scale=-0.0005),
445-
ksim.LinearVelocityPenalty(index="y", scale=-0.0005),
446-
ksim.LinearVelocityPenalty(index="z", scale=-0.0005),
466+
ksim.AngularVelocityPenalty(index="x", scale=-0.005),
467+
ksim.AngularVelocityPenalty(index="y", scale=-0.005),
468+
ksim.AngularVelocityPenalty(index="z", scale=-0.005),
469+
ksim.LinearVelocityPenalty(index="y", scale=-0.005),
470+
ksim.LinearVelocityPenalty(index="z", scale=-0.005),
447471
# Bespoke rewards.
448-
BentArmPenalty.create(physics_model, scale=-0.01),
472+
BentArmPenalty.create_penalty(physics_model, scale=-0.1),
473+
StraightLegPenalty.create_penalty(physics_model, scale=-0.01),
449474
]
450475

451476
def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:

0 commit comments

Comments
 (0)