Terms
Description
As of #467, we use jax.lax.platform_dependent + jax.pure_callback as a (temporary) to support both GPUs and CPUs. Once NVIDIA/warp#1446 is merged, we should be able to simplify the logic by only having one implementation for all platforms.
Screenshots
No response
Additional information
No response
Terms
Description
As of #467, we use
jax.lax.platform_dependent+jax.pure_callbackas a (temporary) to support both GPUs and CPUs. Once NVIDIA/warp#1446 is merged, we should be able to simplify the logic by only having one implementation for all platforms.Screenshots
No response
Additional information
No response