Skip to content

[FEATURE] Use warp.jax_callable on CPU #476

Description

@jeertmans

Terms

  • Checked the existing issues to see if my suggestion has not already been suggested;

Description

As of #467, we use jax.lax.platform_dependent + jax.pure_callback as a (temporary) to support both GPUs and CPUs. Once NVIDIA/warp#1446 is merged, we should be able to simplify the logic by only having one implementation for all platforms.

Screenshots

No response

Additional information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions