|
17 | 17 | import optax |
18 | 18 | import xax |
19 | 19 | from jaxtyping import Array, PRNGKeyArray |
20 | | -from kscale.web.gen.api import JointMetadataOutput |
| 20 | +from kscale.web.gen.api import RobotURDFMetadataOutput |
21 | 21 |
|
22 | 22 | NUM_JOINTS = 20 |
23 | 23 | NUM_ACTOR_INPUTS = 49 |
@@ -368,21 +368,23 @@ def get_mujoco_model(self) -> mujoco.MjModel: |
368 | 368 | mjcf_path = asyncio.run(ksim.get_mujoco_model_path("kbot", name="robot")) |
369 | 369 | return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene="smooth") |
370 | 370 |
|
371 | | - def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> dict[str, JointMetadataOutput]: |
| 371 | + def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> RobotURDFMetadataOutput: |
372 | 372 | metadata = asyncio.run(ksim.get_mujoco_model_metadata("kbot")) |
373 | 373 | if metadata.joint_name_to_metadata is None: |
374 | 374 | raise ValueError("Joint metadata is not available") |
375 | | - return metadata.joint_name_to_metadata |
| 375 | + if metadata.actuator_type_to_metadata is None: |
| 376 | + raise ValueError("Actuator metadata is not available") |
| 377 | + return metadata |
376 | 378 |
|
377 | 379 | def get_actuators( |
378 | 380 | self, |
379 | 381 | physics_model: ksim.PhysicsModel, |
380 | | - metadata: dict[str, JointMetadataOutput] | None = None, |
| 382 | + metadata: RobotURDFMetadataOutput | None = None, |
381 | 383 | ) -> ksim.Actuators: |
382 | 384 | assert metadata is not None, "Metadata is required" |
383 | 385 | return ksim.MITPositionActuators( |
384 | 386 | physics_model=physics_model, |
385 | | - joint_name_to_metadata=metadata, |
| 387 | + metadata=metadata, |
386 | 388 | ) |
387 | 389 |
|
388 | 390 | def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]: |
|
0 commit comments