Skip to content

Commit 165634b

Browse files
committed
Initial commit
0 parents  commit 165634b

8 files changed

Lines changed: 788 additions & 0 deletions

File tree

.gitmodules

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[submodule "third_party/triton"]
2+
path = third_party/triton
3+
url = https://github.com/triton-lang/triton
4+
[submodule "third_party/llvm-project"]
5+
path = third_party/llvm-project
6+
url = https://github.com/llvm/llvm-project
7+
[submodule "third_party/pybind11"]
8+
path = third_party/pybind11
9+
url = https://github.com/pybind/pybind11
10+
[submodule "third_party/argparse"]
11+
path = third_party/argparse
12+
url = https://github.com/p-ranav/argparse

CMakeLists.txt

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
cmake_minimum_required(VERSION 3.28)
2+
3+
project(tritonc CXX)
4+
5+
set(CMAKE_CXX_STANDARD 17)
6+
7+
set(TRITON_CODEGEN_BACKENDS "nvidia")
8+
set(TRITON_BUILD_PYTHON_MODULE ON)
9+
set(TRITON_BUILD_PROTON OFF)
10+
set(PYTHON_INCLUDE_DIRS ON)
11+
12+
add_subdirectory(third_party/argparse)
13+
add_subdirectory(third_party/triton)
14+
15+
add_executable(tritonc src/main.cpp)
16+
target_link_libraries(tritonc PRIVATE argparse)
17+
18+
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
19+
20+
set(TRITON_LIBRARIES
21+
${triton_libs}
22+
23+
# mlir
24+
MLIRAMDGPUDialect
25+
MLIRNVVMDialect
26+
MLIRNVVMToLLVMIRTranslation
27+
MLIRGPUToNVVMTransforms
28+
MLIRGPUToGPURuntimeTransforms
29+
MLIRGPUTransforms
30+
MLIRIR
31+
MLIRControlFlowToLLVM
32+
MLIRBytecodeWriter
33+
MLIRPass
34+
MLIRTransforms
35+
MLIRLLVMDialect
36+
MLIRSupport
37+
MLIRTargetLLVMIRExport
38+
MLIRMathToLLVM
39+
MLIRROCDLToLLVMIRTranslation
40+
MLIRGPUDialect
41+
MLIRSCFToControlFlow
42+
MLIRIndexToLLVM
43+
MLIRGPUToROCDLTransforms
44+
45+
# LLVM
46+
LLVMPasses
47+
LLVMNVPTXCodeGen
48+
# LLVMNVPTXAsmPrinter
49+
LLVMAMDGPUCodeGen
50+
LLVMAMDGPUAsmParser
51+
52+
# Nvidia specific
53+
TritonNVIDIAGPUToLLVM NVGPUToLLVM MLIRNVGPUToNVVM
54+
MLIRNVVMDialect MLIRNVGPUDialect
55+
56+
)
57+
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64
58+
CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") # macOS arm64
59+
list(APPEND TRITON_LIBRARIES
60+
LLVMAArch64CodeGen
61+
LLVMAArch64AsmParser
62+
)
63+
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
64+
list(APPEND TRITON_LIBRARIES
65+
LLVMX86CodeGen
66+
LLVMX86AsmParser
67+
)
68+
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le")
69+
list(APPEND TRITON_LIBRARIES
70+
LLVMPowerPCAsmParser
71+
LLVMPowerPCCodeGen
72+
)
73+
else ()
74+
message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture (${CMAKE_SYSTEM_PROCESSOR}) is not configured in cmake lib dependencies.")
75+
endif ()
76+
77+
target_link_libraries(tritonc PRIVATE ${TRITON_LIBRARIES})
78+
79+
# triton doesn't use targets in the best practise way... (sigh)
80+
# hence we have to do this...
81+
set(TRITON_INCLUDE_DIRS
82+
"third_party/triton/include"
83+
"third_party/triton/third_party/nvidia/include"
84+
)
85+
set(TRITON_GENERATED_INCLUDE_DIRS
86+
"${CMAKE_BINARY_DIR}/third_party/triton/include"
87+
"${CMAKE_BINARY_DIR}/third_party/triton/third_party"
88+
)
89+
90+
target_include_directories(tritonc PRIVATE ${TRITON_INCLUDE_DIRS})
91+
target_include_directories(tritonc PRIVATE ${TRITON_GENERATED_INCLUDE_DIRS})

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# tritonc
2+
3+
Your standalone commandline triton compiler.
4+
Write your triton kernels directly in MLIR and compile it to ptx with this handy tool without ever touching python.
5+
6+
## Example:
7+
8+
### add_kernel.ttir
9+
10+
```mlir
11+
module {
12+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
13+
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
14+
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
15+
%arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
16+
%c1024_i32 = arith.constant 1024 : i32
17+
%0 = tt.get_program_id x : i32
18+
%1 = arith.muli %0, %c1024_i32 : i32
19+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
20+
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
21+
%4 = arith.addi %3, %2 : tensor<1024xi32>
22+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32>
23+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
24+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
25+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
26+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
27+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
28+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
29+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
30+
%13 = arith.addf %9, %12 : tensor<1024xf32>
31+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
32+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
33+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
34+
tt.return
35+
}
36+
}
37+
```
38+
39+
### Commandline
40+
```commandline
41+
tritonc add_kernel.ttir --compute-capability 89 --num-stages 3 --num-warps 4 -o out.ptx
42+
```

add_kernel.ttir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module {
2+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
3+
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
4+
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
5+
%arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
6+
%c1024_i32 = arith.constant 1024 : i32
7+
%0 = tt.get_program_id x : i32
8+
%1 = arith.muli %0, %c1024_i32 : i32
9+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
10+
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
11+
%4 = arith.addi %3, %2 : tensor<1024xi32>
12+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32>
13+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
14+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
15+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
16+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
17+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
18+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
19+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
20+
%13 = arith.addf %9, %12 : tensor<1024xf32>
21+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
22+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
23+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
24+
tt.return
25+
}
26+
}

0 commit comments

Comments
 (0)