[WIP] perf(kernel): Optimize MoE prefill GEMMs for gfx950#464
Draft
Max191 wants to merge 59 commits into
Draft
Conversation
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
This reverts commit eb323bf. Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Squashes the fragile early partial SliceN/LDS edits into the first all-shape-correct source state equivalent to c0623c9. It includes the preshuffled W LDS path, BN256-over-BN128 SliceN compatibility, X reuse across halves, duplicate W copy removal, split W descriptors, two-buffer local prefetch, and the future-copy-before-MFMA wait placement. Benchmark evidence: E011 measured rocprof kernel deltas versus E010 of -35.121 us nonpreshuffled dispatch, -24.141 us nonpreshuffled combine, -13.880 us preshuffled dispatch, and -18.320 us preshuffled combine. E021 later measured the BN256 SliceN route improving 2048 preshuffled dispatch by 76.281 us versus E018. Validation: py_compile and git diff --check passed here. Fresh all-token correctness for this foundation source passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants: 12 results, max_abs_diff 0.03125.
Records E337 item 1 separately: host-preshuffled W still stages HBM -> LDS -> VGPR, not direct VGPR bypass. Benchmark evidence: E337/E348 aggregate is 3310.554 us versus 6390.386 us at branch base (1.93x). No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records E337 item 2 separately: a 256-wide execution tile consumes two 128-wide packed W halves through USE_SLICE_N. Benchmark evidence: E021 improved 2048 preshuffled dispatch by 76.281 us versus E018. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records E337 item 3 separately: one X tile feeds both N halves in the BN256 SliceN CTA. Benchmark evidence: E021 reduced preshuffled dispatch VMEM reads by 22.71% and LDS instructions by 21.57%. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records E337 item 4 separately: the SliceN path fills the needed W LDS tile once instead of copying duplicate data. Benchmark evidence: E099 reduced preshuffled combine total instructions by 28.64% versus E095. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records the early dependency for E337 item 5; the later source commit installs the final shared-view W local-load path. Benchmark evidence is in the E136 source commit. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records E337 item 6 separately: top and bottom W descriptors address adjacent packed 128-wide halves for BN256 SliceN. Benchmark evidence: E099 confirmed the BN256 combine specialization and improved the priority preshuffled total by 52.361 us. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records the early dependency for E337 item 7; the later E133 source commit installs logical output N handling. Benchmark evidence is in that commit. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records E337 item 8 separately: the foundation uses the retained two-buffer local-prefetch structure. Benchmark evidence: E011 hot-loop waitcnt cycles dropped sharply versus E010 and preshuffled combine improved by 18.320 us. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Records E337 item 9 separately: future async copy is issued before current MFMA and the wait leaves later work in flight. Benchmark evidence: E011 measured all four GEMM variants faster than E010 by 13.880-35.121 us. No tree change beyond the verified foundation commit. Fresh all-token correctness for the foundation passed tokens=256..8192, top_k=4, both variants, max_abs_diff 0.03125.
Applies the v6-style two-iteration unroll with alternating operand register sets for the two-buffer local-prefetch loop. Benchmark evidence: E017 measured the 2048-token preshuffled GEMM pair at dispatch 435.165 us and combine 254.523 us, 689.688 us total; this was 40.480 us faster than E016 (-5.54%) and removed 65 static v_mov_b32_e32 instructions from each preshuffled GEMM. Validation: py_compile and git diff --check passed here; saved E017 correctness passed tokens=16 and 2048, top_k=4, both variants.
Promotes the preshuffled default route to BLOCK_N=256 with USE_SLICE_N over two 128-wide packed W tiles. Benchmark evidence: E021 measured 2048-token preshuffled dispatch at 360.004 us versus 436.285 us in E018 (-17.48%). Preshuffled dispatch+combine improved from 688.208 us to 622.847 us (-9.50%). Validation: py_compile and git diff --check passed here; saved E021 correctness passed tokens=16 and 2048, top_k=4, both variants.
Changes the SliceN local-load order to issue both W-half local loads before the X local load at each staging point. Benchmark evidence: E033 measured 2048-token preshuffled dispatch 343.964 us and combine 241.603 us, 585.567 us total; this improved the E031 preshuffled total by 6.679 us (-1.13%) and reduced dispatch LDS latency cycles by 15.66% with unchanged instruction counts. Validation: py_compile and git diff --check passed here; saved E033 correctness passed tokens=16 and 2048, top_k=4, both variants.
Threads K_ITERS from the Python launcher into the Gluon kernel so the hot loop can specialize the fixed GPT-OSS K tile count. Benchmark evidence: E053 measured 2048-token preshuffled dispatch 340.683 us and combine 240.642 us, 581.325 us total; this was 3.881 us faster than E047 (-0.66%) and reduced preshuffled dispatch VGPR from 216 to 164. Validation: py_compile and git diff --check passed here; saved E053 correctness passed tokens=16 and 2048, top_k=4, both variants.
Uses the actual block-schedule M-block count loaded from block_offs instead of iterating the padded upper bound and guarding empty tiles. Benchmark evidence: E085 measured the 2048-token preshuffled pair at 337.843 us dispatch and 236.203 us combine, 574.046 us total; this was 4.560 us faster than E080 (-0.79%). Same-period repeat showed a 7.238 us total win (-1.20%). Validation: py_compile and git diff --check passed here; saved E085 correctness passed tokens=16 and 2048, top_k=4, both variants.
Adds the top-only tail path so invalid lower N halves skip lower-half local loads and MFMAs. Benchmark evidence: E095 measured 2048-token preshuffled dispatch 335.604 us and combine 234.242 us, 569.846 us total; this was 9.200 us faster than E094 (-1.59%) and removed lower-half dispatch MFMA work on the invalid tail. Validation: py_compile and git diff --check passed here; saved E095 correctness passed tokens=16 and 2048, top_k=4, both variants.
Promotes the preshuffled combine path to the same BN256 SliceN execution strategy used by dispatch, keeping the 128-wide packed W layout. Benchmark evidence: E099 measured 2048-token preshuffled combine 186.361 us versus 234.242 us in E095 (-47.881 us, -20.44%). Preshuffled dispatch+combine improved from 569.846 us to 517.485 us (-9.19%). Validation: py_compile and git diff --check passed here; saved E099 correctness passed tokens=16 and 2048, top_k=4, both variants.
Uses the retained wide-N SwiGLU split layout, size_per_thread=[1,8] and threads_per_warp=[4,16], for the dispatch epilogue. Benchmark evidence: E106 measured 2048-token preshuffled dispatch 322.083 us versus 336.723 us in E101 (-14.640 us, -4.35%), with the preshuffled pair improving 516.765 us -> 501.845 us (-2.89%). Validation: py_compile and git diff --check passed here; saved E106 correctness passed tokens=16 and 2048, top_k=4, both variants.
Adds the non-SwiGLU combine store-layout variant used by later route-specific store selections. Benchmark evidence: E113 same-period timing measured combine 170.522 us versus 173.962 us in E112 (-3.440 us, -1.98%) and restored combine VGPR from 172 to 168, while total timing was effectively neutral. Validation: py_compile and git diff --check passed here; saved E113 correctness passed tokens=16 and 2048, top_k=4, both variants.
Replaces the dispatch quantization divide with reciprocal multiply in the SwiGLU epilogue. Benchmark evidence: E123 same-period timing measured dispatch 313.024 us average versus 319.585 us in the E113 control (-6.560 us, -2.05%), with dispatch+combine improving by 6.037 us (-1.23%). Validation: py_compile and git diff --check passed here; saved E123 correctness passed tokens=16 and 2048, top_k=4, both variants.
Keeps padded preshuffled W N for tiling and W strides while allocating and storing only the caller-visible logical N for combine outputs. Benchmark evidence: E133 is GEMM-kernel neutral by design, but wrapper timing removed the padded-output slice/contiguous copy: combine wrapper median improved to 295.804 us versus recent 306.083/310.063 us baselines. Validation: py_compile and git diff --check passed here; saved E133 correctness passed tokens=16 and 2048, top_k=4, both variants.
Changes the preshuffled W LDS local-load path to use the packed dot operand view directly instead of explicit load/reshape/transpose conversion. Benchmark evidence: E136 measured the 2048-token preshuffled pair at 483.724 us median versus 485.245 us in the E135 rerun (-1.521 us). Static instruction counts stayed unchanged, confirming this is primarily a layout-path cleanup. Validation: py_compile and git diff --check passed here; saved E136 correctness passed tokens=16 and 2048, top_k=4, both variants.
Splits the SliceN future async copies into separate top and bottom groups so the bottom half can be waited independently. Benchmark evidence: E143 repeat measured dispatch 307.203 us and combine 167.402 us, 474.605 us total, the best measured sum in that sequence. ATT loop cycles improved versus E138: dispatch 27516 -> 24672 and combine 25540 -> 23372. Validation: py_compile and git diff --check passed here; saved E143 correctness passed tokens=16 and 2048, top_k=4, both variants.
Passes N_CONST/Y_N_CONST through the launcher so preshuffled padded-N and logical-output-N cases specialize instead of recomputing N-derived values dynamically. Benchmark evidence: E156 measured 2048-token preshuffled dispatch 305.963 us and combine 168.002 us, 473.965 us total; dispatch dynamic instructions dropped by 438024 with unchanged VMEM/LDS counts. Validation: py_compile and git diff --check passed here; saved E156 correctness passed tokens=16 and 2048, top_k=4, both variants.
Reorders the combine epilogue so gate scaling happens before the store conversion/layout path. Benchmark evidence: E203 repeat measured 2048-token preshuffled dispatch 306.244 us and combine 159.562 us, 465.806 us total. The improvement came from the non-hot-loop combine epilogue path, with all reported diffs exactly 0.0. Validation: py_compile and git diff --check passed here; saved E203 correctness passed tokens=16 and 2048 against the Triton reference for both variants.
Moves the future W-scale async copy ahead of W-top in the top SliceN copy group while keeping the local-load/MFMA/wait structure unchanged. Benchmark evidence: E208 measured 2048-token preshuffled sums of 460.445 us and 462.765 us in two passes, improving the E206 repeat by 1.360-3.680 us on median sum. Validation: py_compile and git diff --check passed here; saved E208 correctness passed tokens=16 and 2048 against the Triton reference for both variants.
Suppresses unnecessary X tail mask work in the preshuffled path while keeping the same tiling and schedule. Benchmark evidence: E217 repeat measured 2048-token preshuffled total 458.924 us versus 459.485 us in E215 repeat (-0.561 us). Dynamic instructions dropped by 68816 for dispatch and 269280 for combine with unchanged VMEM/LDS counts. Validation: py_compile and git diff --check passed here; saved E217 correctness passed tokens=16 and 2048 against the Triton reference for both variants.
Adds the retained X source-row swizzle and viewed local-load path for preshuffled BM64, with matching helpers for the other BM tiers. Benchmark evidence: E244 measured 2048-token preshuffled dispatch 299.523 us and combine 155.722 us, 455.245 us total; this improved E217 repeat by 3.679 us and reduced LDS bank conflicts to 0 for dispatch and combine. Validation: py_compile and git diff --check passed here; saved E244 correctness passed tokens=16 and 2048 against the Triton reference for both variants.
Adopts the retained hot-loop order: local_load X_next, compute current bottom, issue W_bottom_next, local_load W_top_next, wait bottom group, then local_load W_bottom_next. Benchmark evidence: E251 final measured 2048-token preshuffled dispatch 298.723 us and combine 155.161 us, 453.884 us total, with VGPR 148/148 and zero LDS bank conflicts retained. Validation: py_compile and git diff --check passed here; saved E251 correctness passed tokens=16 and 2048 for both variants with all max absolute diffs exactly 0.0.
Adds the viewed local-load helpers and matching source-row permutations for BM32 and BM128 preshuffled SliceN paths, extending the E244 conflict-free X LDS strategy beyond BM64. Benchmark evidence: E279 kept the 512-token BM32 route correctness-clean with 512 focused timing 341.124 us total; E289/E291 promoted BM128 large-shape routes with 4096 dispatch improving by 22.520 us and 8192 combine improving by 41.281 us in all-token timing. Validation: py_compile and git diff --check passed here; saved E279/E289/E291 correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Adds cache_modifier support to WPreshuffledLdsDescriptor and passes W_CACHE_MODIFIER into preshuffled W async copies. With the existing BLOCK_M<=32 policy, this enables .cg for BM16/BM32 preshuffled W routes. Benchmark evidence: E326 measured all-token sums improving to 284.003 us at 256 tokens (-30.520 us vs E300) and 310.084 us at 512 tokens (-30.039 us), with unchanged instruction counts and lower VMEM latency on focused counters. Validation: py_compile and git diff --check passed here; saved E326 correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Adds slice_size-driven BM16, BM32, and BM128 route tier selection for preshuffled dispatch/combine, while preserving the BM64 2048 checkpoint route. Benchmark evidence: E277 improved the 256-token route by about 5 us with BM16 nonpersistent selection; E279 targeted 512 with BM32; E289 promoted 4096 dispatch to BM128 with a 22.520 us all-token improvement; E291 promoted 8192 combine to BM128 with a 41.281 us all-token improvement. Validation: py_compile and git diff --check passed here; saved E277/E279/E289/E291 correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Adds the explicit group_m/xcd_swizzle selections for the retained preshuffled prefill routes: 256 BM16, 512 BM32, 1024 BM64, 2048 BM64, 4096 BM64/BM128, and 8192 BM128. Benchmark evidence: E299 improved the 256-token all-token sum by about 8 us; E298 improved the 512 focused route by 1.321 us; E300 improved 1024 dispatch by about 2.8 us; E294 improved 4096 combine by 4.001 us in the focused guard; E296 improved 1024 combine by 2.480 us; E291 kept 8192 combine on the verified BM128 path. Validation: py_compile and git diff --check passed here; this source state hash-matches saved E326 (acda2f12713640714424cd49bc1963a89c8be866676fa3e6d4fb8085b5915bc9), whose correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Replaces the broad BLOCK_M-derived cache choice with W_CACHE_CG plumbing so BM16/BM32 keep .cg and only the 1024 BM64 route opts in. The 2048 and 4096 BM64 checkpoint routes keep the empty cache modifier. Benchmark evidence: E328 measured 1024 focused timing 339.604 us versus the 358.243 us reference (-18.639 us), while 2048 stayed in band at 455.125 us and 4096 avoided the rejected broad-BM64 .cg regression. Validation: py_compile and git diff --check passed here; saved E328 correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Adds STORE_LAYOUT_VARIANT=2 and selects the [1,16]/[8,8] store layout for the 4096-token BM64/BN256 SliceN combine route. Benchmark evidence: E336 same-period focused timing measured 4096 combine 216.602 us versus 218.643 us for the E335-control (-2.041 us), while restoring 4096 combine VGPR from 156 to 148. Validation: py_compile and git diff --check passed here; saved E336 correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Applies the same [1,16]/[8,8] combine store variant to the 8192-token BM128/BN256 SliceN combine route while keeping the 4096 route from E336 and the 2048 checkpoint route unchanged. Benchmark evidence: E337 same-period focused timing measured 8192 combine 335.164 us versus 344.204 us for the E336-control (-9.040 us), lowering BM128 combine VGPR from 240 to 228. Fused-context E348 measured aggregate 3310.554 us versus 6390.386 us at branch base, a 1.93x speedup. Validation: py_compile and git diff --check passed here; saved E337 correctness passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Evidence-only split marker. E337 item 18 is an environment/compiler requirement rather than a kernel source delta: the optimized path assumes TRITON_ENABLE_LLIR_SCHED=1 after the scheduler validity fixes. Benchmark evidence: E337/E348 measured the final fused-context aggregate at 3310.554 us versus 6390.386 us at branch base, a 1.93x speedup. Validation: no tree change. Fresh all-token correctness was run across every source-changing commit in this verified branch lineage; all tested states passed tokens=256,512,1024,2048,4096,8192, top_k=4, both variants.
Individual marker for E337 item 19. Benchmark evidence: E244 reduced 2048 dispatch/combine LDS bank conflicts to zero and improved the preshuffled median sum by 3.679 us versus E217 repeat. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 21. Benchmark evidence: E251 final retained zero LDS bank conflicts with 2048 dispatch 298.723 us, combine 155.161 us, VGPR 148/148. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 27. Benchmark evidence: E326 used this plumbing to improve the 256 and 512 all-token sums by about 30 us each versus E300. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 28. Benchmark evidence: E326 focused timing measured 256 sum 283.883 us and 512 sum 307.683 us, about 30 us faster than no-cache controls. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 29. Benchmark evidence: E328 focused timing measured 1024 sum 339.604 us versus 358.243 us reference, a 18.639 us win. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 30. Benchmark evidence: E328 kept 2048 in band at 455.125 us while avoiding the rejected broad-BM64 .cg regression. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 31: BM16/BN256 SliceN, nonpersistent, group_m=16, xcd_swizzle=8. Benchmark evidence: E348 fused-context dispatch is 196.042 us. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 32: BM16/BN256 SliceN, nonpersistent, group_m=16, xcd_swizzle=8. Benchmark evidence: E299 measured combine g16/xcd8 at 109.722 us versus 114.161 us auto. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 33: BM32/BN256 SliceN persistent dispatch. Benchmark evidence: E298 measured 512 dispatch g32/xcd4 at 217.763 us versus 219.122 us g4/xcd8. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 34: BM32/BN256 SliceN persistent combine. Benchmark evidence: E348 fused-context combine is 114.041 us. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 35: BM64/BN256 SliceN persistent dispatch, group_m=32, xcd_swizzle=8, W .cg. Benchmark evidence: E300 measured g32/xcd8 2.761 us faster than auto. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 36: BM64/BN256 SliceN persistent combine, group_m=4, xcd_swizzle=8, W .cg. Benchmark evidence: E296 measured g4/xcd8 2.480 us faster than g4/xcd4. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 37: BM64/BN256 SliceN persistent CU-count dispatch, VGPR 148, scratch 0. Benchmark evidence: E337 focused guard measured 296.443 us; guard repeat measured 293.963 us. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 39: BM128/BN256 SliceN dispatch, persistent. Benchmark evidence: E289 improved 4096 dispatch by 22.520 us; E337 focused timing measured 447.685 us. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 40: BM64/BN256 SliceN combine, group_m=16, xcd_swizzle=4, store-layout variant 2. Benchmark evidence: E336 improved same-period 4096 combine by 2.041 us. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 41: BM128/BN256 SliceN dispatch, persistent. Benchmark evidence: E337 focused timing measured 722.728 us; E348 fused-context timing measured 763.128 us. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
Individual marker for E337 item 42: BM128/BN256 SliceN combine, group_m=16, xcd_swizzle=4, store-layout variant 2. Benchmark evidence: E337 improved same-period 8192 combine by 9.040 us and lowered VGPR 240 -> 228. No tree change; this marker inherits the already-verified final source tree. The corrected branch sweep passed all 27 tree-changing commits for tokens=256,512,1024,2048,4096,8192, top_k=4, both variants, with 0 failures.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Benchmark Summary
Below are comparison tables for different relevant configurations at Po2 sequence lengths from 256-8192 with
topk=4. The different configurations include:The key results to compare are no LLIR yes preshuffle and AITER (vLLM). Enabling LLIR can happen later if needed, but is not part of this PR, so the preshuffled version without LLIR is what we get from this PR.
Overall, the optimized kernel beats AITER on
seqlen >= 1024, and comes close for smaller sequence lengths.Dispatch GEMM
Combine GEMM
Full Fused MoE
Test Plan
The included
test_gluon_moe_gemm_gfx950.pytest was used to verify numerical correctness of the kernel.