|
49 | 49 |
|
50 | 50 |
|
51 | 51 | @attrs.define(frozen=True, kw_only=True) |
52 | | -class BentArmPenalty(ksim.Reward): |
53 | | - arm_indices: tuple[int, ...] = attrs.field() |
54 | | - arm_targets: tuple[float, ...] = attrs.field() |
| 52 | +class JointPositionPenalty(ksim.JointDeviationPenalty): |
| 53 | + @classmethod |
| 54 | + def create_from_names( |
| 55 | + cls, |
| 56 | + names: list[str], |
| 57 | + physics_model: ksim.PhysicsModel, |
| 58 | + scale: float = -1.0, |
| 59 | + scale_by_curriculum: bool = False, |
| 60 | + ) -> Self: |
| 61 | + zeros = {k: v for k, v in ZEROS} |
| 62 | + joint_targets = [zeros[name] for name in names] |
| 63 | + |
| 64 | + return cls.create( |
| 65 | + physics_model=physics_model, |
| 66 | + joint_names=tuple(names), |
| 67 | + joint_targets=tuple(joint_targets), |
| 68 | + scale=scale, |
| 69 | + scale_by_curriculum=scale_by_curriculum, |
| 70 | + ) |
55 | 71 |
|
56 | | - def get_reward(self, trajectory: ksim.Trajectory) -> Array: |
57 | | - qpos = trajectory.qpos[..., self.arm_indices] |
58 | | - qpos_targets = jnp.array(self.arm_targets) |
59 | | - qpos_diff = qpos - qpos_targets |
60 | | - return xax.get_norm(qpos_diff, "l1").mean(axis=-1) |
61 | 72 |
|
| 73 | +@attrs.define(frozen=True, kw_only=True) |
| 74 | +class BentArmPenalty(JointPositionPenalty): |
62 | 75 | @classmethod |
63 | | - def create( |
| 76 | + def create_penalty( |
64 | 77 | cls, |
65 | | - model: ksim.PhysicsModel, |
66 | | - scale: float, |
| 78 | + physics_model: ksim.PhysicsModel, |
| 79 | + scale: float = -1.0, |
67 | 80 | scale_by_curriculum: bool = False, |
68 | 81 | ) -> Self: |
69 | | - qpos_mapping = ksim.get_qpos_data_idxs_by_name(model) |
70 | | - |
71 | | - names = [ |
72 | | - "dof_right_shoulder_pitch_03", |
73 | | - "dof_right_shoulder_roll_03", |
74 | | - "dof_right_shoulder_yaw_02", |
75 | | - "dof_right_elbow_02", |
76 | | - "dof_right_wrist_00", |
77 | | - "dof_left_shoulder_pitch_03", |
78 | | - "dof_left_shoulder_roll_03", |
79 | | - "dof_left_shoulder_yaw_02", |
80 | | - "dof_left_elbow_02", |
81 | | - "dof_left_wrist_00", |
82 | | - ] |
| 82 | + return cls.create_from_names( |
| 83 | + names=[ |
| 84 | + "dof_right_shoulder_pitch_03", |
| 85 | + "dof_right_shoulder_roll_03", |
| 86 | + "dof_right_shoulder_yaw_02", |
| 87 | + "dof_right_elbow_02", |
| 88 | + "dof_right_wrist_00", |
| 89 | + "dof_left_shoulder_pitch_03", |
| 90 | + "dof_left_shoulder_roll_03", |
| 91 | + "dof_left_shoulder_yaw_02", |
| 92 | + "dof_left_elbow_02", |
| 93 | + "dof_left_wrist_00", |
| 94 | + ], |
| 95 | + physics_model=physics_model, |
| 96 | + scale=scale, |
| 97 | + scale_by_curriculum=scale_by_curriculum, |
| 98 | + ) |
83 | 99 |
|
84 | | - zeros = {k: v for k, v in ZEROS} |
85 | | - arm_indices = [qpos_mapping[name][0] for name in names] |
86 | | - arm_targets = [zeros[name] for name in names] |
87 | 100 |
|
88 | | - return cls( |
89 | | - arm_indices=tuple(arm_indices), |
90 | | - arm_targets=tuple(arm_targets), |
| 101 | +@attrs.define(frozen=True, kw_only=True) |
| 102 | +class StraightLegPenalty(JointPositionPenalty): |
| 103 | + @classmethod |
| 104 | + def create_penalty( |
| 105 | + cls, |
| 106 | + physics_model: ksim.PhysicsModel, |
| 107 | + scale: float = -1.0, |
| 108 | + scale_by_curriculum: bool = False, |
| 109 | + ) -> Self: |
| 110 | + return cls.create_from_names( |
| 111 | + names=[ |
| 112 | + "dof_left_hip_roll_03", |
| 113 | + "dof_left_hip_yaw_03", |
| 114 | + "dof_right_hip_roll_03", |
| 115 | + "dof_right_hip_yaw_03", |
| 116 | + ], |
| 117 | + physics_model=physics_model, |
91 | 118 | scale=scale, |
92 | 119 | scale_by_curriculum=scale_by_curriculum, |
93 | 120 | ) |
@@ -436,16 +463,14 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]: |
436 | 463 | ksim.UprightReward(index="x", inverted=False, scale=0.1), |
437 | 464 | # Normalization penalties. |
438 | 465 | ksim.ActionInBoundsReward.create(physics_model, scale=0.01), |
439 | | - ksim.ActionSmoothnessPenalty(scale=-0.01), |
440 | | - ksim.ActuatorJerkPenalty(ctrl_dt=self.config.ctrl_dt, scale=-0.001), |
441 | | - ksim.ActuatorRelativeForcePenalty.create(physics_model, scale=-0.001), |
442 | | - ksim.AngularVelocityPenalty(index="x", scale=-0.0005), |
443 | | - ksim.AngularVelocityPenalty(index="y", scale=-0.0005), |
444 | | - ksim.AngularVelocityPenalty(index="z", scale=-0.0005), |
445 | | - ksim.LinearVelocityPenalty(index="y", scale=-0.0005), |
446 | | - ksim.LinearVelocityPenalty(index="z", scale=-0.0005), |
| 466 | + ksim.AngularVelocityPenalty(index="x", scale=-0.005), |
| 467 | + ksim.AngularVelocityPenalty(index="y", scale=-0.005), |
| 468 | + ksim.AngularVelocityPenalty(index="z", scale=-0.005), |
| 469 | + ksim.LinearVelocityPenalty(index="y", scale=-0.005), |
| 470 | + ksim.LinearVelocityPenalty(index="z", scale=-0.005), |
447 | 471 | # Bespoke rewards. |
448 | | - BentArmPenalty.create(physics_model, scale=-0.01), |
| 472 | + BentArmPenalty.create_penalty(physics_model, scale=-0.1), |
| 473 | + StraightLegPenalty.create_penalty(physics_model, scale=-0.01), |
449 | 474 | ] |
450 | 475 |
|
451 | 476 | def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]: |
|
0 commit comments