An implementation of Tree-Attention in PyTorch because it's in JAX for some reason
python3 model.pyMIT
- Implement flash attention from the native official repo, I couldn't because the docs are nowhere to be found and understood
| Name | Name | Last commit date | ||
|---|---|---|---|---|