Mohamed Mekkouri
commited on
Commit
Β·
95d28ad
1
Parent(s):
51250cb
commit evtn
Browse files- CMakeLists.txt +104 -0
- README.md +104 -0
- build.toml +22 -8
- cmake/compile-metal.cmake +86 -0
- cmake/metallib_to_header.py +73 -0
- cmake/utils.cmake +557 -0
- flake.lock +169 -0
- flake.nix +1 -1
- gptoss_kernels/CMakeLists.txt +0 -191
- gptoss_kernels/__init__.py +0 -6
- gptoss_kernels/examples/chat.py +0 -104
- gptoss_kernels/examples/generate.py +0 -34
- gptoss_kernels/source/context.c +0 -1115
- gptoss_kernels/source/generate.c +0 -317
- gptoss_kernels/source/include/internal/log.h +7 -0
- gptoss_kernels/source/include/internal/metal.h +0 -1
- gptoss_kernels/source/matmul.metal +8 -2
- gptoss_kernels/source/metal.m +0 -1
- gptoss_kernels/source/model.c +0 -581
- gptoss_kernels/source/tensor_wrappers.cpp +77 -0
- gptoss_kernels/source/tokenizer.c +0 -106
- pyproject.toml +10 -0
- setup.py +118 -0
- {gptoss_kernels/test β test}/bf16-f32-embeddings.cc +0 -0
- {gptoss_kernels/test β test}/embeddings-kernel-tester.hpp +0 -0
- {gptoss_kernels/test β test}/f32-bf16w-matmul.cc +0 -0
- {gptoss_kernels/test β test}/f32-bf16w-rmsnorm.cc +0 -0
- {gptoss_kernels/test β test}/f32-random.cc +0 -0
- {gptoss_kernels/test β test}/f32-rope.cc +0 -0
- {gptoss_kernels/test β test}/fill-random-kernel-tester.hpp +0 -0
- {gptoss_kernels/test β test}/matmul-kernel-tester.hpp +0 -0
- {gptoss_kernels/test β test}/mf4-f32-convert.cc +0 -0
- {gptoss_kernels/test β test}/rmsnorm-kernel-tester.hpp +0 -0
- {gptoss_kernels/test β test}/rope-kernel-tester.hpp +0 -0
- {gptoss_kernels/test β test}/u32-random.cc +0 -0
- torch-ext/gptoss_kernels/__init__.py +8 -0
- torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc +0 -0
- torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc +0 -0
- torch-ext/gptoss_kernels/_gptoss_kernels_931bc1b_dirty.abi3.so +3 -0
- torch-ext/gptoss_kernels/_ops.py +9 -0
- torch-ext/gptoss_kernels/test.py +6 -0
- torch-ext/registration.h +30 -0
- torch-ext/torch_binding.cpp +10 -0
- torch-ext/torch_binding.h +5 -0
CMakeLists.txt
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.26)
|
| 2 |
+
project(gptoss_kernels LANGUAGES CXX)
|
| 3 |
+
|
| 4 |
+
set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")
|
| 5 |
+
|
| 6 |
+
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
| 7 |
+
|
| 8 |
+
include(FetchContent)
|
| 9 |
+
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
|
| 10 |
+
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
| 11 |
+
|
| 12 |
+
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
| 13 |
+
|
| 14 |
+
if(DEFINED Python3_EXECUTABLE)
|
| 15 |
+
# Allow passing through the interpreter (e.g. from setup.py).
|
| 16 |
+
find_package(Python3 COMPONENTS Development Development.SABIModule Interpreter)
|
| 17 |
+
if (NOT Python3_FOUND)
|
| 18 |
+
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
| 19 |
+
endif()
|
| 20 |
+
else()
|
| 21 |
+
find_package(Python3 REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
|
| 22 |
+
endif()
|
| 23 |
+
|
| 24 |
+
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
| 25 |
+
|
| 26 |
+
find_package(Torch REQUIRED)
|
| 27 |
+
|
| 28 |
+
add_compile_definitions(METAL_KERNEL)
|
| 29 |
+
|
| 30 |
+
# Initialize list for Metal shader sources
|
| 31 |
+
set(ALL_METAL_SOURCES)
|
| 32 |
+
#get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
|
| 33 |
+
#list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
|
| 34 |
+
|
| 35 |
+
set(TORCH_gptoss_kernels_SRC
|
| 36 |
+
torch-ext/torch_binding.cpp torch-ext/torch_binding.h
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
list(APPEND SRC "${TORCH_gptoss_kernels_SRC}")
|
| 41 |
+
set(gptoss_kernels_SRC
|
| 42 |
+
"gptoss_kernels/source/accumulate.metal"
|
| 43 |
+
"gptoss_kernels/source/expert_routing_metadata.metal"
|
| 44 |
+
"gptoss_kernels/source/metal.m"
|
| 45 |
+
"gptoss_kernels/source/scatter.metal"
|
| 46 |
+
"gptoss_kernels/source/topk.metal"
|
| 47 |
+
"gptoss_kernels/source/embeddings.metal"
|
| 48 |
+
"gptoss_kernels/source/metal-kernels.c"
|
| 49 |
+
"gptoss_kernels/source/random.metal"
|
| 50 |
+
"gptoss_kernels/source/sdpa.metal"
|
| 51 |
+
"gptoss_kernels/source/matmul.metal"
|
| 52 |
+
"gptoss_kernels/source/rmsnorm.metal"
|
| 53 |
+
"gptoss_kernels/source/sample.metal"
|
| 54 |
+
"gptoss_kernels/source/moematmul.metal"
|
| 55 |
+
"gptoss_kernels/source/convert.metal"
|
| 56 |
+
"gptoss_kernels/source/rope.metal"
|
| 57 |
+
"gptoss_kernels/source/gather_and_accumulate.metal"
|
| 58 |
+
"gptoss_kernels/source/tensor_wrappers.cpp"
|
| 59 |
+
"gptoss_kernels/source/log.c"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Separate Metal shader files from other sources
|
| 63 |
+
set(gptoss_kernels_METAL_SRC)
|
| 64 |
+
set(gptoss_kernels_CPP_SRC)
|
| 65 |
+
|
| 66 |
+
foreach(src_file IN LISTS gptoss_kernels_SRC)
|
| 67 |
+
if(src_file MATCHES "\\.(metal|h)$")
|
| 68 |
+
list(APPEND gptoss_kernels_METAL_SRC ${src_file})
|
| 69 |
+
else()
|
| 70 |
+
list(APPEND gptoss_kernels_CPP_SRC ${src_file})
|
| 71 |
+
endif()
|
| 72 |
+
endforeach()
|
| 73 |
+
|
| 74 |
+
# TODO: check if CLion support this:
|
| 75 |
+
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
|
| 76 |
+
set_source_files_properties(
|
| 77 |
+
${gptoss_kernels_CPP_SRC}
|
| 78 |
+
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/gptoss_kernels/source/include;${CMAKE_SOURCE_DIR}/gptoss_kernels/include;${CMAKE_SOURCE_DIR}/.")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Add C++ sources to main source list
|
| 82 |
+
list(APPEND SRC "${gptoss_kernels_CPP_SRC}")
|
| 83 |
+
|
| 84 |
+
# Keep track of Metal sources for later compilation
|
| 85 |
+
if(gptoss_kernels_METAL_SRC)
|
| 86 |
+
list(APPEND ALL_METAL_SOURCES "${gptoss_kernels_METAL_SRC}")
|
| 87 |
+
endif()
|
| 88 |
+
# Include Metal shader compilation utilities
|
| 89 |
+
include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
|
| 90 |
+
|
| 91 |
+
define_gpu_extension_target(
|
| 92 |
+
_gptoss_kernels_931bc1b_dirty
|
| 93 |
+
DESTINATION _gptoss_kernels_931bc1b_dirty
|
| 94 |
+
LANGUAGE ${GPU_LANG}
|
| 95 |
+
SOURCES ${SRC}
|
| 96 |
+
COMPILE_FLAGS ${GPU_FLAGS}
|
| 97 |
+
ARCHITECTURES ${GPU_ARCHES}
|
| 98 |
+
USE_SABI 3
|
| 99 |
+
WITH_SOABI)
|
| 100 |
+
|
| 101 |
+
# Compile Metal shaders if any were found
|
| 102 |
+
if(ALL_METAL_SOURCES)
|
| 103 |
+
compile_metal_shaders(_gptoss_kernels_931bc1b_dirty "${ALL_METAL_SOURCES}")
|
| 104 |
+
endif()
|
README.md
CHANGED
|
@@ -8,3 +8,107 @@ tags:
|
|
| 8 |
|
| 9 |
This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
|
| 10 |
|
| 11 |
+
```21:69:/Users/medmekk/projects/ai/kernels/gpt-oss/gpt_oss/metal/source/matmul.metal
|
| 12 |
+
kernel void gptoss_f32_bf16w_matmul(
|
| 13 |
+
constant gptoss_matmul_args& args [[ buffer(0) ]],
|
| 14 |
+
const device float4* input [[ buffer(1) ]],
|
| 15 |
+
const device bfloat4* weight [[ buffer(2) ]],
|
| 16 |
+
const device bfloat* bias [[ buffer(3) ]],
|
| 17 |
+
device float* output [[ buffer(4) ]],
|
| 18 |
+
const device gptoss_control* control [[ buffer(5) ]],
|
| 19 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 20 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 21 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 22 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 23 |
+
{
|
| 24 |
+
const uint simdgroup_size = 32;
|
| 25 |
+
if (control->abort != 0) {
|
| 26 |
+
return;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
const uint num_column_vecs = args.num_column_vecs;
|
| 30 |
+
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
| 31 |
+
|
| 32 |
+
input += gid.y * num_column_vecs + simdgroup_tid;
|
| 33 |
+
weight += num_column_vecs * row + simdgroup_tid;
|
| 34 |
+
bias += row;
|
| 35 |
+
output += gid.y * args.num_rows + row;
|
| 36 |
+
|
| 37 |
+
uint num_iter = 0;
|
| 38 |
+
num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 39 |
+
|
| 40 |
+
float4 sum4 = 0.0f;
|
| 41 |
+
do {
|
| 42 |
+
const bfloat4 w = *weight;
|
| 43 |
+
const float4 i = *input;
|
| 44 |
+
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
|
| 45 |
+
|
| 46 |
+
weight += simdgroup_size;
|
| 47 |
+
input += simdgroup_size;
|
| 48 |
+
} while (--num_iter != 0);
|
| 49 |
+
const float2 sum2 = sum4.xy + sum4.zw;
|
| 50 |
+
float sum = sum2.x + sum2.y;
|
| 51 |
+
sum = metal::simd_sum(sum);
|
| 52 |
+
if (metal::simd_is_first()) {
|
| 53 |
+
sum += static_cast<float>(*bias);
|
| 54 |
+
if (args.add) {
|
| 55 |
+
*output += sum;
|
| 56 |
+
} else {
|
| 57 |
+
*output = sum;
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### What it computes
|
| 64 |
+
- Computes Y = X Β· W + b for a batch of tokens.
|
| 65 |
+
- Types/layout:
|
| 66 |
+
- X is float32, shape [num_tokens, num_cols], viewed as `float4` vectors β num_column_vecs = num_cols/4.
|
| 67 |
+
- W is bfloat16, shape [num_rows, num_cols], viewed as `bfloat4` vectors per row (row-major).
|
| 68 |
+
- b is bfloat16, length num_rows.
|
| 69 |
+
- Y is float32, shape [num_tokens, num_rows].
|
| 70 |
+
|
| 71 |
+
### Work decomposition
|
| 72 |
+
- Grid Y (gid.y) = token index t in [0, num_tokens).
|
| 73 |
+
- Grid X (gid.x) spans output rows in groups of `num_simdgroups`. Within a threadgroup:
|
| 74 |
+
- simdgroup_idx in [0, num_simdgroups) selects one output row r.
|
| 75 |
+
- Therefore row r = gid.x*num_simdgroups + simdgroup_idx.
|
| 76 |
+
- Each simdgroup computes exactly one scalar Y[t, r].
|
| 77 |
+
- Lanes inside a simdgroup (simdgroup_tid in [0, 31]) split the K dimension (num_column_vecs vectors) in a strided pattern: lane β processes indices β, β+32, β+64, ...
|
| 78 |
+
|
| 79 |
+
### Pointer setup per simdgroup
|
| 80 |
+
- input points to the start of token tβs vector-list, then lane offset: `input += t*num_column_vecs + lane`.
|
| 81 |
+
- weight points to the start of row rβs vector-list, then lane offset: `weight += r*num_column_vecs + lane`.
|
| 82 |
+
- bias points to `b[r]`. output points to `Y[t, r]`.
|
| 83 |
+
|
| 84 |
+
### Inner loop
|
| 85 |
+
- num_iter = ceil((num_column_vecs - lane)/32). Each lane loops over its share of the K/4 vectors.
|
| 86 |
+
- On each iteration:
|
| 87 |
+
- Load one `float4` from input (4 consecutive columns) and one `bfloat4` from weight for row r.
|
| 88 |
+
- Fused multiply-add into `sum4` (vector-wise).
|
| 89 |
+
- Advance both pointers by 32 vectors (next stripe for this lane).
|
| 90 |
+
- After loop, reduce the 4 lanes of `sum4` into a scalar: sum4.xy + sum4.zw β sum2, then sum2.x + sum2.y.
|
| 91 |
+
- Reduce across all 32 lanes with `metal::simd_sum(sum)`.
|
| 92 |
+
- Lane 0 adds bias[r] and writes `Y[t, r]` (add or overwrite depending on args.add).
|
| 93 |
+
|
| 94 |
+
### Example mapping (num_tokens=2, num_cols=128, num_rows=4, threadgroup_size=32)
|
| 95 |
+
- num_column_vecs = 128/4 = 32.
|
| 96 |
+
- threadgroup_size=32 β num_simdgroups=1 β each threadgroup computes 1 row.
|
| 97 |
+
- Grid:
|
| 98 |
+
- gid.y β {0,1} (two tokens).
|
| 99 |
+
- gid.x β {0,1,2,3} (four rows).
|
| 100 |
+
- For a given (gid.x, gid.y), simdgroup_idx=0 computes one output scalar Y[t=gid.y, r=gid.x].
|
| 101 |
+
- Per-lane work:
|
| 102 |
+
- lane β loads X[t, cols 4β..4β+3] as float4 and W[r, cols 4β..4β+3] as bfloat4 (exactly one iteration since 32 vectors total).
|
| 103 |
+
- Each lane accumulates the dot over its 4 elements; lanes are then summed β full dot(X_row, W_row).
|
| 104 |
+
- Lane 0 writes Y[t, r] = dot + b[r].
|
| 105 |
+
|
| 106 |
+
### Example mapping (same shapes, threadgroup_size=64)
|
| 107 |
+
- num_simdgroups=64/32=2 β each threadgroup computes 2 rows at once.
|
| 108 |
+
- For gid.x=k:
|
| 109 |
+
- simdgroup_idx=0 computes row r=2k, simdgroup_idx=1 computes row r=2k+1.
|
| 110 |
+
- Lanes split K identically; two output scalars are produced per threadgroup (one per simdgroup).
|
| 111 |
+
|
| 112 |
+
### Which piece each unit owns
|
| 113 |
+
- Token t (grid y) Γ Row r (simdgroup within grid x) β one output scalar Y[t, r].
|
| 114 |
+
- Lane β within that simdgroup β partial dot over columns {4β, 4β+1, 4β+2, 4β+3}, plus any further stripes {4(β+32), ...} if num_cols > 128.
|
build.toml
CHANGED
|
@@ -9,15 +9,29 @@ src = [
|
|
| 9 |
]
|
| 10 |
|
| 11 |
[kernel.gptoss_kernels]
|
|
|
|
| 12 |
depends = ["torch"]
|
| 13 |
-
backend = "
|
| 14 |
|
| 15 |
src = [
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
]
|
|
|
|
|
|
|
|
|
| 9 |
]
|
| 10 |
|
| 11 |
[kernel.gptoss_kernels]
|
| 12 |
+
|
| 13 |
depends = ["torch"]
|
| 14 |
+
backend = "metal"
|
| 15 |
|
| 16 |
src = [
|
| 17 |
+
"gptoss_kernels/source/accumulate.metal",
|
| 18 |
+
"gptoss_kernels/source/expert_routing_metadata.metal",
|
| 19 |
+
"gptoss_kernels/source/metal.m",
|
| 20 |
+
"gptoss_kernels/source/scatter.metal",
|
| 21 |
+
"gptoss_kernels/source/topk.metal",
|
| 22 |
+
"gptoss_kernels/source/embeddings.metal",
|
| 23 |
+
"gptoss_kernels/source/metal-kernels.c",
|
| 24 |
+
"gptoss_kernels/source/random.metal",
|
| 25 |
+
"gptoss_kernels/source/sdpa.metal",
|
| 26 |
+
"gptoss_kernels/source/matmul.metal",
|
| 27 |
+
"gptoss_kernels/source/rmsnorm.metal",
|
| 28 |
+
"gptoss_kernels/source/sample.metal",
|
| 29 |
+
"gptoss_kernels/source/moematmul.metal",
|
| 30 |
+
"gptoss_kernels/source/convert.metal",
|
| 31 |
+
"gptoss_kernels/source/rope.metal",
|
| 32 |
+
"gptoss_kernels/source/gather_and_accumulate.metal",
|
| 33 |
+
"gptoss_kernels/source/tensor_wrappers.cpp",
|
| 34 |
+
"gptoss_kernels/source/log.c",
|
| 35 |
]
|
| 36 |
+
|
| 37 |
+
include = ["gptoss_kernels/source/include", "gptoss_kernels/include", "."]
|
cmake/compile-metal.cmake
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Metal shader compilation function
|
| 2 |
+
function(compile_metal_shaders TARGET_NAME METAL_SOURCES)
|
| 3 |
+
# Find the Metal compiler
|
| 4 |
+
find_program(METAL_COMPILER xcrun REQUIRED)
|
| 5 |
+
|
| 6 |
+
# Set Metal compiler flags
|
| 7 |
+
set(METAL_FLAGS "-std=metal3.0" "-O2")
|
| 8 |
+
|
| 9 |
+
# Output directory for compiled metallib
|
| 10 |
+
set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")
|
| 11 |
+
file(MAKE_DIRECTORY ${METALLIB_OUTPUT_DIR})
|
| 12 |
+
|
| 13 |
+
# Separate .metal files from .h files and compile .metal files to .air
|
| 14 |
+
set(AIR_FILES)
|
| 15 |
+
set(METAL_FILES)
|
| 16 |
+
set(HEADER_FILES)
|
| 17 |
+
|
| 18 |
+
foreach(SOURCE_FILE ${METAL_SOURCES})
|
| 19 |
+
if(SOURCE_FILE MATCHES "\\.metal$")
|
| 20 |
+
list(APPEND METAL_FILES ${SOURCE_FILE})
|
| 21 |
+
elseif(SOURCE_FILE MATCHES "\\.h$")
|
| 22 |
+
list(APPEND HEADER_FILES ${SOURCE_FILE})
|
| 23 |
+
endif()
|
| 24 |
+
endforeach()
|
| 25 |
+
|
| 26 |
+
foreach(METAL_FILE ${METAL_FILES})
|
| 27 |
+
get_filename_component(METAL_NAME ${METAL_FILE} NAME_WE)
|
| 28 |
+
set(AIR_FILE "${CMAKE_BINARY_DIR}/${METAL_NAME}.air")
|
| 29 |
+
|
| 30 |
+
# Include header files as dependencies
|
| 31 |
+
set(ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE})
|
| 32 |
+
foreach(HEADER_FILE ${HEADER_FILES})
|
| 33 |
+
list(APPEND ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${HEADER_FILE})
|
| 34 |
+
endforeach()
|
| 35 |
+
|
| 36 |
+
add_custom_command(
|
| 37 |
+
OUTPUT ${AIR_FILE}
|
| 38 |
+
COMMAND ${METAL_COMPILER} -sdk macosx metal ${METAL_FLAGS}
|
| 39 |
+
-c ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE}
|
| 40 |
+
-o ${AIR_FILE}
|
| 41 |
+
DEPENDS ${ALL_DEPENDENCIES}
|
| 42 |
+
COMMENT "Compiling Metal shader ${METAL_FILE} to ${AIR_FILE}"
|
| 43 |
+
VERBATIM
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
list(APPEND AIR_FILES ${AIR_FILE})
|
| 47 |
+
endforeach()
|
| 48 |
+
|
| 49 |
+
# Link all .air files into a single .metallib
|
| 50 |
+
set(METALLIB_FILE "${METALLIB_OUTPUT_DIR}/${TARGET_NAME}.metallib")
|
| 51 |
+
add_custom_command(
|
| 52 |
+
OUTPUT ${METALLIB_FILE}
|
| 53 |
+
COMMAND ${METAL_COMPILER} -sdk macosx metallib ${AIR_FILES}
|
| 54 |
+
-o ${METALLIB_FILE}
|
| 55 |
+
DEPENDS ${AIR_FILES}
|
| 56 |
+
COMMENT "Linking Metal library ${METALLIB_FILE}"
|
| 57 |
+
VERBATIM
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Generate C++ header with embedded metallib data
|
| 61 |
+
set(METALLIB_HEADER "${CMAKE_BINARY_DIR}/${TARGET_NAME}_metallib.h")
|
| 62 |
+
set(METALLIB_TO_HEADER_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/cmake/metallib_to_header.py")
|
| 63 |
+
|
| 64 |
+
add_custom_command(
|
| 65 |
+
OUTPUT ${METALLIB_HEADER}
|
| 66 |
+
COMMAND ${Python_EXECUTABLE} ${METALLIB_TO_HEADER_SCRIPT} ${METALLIB_FILE} ${METALLIB_HEADER} ${TARGET_NAME}
|
| 67 |
+
DEPENDS ${METALLIB_FILE} ${METALLIB_TO_HEADER_SCRIPT}
|
| 68 |
+
COMMENT "Generating embedded Metal library header ${METALLIB_HEADER}"
|
| 69 |
+
VERBATIM
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Create a custom target for the metallib
|
| 73 |
+
add_custom_target(${TARGET_NAME}_metallib ALL DEPENDS ${METALLIB_FILE} ${METALLIB_HEADER})
|
| 74 |
+
|
| 75 |
+
# Add dependency to main target
|
| 76 |
+
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_metallib)
|
| 77 |
+
|
| 78 |
+
# Add the generated header to include directories
|
| 79 |
+
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_BINARY_DIR})
|
| 80 |
+
|
| 81 |
+
# Pass the metallib header and namespace as compile definitions
|
| 82 |
+
target_compile_definitions(${TARGET_NAME} PRIVATE
|
| 83 |
+
EMBEDDED_METALLIB_HEADER="${TARGET_NAME}_metallib.h"
|
| 84 |
+
EMBEDDED_METALLIB_NAMESPACE=${TARGET_NAME}_metal
|
| 85 |
+
)
|
| 86 |
+
endfunction()
|
cmake/metallib_to_header.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def convert_metallib_to_header(metallib_path: str, header_path: str, target_name: str) -> None:
|
| 6 |
+
"""Convert a metallib binary file to a C++ header with embedded data."""
|
| 7 |
+
|
| 8 |
+
# Read the metallib binary data
|
| 9 |
+
with open(metallib_path, 'rb') as f:
|
| 10 |
+
data: bytes = f.read()
|
| 11 |
+
|
| 12 |
+
# Generate the header content
|
| 13 |
+
header_content: str = """// Auto-generated file containing embedded Metal library
|
| 14 |
+
#pragma once
|
| 15 |
+
#include <cstddef>
|
| 16 |
+
#include <Metal/Metal.h>
|
| 17 |
+
|
| 18 |
+
namespace """ + target_name + """_metal {
|
| 19 |
+
static const unsigned char metallib_data[] = {
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# Convert binary data to C array format
|
| 23 |
+
bytes_per_line: int = 16
|
| 24 |
+
for i in range(0, len(data), bytes_per_line):
|
| 25 |
+
chunk: bytes = data[i:i + bytes_per_line]
|
| 26 |
+
hex_values: str = ', '.join('0x{:02x}'.format(b) for b in chunk)
|
| 27 |
+
header_content += " " + hex_values + ","
|
| 28 |
+
if i + bytes_per_line < len(data):
|
| 29 |
+
header_content += "\n"
|
| 30 |
+
|
| 31 |
+
header_content += """
|
| 32 |
+
};
|
| 33 |
+
static const size_t metallib_data_len = """ + str(len(data)) + """;
|
| 34 |
+
|
| 35 |
+
// Convenience function to create Metal library from embedded data
|
| 36 |
+
inline id<MTLLibrary> createLibrary(id<MTLDevice> device, NSError** error = nullptr) {
|
| 37 |
+
dispatch_data_t libraryData = dispatch_data_create(
|
| 38 |
+
metallib_data,
|
| 39 |
+
metallib_data_len,
|
| 40 |
+
dispatch_get_main_queue(),
|
| 41 |
+
^{ /* No cleanup needed for static data */ });
|
| 42 |
+
|
| 43 |
+
NSError* localError = nil;
|
| 44 |
+
id<MTLLibrary> library = [device newLibraryWithData:libraryData error:&localError];
|
| 45 |
+
|
| 46 |
+
if (error) {
|
| 47 |
+
*error = localError;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
return library;
|
| 51 |
+
}
|
| 52 |
+
} // namespace """ + target_name + """_metal
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# Write the header file
|
| 56 |
+
dir_path: str = os.path.dirname(header_path)
|
| 57 |
+
if dir_path:
|
| 58 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 59 |
+
with open(header_path, 'w') as f:
|
| 60 |
+
f.write(header_content)
|
| 61 |
+
|
| 62 |
+
print("Generated {} ({} bytes)".format(header_path, len(data)))
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
if len(sys.argv) != 4:
|
| 66 |
+
print("Usage: metallib_to_header.py <metallib_path> <header_path> <target_name>")
|
| 67 |
+
sys.exit(1)
|
| 68 |
+
|
| 69 |
+
metallib_path: str = sys.argv[1]
|
| 70 |
+
header_path: str = sys.argv[2]
|
| 71 |
+
target_name: str = sys.argv[3]
|
| 72 |
+
|
| 73 |
+
convert_metallib_to_header(metallib_path, header_path, target_name)
|
cmake/utils.cmake
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Vendored from vLLM:
|
| 2 |
+
#
|
| 3 |
+
# https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
|
| 4 |
+
#
|
| 5 |
+
# Attempt to find the python package that uses the same python executable as
|
| 6 |
+
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
|
| 7 |
+
#
|
| 8 |
+
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
| 9 |
+
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
| 10 |
+
set(Python3_EXECUTABLE ${EXECUTABLE})
|
| 11 |
+
find_package(Python3 COMPONENTS Interpreter Development.Module Development.SABIModule)
|
| 12 |
+
if (NOT Python3_FOUND)
|
| 13 |
+
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
| 14 |
+
endif()
|
| 15 |
+
set(_VER "${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}")
|
| 16 |
+
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
|
| 17 |
+
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
|
| 18 |
+
message(FATAL_ERROR
|
| 19 |
+
"Python version (${_VER}) is not one of the supported versions: "
|
| 20 |
+
"${_SUPPORTED_VERSIONS_LIST}.")
|
| 21 |
+
endif()
|
| 22 |
+
message(STATUS "Found python matching: ${EXECUTABLE}.")
|
| 23 |
+
endmacro()
|
| 24 |
+
|
| 25 |
+
#
|
| 26 |
+
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
|
| 27 |
+
# has trailing whitespace stripped. If an error is encountered when running
|
| 28 |
+
# python, a fatal message `ERR_MSG` is issued.
|
| 29 |
+
#
|
| 30 |
+
function (run_python OUT EXPR ERR_MSG)
|
| 31 |
+
execute_process(
|
| 32 |
+
COMMAND
|
| 33 |
+
"${Python3_EXECUTABLE}" "-c" "${EXPR}"
|
| 34 |
+
OUTPUT_VARIABLE PYTHON_OUT
|
| 35 |
+
RESULT_VARIABLE PYTHON_ERROR_CODE
|
| 36 |
+
ERROR_VARIABLE PYTHON_STDERR
|
| 37 |
+
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
| 38 |
+
|
| 39 |
+
if(NOT PYTHON_ERROR_CODE EQUAL 0)
|
| 40 |
+
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
|
| 41 |
+
endif()
|
| 42 |
+
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
| 43 |
+
endfunction()
|
| 44 |
+
|
| 45 |
+
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
| 46 |
+
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
| 47 |
+
macro (append_cmake_prefix_path PKG EXPR)
|
| 48 |
+
run_python(_PREFIX_PATH
|
| 49 |
+
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
|
| 50 |
+
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
|
| 51 |
+
endmacro()
|
| 52 |
+
|
| 53 |
+
#
|
| 54 |
+
# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
|
| 55 |
+
# of CUDA source files. The names of the corresponding "hipified" sources are
|
| 56 |
+
# stored in `OUT_SRCS`.
|
| 57 |
+
#
|
| 58 |
+
function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
| 59 |
+
#
|
| 60 |
+
# Split into C++ and non-C++ (i.e. CUDA) sources.
|
| 61 |
+
#
|
| 62 |
+
set(NODUP_SRCS ${ORIG_SRCS})
|
| 63 |
+
list(REMOVE_DUPLICATES NODUP_SRCS)
|
| 64 |
+
set(SRCS ${NODUP_SRCS})
|
| 65 |
+
set(CXX_SRCS ${NODUP_SRCS})
|
| 66 |
+
list(FILTER SRCS INCLUDE REGEX "\.cu$")
|
| 67 |
+
list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
|
| 68 |
+
|
| 69 |
+
#
|
| 70 |
+
# Generate ROCm/HIP source file names from CUDA file names.
|
| 71 |
+
# Since HIP files are generated code, they will appear in the build area
|
| 72 |
+
# `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
|
| 73 |
+
#
|
| 74 |
+
set(HIP_SRCS)
|
| 75 |
+
foreach (SRC ${SRCS})
|
| 76 |
+
get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
|
| 77 |
+
get_source_file_property(compile_options "${SRC}" COMPILE_OPTIONS)
|
| 78 |
+
string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
|
| 79 |
+
string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
|
| 80 |
+
|
| 81 |
+
if(include_dirs)
|
| 82 |
+
# Copy over include directories from the original CUDA file.
|
| 83 |
+
set_source_files_properties(
|
| 84 |
+
${SRC}
|
| 85 |
+
PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
|
| 86 |
+
endif()
|
| 87 |
+
|
| 88 |
+
if(compile_options)
|
| 89 |
+
set_source_files_properties(
|
| 90 |
+
${SRC}
|
| 91 |
+
PROPERTIES COMPILE_OPTIONS "${compile_options}")
|
| 92 |
+
endif()
|
| 93 |
+
|
| 94 |
+
list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
|
| 95 |
+
endforeach()
|
| 96 |
+
|
| 97 |
+
add_custom_target(
|
| 98 |
+
hipify${NAME}
|
| 99 |
+
COMMAND "${Python3_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
|
| 100 |
+
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
|
| 101 |
+
BYPRODUCTS ${HIP_SRCS}
|
| 102 |
+
COMMENT "Running hipify on ${NAME} extension source files.")
|
| 103 |
+
|
| 104 |
+
# Swap out original extension sources with hipified sources.
|
| 105 |
+
list(APPEND HIP_SRCS ${CXX_SRCS})
|
| 106 |
+
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
|
| 107 |
+
endfunction()
|
| 108 |
+
|
| 109 |
+
#
|
| 110 |
+
# Get additional GPU compiler flags from torch.
|
| 111 |
+
#
|
| 112 |
+
function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
| 113 |
+
if (${GPU_LANG} STREQUAL "CUDA")
|
| 114 |
+
#
|
| 115 |
+
# Get common NVCC flags from torch.
|
| 116 |
+
#
|
| 117 |
+
run_python(GPU_FLAGS
|
| 118 |
+
"from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
|
| 119 |
+
"Failed to determine torch nvcc compiler flags")
|
| 120 |
+
|
| 121 |
+
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
| 122 |
+
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
| 123 |
+
list(REMOVE_ITEM GPU_FLAGS
|
| 124 |
+
"-D__CUDA_NO_HALF_OPERATORS__"
|
| 125 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
| 126 |
+
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
| 127 |
+
"-D__CUDA_NO_HALF2_OPERATORS__")
|
| 128 |
+
endif()
|
| 129 |
+
|
| 130 |
+
elseif(${GPU_LANG} STREQUAL "HIP")
|
| 131 |
+
#
|
| 132 |
+
# Get common HIP/HIPCC flags from torch.
|
| 133 |
+
#
|
| 134 |
+
run_python(GPU_FLAGS
|
| 135 |
+
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
|
| 136 |
+
"Failed to determine torch nvcc compiler flags")
|
| 137 |
+
|
| 138 |
+
list(APPEND GPU_FLAGS
|
| 139 |
+
"-DUSE_ROCM"
|
| 140 |
+
"-DENABLE_FP8"
|
| 141 |
+
"-U__HIP_NO_HALF_CONVERSIONS__"
|
| 142 |
+
"-U__HIP_NO_HALF_OPERATORS__"
|
| 143 |
+
"-fno-gpu-rdc")
|
| 144 |
+
|
| 145 |
+
endif()
|
| 146 |
+
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
|
| 147 |
+
endfunction()
|
| 148 |
+
|
| 149 |
+
# Macro for converting a `gencode` version number to a cmake version number.
|
| 150 |
+
macro(string_to_ver OUT_VER IN_STR)
|
| 151 |
+
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
| 152 |
+
endmacro()
|
| 153 |
+
|
| 154 |
+
#
|
| 155 |
+
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
| 156 |
+
# `CUDA_ARCH_FLAGS`.
|
| 157 |
+
#
|
| 158 |
+
# Example:
|
| 159 |
+
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
| 160 |
+
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
| 161 |
+
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
| 162 |
+
# CMAKE_CUDA_FLAGS="-Wall"
|
| 163 |
+
#
|
| 164 |
+
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
| 165 |
+
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
| 166 |
+
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
| 167 |
+
${CMAKE_CUDA_FLAGS})
|
| 168 |
+
|
| 169 |
+
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
| 170 |
+
# and passed back via the `CUDA_ARCHITECTURES` property.
|
| 171 |
+
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
| 172 |
+
${CMAKE_CUDA_FLAGS})
|
| 173 |
+
endmacro()
|
| 174 |
+
|
| 175 |
+
#
|
| 176 |
+
# Extract unique CUDA architectures from a list of compute capabilities codes in
|
| 177 |
+
# the form `<major><minor>[<letter>]`, convert them to the form sort
|
| 178 |
+
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
|
| 179 |
+
# stores them in `OUT_ARCHES`.
|
| 180 |
+
#
|
| 181 |
+
# Example:
|
| 182 |
+
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
|
| 183 |
+
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
|
| 184 |
+
# OUT_ARCHES="7.5;...;9.0"
|
| 185 |
+
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
| 186 |
+
set(_CUDA_ARCHES)
|
| 187 |
+
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
| 188 |
+
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
| 189 |
+
if (_COMPUTE)
|
| 190 |
+
set(_COMPUTE ${CMAKE_MATCH_1})
|
| 191 |
+
endif()
|
| 192 |
+
|
| 193 |
+
string_to_ver(_COMPUTE_VER ${_COMPUTE})
|
| 194 |
+
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
|
| 195 |
+
endforeach()
|
| 196 |
+
|
| 197 |
+
list(REMOVE_DUPLICATES _CUDA_ARCHES)
|
| 198 |
+
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
|
| 199 |
+
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
|
| 200 |
+
endfunction()
|
| 201 |
+
|
| 202 |
+
#
|
| 203 |
+
# For a specific file set the `-gencode` flag in compile options conditionally
|
| 204 |
+
# for the CUDA language.
|
| 205 |
+
#
|
| 206 |
+
# Example:
|
| 207 |
+
# set_gencode_flag_for_srcs(
|
| 208 |
+
# SRCS "foo.cu"
|
| 209 |
+
# ARCH "compute_75"
|
| 210 |
+
# CODE "sm_75")
|
| 211 |
+
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
|
| 212 |
+
# `foo.cu` (only for the CUDA language).
|
| 213 |
+
#
|
| 214 |
+
macro(set_gencode_flag_for_srcs)
|
| 215 |
+
set(options)
|
| 216 |
+
set(oneValueArgs ARCH CODE)
|
| 217 |
+
set(multiValueArgs SRCS)
|
| 218 |
+
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
| 219 |
+
"${multiValueArgs}" ${ARGN} )
|
| 220 |
+
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
|
| 221 |
+
set_property(
|
| 222 |
+
SOURCE ${arg_SRCS}
|
| 223 |
+
APPEND PROPERTY
|
| 224 |
+
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
|
| 228 |
+
endmacro(set_gencode_flag_for_srcs)
|
| 229 |
+
|
| 230 |
+
#
|
| 231 |
+
# For a list of source files set the `-gencode` flags in the files specific
|
| 232 |
+
# compile options (specifically for the CUDA language).
|
| 233 |
+
#
|
| 234 |
+
# arguments are:
|
| 235 |
+
# SRCS: list of source files
|
| 236 |
+
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
|
| 237 |
+
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
|
| 238 |
+
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
|
| 239 |
+
# that is larger than BUILD_PTX_FOR_ARCH.
|
| 240 |
+
#
|
| 241 |
+
macro(set_gencode_flags_for_srcs)
|
| 242 |
+
set(options)
|
| 243 |
+
set(oneValueArgs BUILD_PTX_FOR_ARCH)
|
| 244 |
+
set(multiValueArgs SRCS CUDA_ARCHS)
|
| 245 |
+
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
| 246 |
+
"${multiValueArgs}" ${ARGN} )
|
| 247 |
+
|
| 248 |
+
foreach(_ARCH ${arg_CUDA_ARCHS})
|
| 249 |
+
# handle +PTX suffix: generate both sm and ptx codes if requested
|
| 250 |
+
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
| 251 |
+
if(NOT _HAS_PTX EQUAL -1)
|
| 252 |
+
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
| 253 |
+
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
| 254 |
+
set_gencode_flag_for_srcs(
|
| 255 |
+
SRCS ${arg_SRCS}
|
| 256 |
+
ARCH "compute_${_STRIPPED_ARCH}"
|
| 257 |
+
CODE "sm_${_STRIPPED_ARCH}")
|
| 258 |
+
set_gencode_flag_for_srcs(
|
| 259 |
+
SRCS ${arg_SRCS}
|
| 260 |
+
ARCH "compute_${_STRIPPED_ARCH}"
|
| 261 |
+
CODE "compute_${_STRIPPED_ARCH}")
|
| 262 |
+
else()
|
| 263 |
+
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
| 264 |
+
set_gencode_flag_for_srcs(
|
| 265 |
+
SRCS ${arg_SRCS}
|
| 266 |
+
ARCH "compute_${_STRIPPED_ARCH}"
|
| 267 |
+
CODE "sm_${_STRIPPED_ARCH}")
|
| 268 |
+
endif()
|
| 269 |
+
endforeach()
|
| 270 |
+
|
| 271 |
+
if (${arg_BUILD_PTX_FOR_ARCH})
|
| 272 |
+
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
| 273 |
+
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
|
| 274 |
+
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
|
| 275 |
+
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
|
| 276 |
+
set_gencode_flag_for_srcs(
|
| 277 |
+
SRCS ${arg_SRCS}
|
| 278 |
+
ARCH "compute_${_PTX_ARCH}"
|
| 279 |
+
CODE "compute_${_PTX_ARCH}")
|
| 280 |
+
endif()
|
| 281 |
+
endif()
|
| 282 |
+
endmacro()
|
| 283 |
+
|
| 284 |
+
#
|
| 285 |
+
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
| 286 |
+
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
| 287 |
+
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
| 288 |
+
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
| 289 |
+
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
| 290 |
+
# architecture in `SRC_CUDA_ARCHS`.
|
| 291 |
+
# The loose intersection is defined as:
|
| 292 |
+
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
| 293 |
+
# where `<=` is the version comparison operator.
|
| 294 |
+
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
| 295 |
+
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
| 296 |
+
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
| 297 |
+
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
| 298 |
+
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
| 299 |
+
# The result is stored in `OUT_CUDA_ARCHS`.
|
| 300 |
+
#
|
| 301 |
+
# Example:
|
| 302 |
+
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
|
| 303 |
+
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
|
| 304 |
+
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 305 |
+
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
| 306 |
+
#
|
| 307 |
+
# Example With PTX:
|
| 308 |
+
# SRC_CUDA_ARCHS="8.0+PTX"
|
| 309 |
+
# TGT_CUDA_ARCHS="9.0"
|
| 310 |
+
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 311 |
+
# OUT_CUDA_ARCHS="8.0+PTX"
|
| 312 |
+
#
|
| 313 |
+
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 314 |
+
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
| 315 |
+
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
| 316 |
+
|
| 317 |
+
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
| 318 |
+
set(_PTX_ARCHS)
|
| 319 |
+
foreach(_arch ${_SRC_CUDA_ARCHS})
|
| 320 |
+
if(_arch MATCHES "\\+PTX$")
|
| 321 |
+
string(REPLACE "+PTX" "" _base "${_arch}")
|
| 322 |
+
list(APPEND _PTX_ARCHS "${_base}")
|
| 323 |
+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
| 324 |
+
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
| 325 |
+
endif()
|
| 326 |
+
endforeach()
|
| 327 |
+
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
| 328 |
+
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
| 329 |
+
|
| 330 |
+
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
| 331 |
+
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
| 332 |
+
set(_CUDA_ARCHS)
|
| 333 |
+
foreach(_arch ${_SRC_CUDA_ARCHS})
|
| 334 |
+
if(_arch MATCHES "\\a$")
|
| 335 |
+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
| 336 |
+
string(REPLACE "a" "" _base "${_arch}")
|
| 337 |
+
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
| 338 |
+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
| 339 |
+
list(APPEND _CUDA_ARCHS "${_arch}")
|
| 340 |
+
endif()
|
| 341 |
+
endif()
|
| 342 |
+
endforeach()
|
| 343 |
+
|
| 344 |
+
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
| 345 |
+
|
| 346 |
+
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
| 347 |
+
# is less or equal to ARCH (but has the same major version since SASS binary
|
| 348 |
+
# compatibility is only forward compatible within the same major version).
|
| 349 |
+
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
| 350 |
+
set(_TMP_ARCH)
|
| 351 |
+
# Extract the major version of the target arch
|
| 352 |
+
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
| 353 |
+
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
| 354 |
+
# Extract the major version of the source arch
|
| 355 |
+
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
| 356 |
+
# Check version-less-or-equal, and allow PTX arches to match across majors
|
| 357 |
+
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
| 358 |
+
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
| 359 |
+
set(_TMP_ARCH "${_SRC_ARCH}")
|
| 360 |
+
endif()
|
| 361 |
+
else()
|
| 362 |
+
# If we hit a version greater than the target, we can break
|
| 363 |
+
break()
|
| 364 |
+
endif()
|
| 365 |
+
endforeach()
|
| 366 |
+
|
| 367 |
+
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
| 368 |
+
if (_TMP_ARCH)
|
| 369 |
+
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
| 370 |
+
endif()
|
| 371 |
+
endforeach()
|
| 372 |
+
|
| 373 |
+
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
| 374 |
+
|
| 375 |
+
# reapply +PTX suffix to architectures that requested PTX
|
| 376 |
+
set(_FINAL_ARCHS)
|
| 377 |
+
foreach(_arch ${_CUDA_ARCHS})
|
| 378 |
+
if(_arch IN_LIST _PTX_ARCHS)
|
| 379 |
+
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
| 380 |
+
else()
|
| 381 |
+
list(APPEND _FINAL_ARCHS "${_arch}")
|
| 382 |
+
endif()
|
| 383 |
+
endforeach()
|
| 384 |
+
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
| 385 |
+
|
| 386 |
+
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
| 387 |
+
endfunction()
|
| 388 |
+
|
| 389 |
+
#
|
| 390 |
+
# For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
|
| 391 |
+
# `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
|
| 392 |
+
# The loose intersection is defined as:
|
| 393 |
+
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
| 394 |
+
# where `<=` is the version comparison operator.
|
| 395 |
+
# In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
|
| 396 |
+
# in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
|
| 397 |
+
# The result is stored in `OUT_ROCM_ARCHS`.
|
| 398 |
+
#
|
| 399 |
+
# Example:
|
| 400 |
+
# SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
|
| 401 |
+
# TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
|
| 402 |
+
# hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
|
| 403 |
+
# OUT_ROCM_ARCHS="gfx906;gfx908"
|
| 404 |
+
#
|
| 405 |
+
function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
|
| 406 |
+
list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
|
| 407 |
+
|
| 408 |
+
# ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
|
| 409 |
+
# and x is a letter. We can sort them by string comparison which works for this format.
|
| 410 |
+
list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
|
| 411 |
+
|
| 412 |
+
set(_ROCM_ARCHS)
|
| 413 |
+
|
| 414 |
+
# Find the intersection of supported architectures
|
| 415 |
+
foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
|
| 416 |
+
if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
|
| 417 |
+
list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
|
| 418 |
+
endif()
|
| 419 |
+
endforeach()
|
| 420 |
+
|
| 421 |
+
list(REMOVE_DUPLICATES _ROCM_ARCHS)
|
| 422 |
+
set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
|
| 423 |
+
endfunction()
|
| 424 |
+
|
| 425 |
+
#
|
| 426 |
+
# Override the GPU architectures detected by cmake/torch and filter them by
|
| 427 |
+
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
| 428 |
+
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
|
| 429 |
+
# the architectures on a per file basis.
|
| 430 |
+
#
|
| 431 |
+
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
| 432 |
+
#
|
| 433 |
+
macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
| 434 |
+
set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
|
| 435 |
+
message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
|
| 436 |
+
|
| 437 |
+
if (${GPU_LANG} STREQUAL "HIP")
|
| 438 |
+
#
|
| 439 |
+
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
| 440 |
+
#
|
| 441 |
+
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
|
| 442 |
+
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
|
| 443 |
+
# "rocm_agent_enumerator" in "enable_language(HIP)"
|
| 444 |
+
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
|
| 445 |
+
#
|
| 446 |
+
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
| 447 |
+
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
|
| 448 |
+
else()
|
| 449 |
+
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
|
| 450 |
+
endif()
|
| 451 |
+
#
|
| 452 |
+
# Find the intersection of the supported + detected architectures to
|
| 453 |
+
# set the module architecture flags.
|
| 454 |
+
#
|
| 455 |
+
set(${GPU_ARCHES})
|
| 456 |
+
foreach (_ARCH ${HIP_ARCHITECTURES})
|
| 457 |
+
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
| 458 |
+
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
| 459 |
+
endif()
|
| 460 |
+
endforeach()
|
| 461 |
+
|
| 462 |
+
if(NOT ${GPU_ARCHES})
|
| 463 |
+
message(FATAL_ERROR
|
| 464 |
+
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
| 465 |
+
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
| 466 |
+
endif()
|
| 467 |
+
endif()
|
| 468 |
+
endmacro()
|
| 469 |
+
|
| 470 |
+
#
|
| 471 |
+
# Define a target named `GPU_MOD_NAME` for a single extension. The
|
| 472 |
+
# arguments are:
|
| 473 |
+
#
|
| 474 |
+
# DESTINATION <dest> - Module destination directory.
|
| 475 |
+
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
|
| 476 |
+
# etc.
|
| 477 |
+
# SOURCES <sources> - List of source files relative to CMakeLists.txt
|
| 478 |
+
# directory.
|
| 479 |
+
#
|
| 480 |
+
# Optional arguments:
|
| 481 |
+
#
|
| 482 |
+
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
|
| 483 |
+
# format.
|
| 484 |
+
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
|
| 485 |
+
# and `CMAKE_HIP_ARCHITECTURES` for more info.
|
| 486 |
+
# ARCHITECTURES will use cmake's defaults if
|
| 487 |
+
# not provided.
|
| 488 |
+
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
|
| 489 |
+
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
| 490 |
+
# LIBRARIES <libraries> - Extra link libraries.
|
| 491 |
+
# WITH_SOABI - Generate library with python SOABI suffix name.
|
| 492 |
+
# USE_SABI <version> - Use python stable api <version>
|
| 493 |
+
#
|
| 494 |
+
# Note: optimization level/debug info is set via cmake build type.
|
| 495 |
+
#
|
| 496 |
+
function (define_gpu_extension_target GPU_MOD_NAME)
|
| 497 |
+
cmake_parse_arguments(PARSE_ARGV 1
|
| 498 |
+
GPU
|
| 499 |
+
"WITH_SOABI"
|
| 500 |
+
"DESTINATION;LANGUAGE;USE_SABI"
|
| 501 |
+
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
| 502 |
+
|
| 503 |
+
# Add hipify preprocessing step when building with HIP/ROCm.
|
| 504 |
+
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 505 |
+
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
|
| 506 |
+
endif()
|
| 507 |
+
|
| 508 |
+
if (GPU_WITH_SOABI)
|
| 509 |
+
set(GPU_WITH_SOABI WITH_SOABI)
|
| 510 |
+
else()
|
| 511 |
+
set(GPU_WITH_SOABI)
|
| 512 |
+
endif()
|
| 513 |
+
|
| 514 |
+
if (GPU_USE_SABI)
|
| 515 |
+
Python3_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
| 516 |
+
else()
|
| 517 |
+
Python3_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
| 518 |
+
endif()
|
| 519 |
+
|
| 520 |
+
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 521 |
+
# Make this target dependent on the hipify preprocessor step.
|
| 522 |
+
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
| 523 |
+
endif()
|
| 524 |
+
|
| 525 |
+
if (GPU_ARCHITECTURES)
|
| 526 |
+
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 527 |
+
# Clear target architectures, we are passing arch flags per source file.
|
| 528 |
+
set_property(TARGET ${GPU_MOD_NAME} PROPERTY HIP_ARCHITECTURES off)
|
| 529 |
+
else()
|
| 530 |
+
set_target_properties(${GPU_MOD_NAME} PROPERTIES
|
| 531 |
+
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
| 532 |
+
endif()
|
| 533 |
+
endif()
|
| 534 |
+
|
| 535 |
+
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
| 536 |
+
|
| 537 |
+
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
| 538 |
+
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
| 539 |
+
|
| 540 |
+
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
| 541 |
+
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
| 542 |
+
|
| 543 |
+
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
| 544 |
+
${GPU_INCLUDE_DIRECTORIES})
|
| 545 |
+
|
| 546 |
+
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
|
| 547 |
+
|
| 548 |
+
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
| 549 |
+
# dependencies that are not necessary and may not be installed.
|
| 550 |
+
if (GPU_LANGUAGE STREQUAL "CUDA")
|
| 551 |
+
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
|
| 552 |
+
else()
|
| 553 |
+
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
| 554 |
+
endif()
|
| 555 |
+
|
| 556 |
+
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
|
| 557 |
+
endfunction()
|
flake.lock
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1761588595,
|
| 6 |
+
"narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1747046372,
|
| 21 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1761756835,
|
| 77 |
+
"narHash": "sha256-Vjrv8ZIhkQRgQ3MHGVFaj/fUcE4yuGr+vnoKYRwWmYw=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "6839b6998be18679992978c2f3abddc902276280",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1761991868,
|
| 102 |
+
"narHash": "sha256-+csvkWC9jC4mwq1LNfK4O6m3Qg4dCCXjP5JGdPa3TEo=",
|
| 103 |
+
"owner": "huggingface",
|
| 104 |
+
"repo": "kernel-builder",
|
| 105 |
+
"rev": "79cbfcdfde82c8847551f67f4b951a410794a5c6",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "huggingface",
|
| 110 |
+
"ref": "metal_kernels",
|
| 111 |
+
"repo": "kernel-builder",
|
| 112 |
+
"type": "github"
|
| 113 |
+
}
|
| 114 |
+
},
|
| 115 |
+
"nixpkgs": {
|
| 116 |
+
"locked": {
|
| 117 |
+
"lastModified": 1755963616,
|
| 118 |
+
"narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
|
| 119 |
+
"owner": "nixos",
|
| 120 |
+
"repo": "nixpkgs",
|
| 121 |
+
"rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
|
| 122 |
+
"type": "github"
|
| 123 |
+
},
|
| 124 |
+
"original": {
|
| 125 |
+
"owner": "nixos",
|
| 126 |
+
"ref": "nixos-unstable-small",
|
| 127 |
+
"repo": "nixpkgs",
|
| 128 |
+
"type": "github"
|
| 129 |
+
}
|
| 130 |
+
},
|
| 131 |
+
"root": {
|
| 132 |
+
"inputs": {
|
| 133 |
+
"kernel-builder": "kernel-builder"
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
"systems": {
|
| 137 |
+
"locked": {
|
| 138 |
+
"lastModified": 1681028828,
|
| 139 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 140 |
+
"owner": "nix-systems",
|
| 141 |
+
"repo": "default",
|
| 142 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 143 |
+
"type": "github"
|
| 144 |
+
},
|
| 145 |
+
"original": {
|
| 146 |
+
"owner": "nix-systems",
|
| 147 |
+
"repo": "default",
|
| 148 |
+
"type": "github"
|
| 149 |
+
}
|
| 150 |
+
},
|
| 151 |
+
"systems_2": {
|
| 152 |
+
"locked": {
|
| 153 |
+
"lastModified": 1681028828,
|
| 154 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 155 |
+
"owner": "nix-systems",
|
| 156 |
+
"repo": "default",
|
| 157 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 158 |
+
"type": "github"
|
| 159 |
+
},
|
| 160 |
+
"original": {
|
| 161 |
+
"owner": "nix-systems",
|
| 162 |
+
"repo": "default",
|
| 163 |
+
"type": "github"
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
},
|
| 167 |
+
"root": "root",
|
| 168 |
+
"version": 7
|
| 169 |
+
}
|
flake.nix
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
description = "Flake for Torch kernel extension";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs = { self, kernel-builder, }:
|
|
|
|
| 2 |
description = "Flake for Torch kernel extension";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder?ref=metal_kernels";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs = { self, kernel-builder, }:
|
gptoss_kernels/CMakeLists.txt
DELETED
|
@@ -1,191 +0,0 @@
|
|
| 1 |
-
cmake_minimum_required(VERSION 3.24)
|
| 2 |
-
project(GPTOSS
|
| 3 |
-
VERSION 1.0
|
| 4 |
-
DESCRIPTION "Local GPT-OSS inference"
|
| 5 |
-
LANGUAGES C CXX OBJC)
|
| 6 |
-
|
| 7 |
-
set(CMAKE_C_STANDARD 11)
|
| 8 |
-
set(CMAKE_CXX_STANDARD 20)
|
| 9 |
-
set(CMAKE_OBJC_STANDARD 11)
|
| 10 |
-
set(CMAKE_OBJC_STANDARD_REQUIRED ON)
|
| 11 |
-
|
| 12 |
-
find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)
|
| 13 |
-
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
| 14 |
-
find_library(IOKIT_FRAMEWORK IOKit REQUIRED)
|
| 15 |
-
|
| 16 |
-
set(METAL_SOURCES
|
| 17 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
|
| 18 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
|
| 19 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
|
| 20 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal
|
| 21 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal
|
| 22 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
|
| 23 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
|
| 24 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
|
| 25 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
|
| 26 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
|
| 27 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
|
| 28 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal
|
| 29 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
|
| 30 |
-
${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
|
| 31 |
-
)
|
| 32 |
-
set(METAL_LIB default.metallib)
|
| 33 |
-
|
| 34 |
-
include_directories(BEFORE include source/include)
|
| 35 |
-
|
| 36 |
-
add_custom_command(
|
| 37 |
-
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 38 |
-
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/source/"
|
| 39 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air"
|
| 40 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air"
|
| 41 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
|
| 42 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air"
|
| 43 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
|
| 44 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
|
| 45 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air"
|
| 46 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
|
| 47 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
|
| 48 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
|
| 49 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
|
| 50 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air"
|
| 51 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
|
| 52 |
-
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
|
| 53 |
-
COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air" "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
|
| 54 |
-
DEPENDS ${METAL_SOURCES}
|
| 55 |
-
COMMENT "Compiling Metal compute library"
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
add_custom_target(build_metallib ALL
|
| 59 |
-
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})
|
| 60 |
-
|
| 61 |
-
add_library(log OBJECT source/log.c)
|
| 62 |
-
|
| 63 |
-
add_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)
|
| 64 |
-
target_link_libraries(metal-kernels PRIVATE log)
|
| 65 |
-
|
| 66 |
-
add_dependencies(metal-kernels build_metallib)
|
| 67 |
-
add_custom_command(TARGET metal-kernels POST_BUILD
|
| 68 |
-
COMMAND ${CMAKE_COMMAND} -E copy
|
| 69 |
-
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 70 |
-
$<TARGET_FILE_DIR:metal-kernels>)
|
| 71 |
-
|
| 72 |
-
target_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})
|
| 73 |
-
|
| 74 |
-
add_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)
|
| 75 |
-
target_link_libraries(gptoss PRIVATE log metal-kernels)
|
| 76 |
-
|
| 77 |
-
add_executable(generate source/generate.c)
|
| 78 |
-
target_link_libraries(generate gptoss)
|
| 79 |
-
|
| 80 |
-
# --- [ Tests
|
| 81 |
-
include(FetchContent)
|
| 82 |
-
FetchContent_Declare(
|
| 83 |
-
googletest
|
| 84 |
-
URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip
|
| 85 |
-
DOWNLOAD_EXTRACT_TIMESTAMP OFF
|
| 86 |
-
)
|
| 87 |
-
# For Windows: Prevent overriding the parent project's compiler/linker settings
|
| 88 |
-
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
| 89 |
-
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
|
| 90 |
-
FetchContent_MakeAvailable(googletest)
|
| 91 |
-
|
| 92 |
-
enable_testing()
|
| 93 |
-
|
| 94 |
-
add_executable(u32-random-test test/u32-random.cc)
|
| 95 |
-
target_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)
|
| 96 |
-
target_include_directories(u32-random-test PRIVATE source/include)
|
| 97 |
-
add_test(NAME u32-random-test COMMAND u32-random-test)
|
| 98 |
-
|
| 99 |
-
add_executable(f32-random-test test/f32-random.cc)
|
| 100 |
-
target_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)
|
| 101 |
-
target_include_directories(f32-random-test PRIVATE source/include)
|
| 102 |
-
add_test(NAME f32-random-test COMMAND f32-random-test)
|
| 103 |
-
|
| 104 |
-
add_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)
|
| 105 |
-
target_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)
|
| 106 |
-
target_include_directories(mf4-f32-convert-test PRIVATE source/include)
|
| 107 |
-
add_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)
|
| 108 |
-
|
| 109 |
-
add_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)
|
| 110 |
-
target_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)
|
| 111 |
-
target_include_directories(bf16-f32-embeddings-test PRIVATE source/include)
|
| 112 |
-
add_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)
|
| 113 |
-
|
| 114 |
-
add_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)
|
| 115 |
-
target_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)
|
| 116 |
-
target_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)
|
| 117 |
-
add_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)
|
| 118 |
-
|
| 119 |
-
add_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)
|
| 120 |
-
target_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)
|
| 121 |
-
target_include_directories(f32-bf16w-matmul-test PRIVATE source/include)
|
| 122 |
-
add_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)
|
| 123 |
-
|
| 124 |
-
add_executable(f32-rope-test test/f32-rope.cc)
|
| 125 |
-
target_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)
|
| 126 |
-
target_include_directories(f32-rope-test PRIVATE source/include)
|
| 127 |
-
add_test(NAME f32-rope-test COMMAND f32-rope-test)
|
| 128 |
-
|
| 129 |
-
# --- [ Benchmarks
|
| 130 |
-
include(FetchContent)
|
| 131 |
-
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable self-tests in Google Benchmark" FORCE)
|
| 132 |
-
set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable installation of Google Benchmark" FORCE)
|
| 133 |
-
FetchContent_Declare(
|
| 134 |
-
benchmark
|
| 135 |
-
URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip
|
| 136 |
-
DOWNLOAD_EXTRACT_TIMESTAMP OFF
|
| 137 |
-
)
|
| 138 |
-
FetchContent_MakeAvailable(benchmark)
|
| 139 |
-
|
| 140 |
-
add_executable(f32-random-bench benchmark/f32-random.cc)
|
| 141 |
-
target_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 142 |
-
target_include_directories(f32-random-bench PRIVATE source/include)
|
| 143 |
-
|
| 144 |
-
add_executable(u32-random-bench benchmark/u32-random.cc)
|
| 145 |
-
target_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 146 |
-
target_include_directories(u32-random-bench PRIVATE source/include)
|
| 147 |
-
|
| 148 |
-
add_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)
|
| 149 |
-
target_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 150 |
-
target_include_directories(mf4-f32-convert-bench PRIVATE source/include)
|
| 151 |
-
|
| 152 |
-
add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
|
| 153 |
-
target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 154 |
-
target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
|
| 155 |
-
|
| 156 |
-
add_executable(end-to-end-bench benchmark/end-to-end.cc)
|
| 157 |
-
target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
|
| 158 |
-
target_include_directories(end-to-end-bench PRIVATE source/include)
|
| 159 |
-
|
| 160 |
-
add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)
|
| 161 |
-
target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)
|
| 162 |
-
target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)
|
| 163 |
-
|
| 164 |
-
# --- [ Python extension ] -----------------------------------------------
|
| 165 |
-
find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
|
| 166 |
-
|
| 167 |
-
pybind11_add_module(_metal
|
| 168 |
-
python/module.c
|
| 169 |
-
python/context.c
|
| 170 |
-
python/model.c
|
| 171 |
-
python/tokenizer.c
|
| 172 |
-
)
|
| 173 |
-
set_target_properties(_metal PROPERTIES PREFIX "")
|
| 174 |
-
|
| 175 |
-
target_link_libraries(_metal PRIVATE gptoss)
|
| 176 |
-
add_dependencies(_metal build_metallib)
|
| 177 |
-
target_link_options(_metal PRIVATE
|
| 178 |
-
LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 179 |
-
)
|
| 180 |
-
add_custom_command(TARGET _metal POST_BUILD
|
| 181 |
-
COMMAND ${CMAKE_COMMAND} -E copy
|
| 182 |
-
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 183 |
-
$<TARGET_FILE_DIR:_metal>)
|
| 184 |
-
|
| 185 |
-
# 1οΈβ£ install the extension module into the Python package
|
| 186 |
-
install(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)
|
| 187 |
-
|
| 188 |
-
# 2οΈβ£ make sure the Metal shader archive travels with it
|
| 189 |
-
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 190 |
-
DESTINATION gpt_oss/metal)
|
| 191 |
-
# ------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
from importlib import import_module as _im
|
| 2 |
-
|
| 3 |
-
# Load the compiled extension (gpt_oss.metal._metal)
|
| 4 |
-
_ext = _im(f"{__name__}._metal")
|
| 5 |
-
globals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith("_")})
|
| 6 |
-
del _im, _ext
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/examples/chat.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import sys
|
| 5 |
-
|
| 6 |
-
from datetime import date
|
| 7 |
-
from gpt_oss.metal import Context, Model
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
|
| 11 |
-
Knowledge cutoff: 2024-06
|
| 12 |
-
Current date: {date.today().isoformat()}
|
| 13 |
-
|
| 14 |
-
reasoning effort high
|
| 15 |
-
|
| 16 |
-
# Valid channels: analysis, final. Channel must be included for every message."""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 20 |
-
parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
|
| 21 |
-
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
|
| 22 |
-
parser.add_argument(
|
| 23 |
-
"--context-length", type=int, default=0, help="The maximum context length"
|
| 24 |
-
)
|
| 25 |
-
parser.add_argument(
|
| 26 |
-
"--temperature", type=float, default=1.0, help="Sampling temperature"
|
| 27 |
-
)
|
| 28 |
-
parser.add_argument(
|
| 29 |
-
"--seed", type=int, default=0, help="Sampling seed"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
GREY = "\33[90m"
|
| 34 |
-
BOLD = "\33[1m"
|
| 35 |
-
RESET = "\33[0m"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def main(args):
|
| 39 |
-
options = parser.parse_args(args)
|
| 40 |
-
model = Model(options.model)
|
| 41 |
-
tokenizer = model.tokenizer
|
| 42 |
-
start_token = tokenizer.encode_special_token("<|start|>")
|
| 43 |
-
message_token = tokenizer.encode_special_token("<|message|>")
|
| 44 |
-
end_token = tokenizer.encode_special_token("<|end|>")
|
| 45 |
-
return_token = tokenizer.encode_special_token("<|return|>")
|
| 46 |
-
channel_token = tokenizer.encode_special_token("<|channel|>")
|
| 47 |
-
|
| 48 |
-
context = Context(model, context_length=options.context_length)
|
| 49 |
-
context.append(start_token)
|
| 50 |
-
context.append("system")
|
| 51 |
-
context.append(message_token)
|
| 52 |
-
context.append(options.prompt)
|
| 53 |
-
context.append(end_token)
|
| 54 |
-
|
| 55 |
-
while True:
|
| 56 |
-
context.append(start_token)
|
| 57 |
-
context.append("user")
|
| 58 |
-
context.append(message_token)
|
| 59 |
-
message = input(f"{BOLD}User:{RESET} ").rstrip()
|
| 60 |
-
context.append(message)
|
| 61 |
-
context.append(end_token)
|
| 62 |
-
print(f"{BOLD}Assistant:{RESET} {GREY}", end="", flush=True)
|
| 63 |
-
context.append(start_token)
|
| 64 |
-
context.append("assistant")
|
| 65 |
-
context.append(channel_token)
|
| 66 |
-
|
| 67 |
-
inside_start_block = True
|
| 68 |
-
inside_channel_block = True
|
| 69 |
-
role = "assistant"
|
| 70 |
-
channel = ""
|
| 71 |
-
while True:
|
| 72 |
-
token = context.sample(
|
| 73 |
-
temperature=options.temperature,
|
| 74 |
-
seed=options.seed,
|
| 75 |
-
)
|
| 76 |
-
context.append(token)
|
| 77 |
-
if token == return_token:
|
| 78 |
-
print(flush=True)
|
| 79 |
-
break
|
| 80 |
-
elif token == start_token:
|
| 81 |
-
inside_start_block = True
|
| 82 |
-
role = ""
|
| 83 |
-
channel = ""
|
| 84 |
-
elif token == message_token:
|
| 85 |
-
inside_start_block = False
|
| 86 |
-
inside_channel_block = False
|
| 87 |
-
if channel == "analysis":
|
| 88 |
-
print(f"{GREY}", end="", flush=True)
|
| 89 |
-
elif token == end_token:
|
| 90 |
-
print(f"{RESET}", flush=True)
|
| 91 |
-
elif token == channel_token:
|
| 92 |
-
inside_channel_block = True
|
| 93 |
-
elif token < tokenizer.num_text_tokens:
|
| 94 |
-
if inside_channel_block:
|
| 95 |
-
channel += str(tokenizer.decode(token), encoding="utf-8")
|
| 96 |
-
elif inside_start_block:
|
| 97 |
-
role += str(tokenizer.decode(token), encoding="utf-8")
|
| 98 |
-
else:
|
| 99 |
-
sys.stdout.buffer.write(tokenizer.decode(token))
|
| 100 |
-
sys.stdout.buffer.flush()
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
if __name__ == "__main__":
|
| 104 |
-
main(sys.argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/examples/generate.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import sys
|
| 5 |
-
|
| 6 |
-
from gpt_oss.metal import Context, Model
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 10 |
-
parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
|
| 11 |
-
parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
|
| 12 |
-
parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
|
| 13 |
-
parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def main(args):
|
| 17 |
-
options = parser.parse_args(args)
|
| 18 |
-
model = Model(options.model)
|
| 19 |
-
|
| 20 |
-
context = Context(model, context_length=options.context_length)
|
| 21 |
-
context.append(options.prompt)
|
| 22 |
-
print(context.tokens)
|
| 23 |
-
prompt_tokens = context.num_tokens
|
| 24 |
-
|
| 25 |
-
tokenizer = model.tokenizer
|
| 26 |
-
|
| 27 |
-
while context.num_tokens - prompt_tokens < options.limit:
|
| 28 |
-
token = context.sample()
|
| 29 |
-
context.append(token)
|
| 30 |
-
print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
if __name__ == '__main__':
|
| 34 |
-
main(sys.argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/source/context.c
DELETED
|
@@ -1,1115 +0,0 @@
|
|
| 1 |
-
#include <assert.h>
|
| 2 |
-
#include <float.h>
|
| 3 |
-
#include <inttypes.h>
|
| 4 |
-
#include <stdbool.h>
|
| 5 |
-
#include <stdint.h>
|
| 6 |
-
#include <stdlib.h>
|
| 7 |
-
#include <string.h>
|
| 8 |
-
|
| 9 |
-
#include <gpt-oss.h>
|
| 10 |
-
|
| 11 |
-
#include "internal/datatype.h"
|
| 12 |
-
#include "internal/model.h"
|
| 13 |
-
#include "internal/metal.h"
|
| 14 |
-
#include "internal/metal-kernels.h"
|
| 15 |
-
#include "internal/log.h"
|
| 16 |
-
#include "internal/rng.h"
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_create(
|
| 20 |
-
gptoss_model_t model,
|
| 21 |
-
size_t context_length,
|
| 22 |
-
size_t max_batch_tokens,
|
| 23 |
-
gptoss_context_t* context_out)
|
| 24 |
-
{
|
| 25 |
-
*context_out = NULL;
|
| 26 |
-
|
| 27 |
-
enum gptoss_status status = gptoss_status_success;
|
| 28 |
-
struct gptoss_context* context = NULL;
|
| 29 |
-
|
| 30 |
-
// Validate context_length
|
| 31 |
-
if (context_length == 0) {
|
| 32 |
-
context_length = model->context_length;
|
| 33 |
-
} else if (context_length > model->context_length) {
|
| 34 |
-
GPTOSS_LOG_ERROR("requested context length %zu exceeds model context length %" PRIu32,
|
| 35 |
-
context_length, model->context_length);
|
| 36 |
-
status = gptoss_status_invalid_argument;
|
| 37 |
-
goto cleanup;
|
| 38 |
-
}
|
| 39 |
-
assert(context_length != 0);
|
| 40 |
-
assert(context_length <= model->context_length);
|
| 41 |
-
|
| 42 |
-
// Validate max_batch_tokens
|
| 43 |
-
if (max_batch_tokens == 0) {
|
| 44 |
-
max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
|
| 45 |
-
} else if (max_batch_tokens > context_length) {
|
| 46 |
-
GPTOSS_LOG_ERROR("requested max batch tokens %zu exceeds context length %zu",
|
| 47 |
-
max_batch_tokens, context_length);
|
| 48 |
-
status = gptoss_status_invalid_argument;
|
| 49 |
-
goto cleanup;
|
| 50 |
-
}
|
| 51 |
-
assert(max_batch_tokens != 0);
|
| 52 |
-
assert(max_batch_tokens <= context_length);
|
| 53 |
-
|
| 54 |
-
context = malloc(sizeof(struct gptoss_context));
|
| 55 |
-
if (context == NULL) {
|
| 56 |
-
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for Context object",
|
| 57 |
-
sizeof(struct gptoss_context));
|
| 58 |
-
status = gptoss_status_insufficient_memory;
|
| 59 |
-
goto cleanup;
|
| 60 |
-
}
|
| 61 |
-
memset(context, 0, sizeof(struct gptoss_context));
|
| 62 |
-
|
| 63 |
-
atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
|
| 64 |
-
context->max_tokens = context_length;
|
| 65 |
-
context->max_batch_tokens = max_batch_tokens;
|
| 66 |
-
|
| 67 |
-
// Activation buffers
|
| 68 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
|
| 69 |
-
if (status != gptoss_status_success) {
|
| 70 |
-
goto cleanup;
|
| 71 |
-
}
|
| 72 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
|
| 73 |
-
if (status != gptoss_status_success) {
|
| 74 |
-
goto cleanup;
|
| 75 |
-
}
|
| 76 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
|
| 77 |
-
if (status != gptoss_status_success) {
|
| 78 |
-
goto cleanup;
|
| 79 |
-
}
|
| 80 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
|
| 81 |
-
if (status != gptoss_status_success) {
|
| 82 |
-
goto cleanup;
|
| 83 |
-
}
|
| 84 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
|
| 85 |
-
if (status != gptoss_status_success) {
|
| 86 |
-
goto cleanup;
|
| 87 |
-
}
|
| 88 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
|
| 89 |
-
if (status != gptoss_status_success) {
|
| 90 |
-
goto cleanup;
|
| 91 |
-
}
|
| 92 |
-
// The last entry will hold the total number of tokens.
|
| 93 |
-
status = gptoss_metal_buffer_create(&model->device, (1 + model->num_experts) * sizeof(uint32_t), NULL, &context->expert_offset_buffer);
|
| 94 |
-
if (status != gptoss_status_success) {
|
| 95 |
-
goto cleanup;
|
| 96 |
-
}
|
| 97 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);
|
| 98 |
-
if (status != gptoss_status_success) {
|
| 99 |
-
goto cleanup;
|
| 100 |
-
}
|
| 101 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);
|
| 102 |
-
if (status != gptoss_status_success) {
|
| 103 |
-
goto cleanup;
|
| 104 |
-
}
|
| 105 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
|
| 106 |
-
if (status != gptoss_status_success) {
|
| 107 |
-
goto cleanup;
|
| 108 |
-
}
|
| 109 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
|
| 110 |
-
if (status != gptoss_status_success) {
|
| 111 |
-
goto cleanup;
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
// Input/output buffers
|
| 115 |
-
status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
|
| 116 |
-
if (status != gptoss_status_success) {
|
| 117 |
-
goto cleanup;
|
| 118 |
-
}
|
| 119 |
-
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
|
| 120 |
-
if (status != gptoss_status_success) {
|
| 121 |
-
goto cleanup;
|
| 122 |
-
}
|
| 123 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);
|
| 124 |
-
if (status != gptoss_status_success) {
|
| 125 |
-
goto cleanup;
|
| 126 |
-
}
|
| 127 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);
|
| 128 |
-
if (status != gptoss_status_success) {
|
| 129 |
-
goto cleanup;
|
| 130 |
-
}
|
| 131 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);
|
| 132 |
-
if (status != gptoss_status_success) {
|
| 133 |
-
goto cleanup;
|
| 134 |
-
}
|
| 135 |
-
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);
|
| 136 |
-
if (status != gptoss_status_success) {
|
| 137 |
-
goto cleanup;
|
| 138 |
-
}
|
| 139 |
-
status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);
|
| 140 |
-
if (status != gptoss_status_success) {
|
| 141 |
-
goto cleanup;
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
context->kvcache_size = context->kvcache_buffer.size;
|
| 145 |
-
context->allocation_size =
|
| 146 |
-
context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
|
| 147 |
-
context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
|
| 148 |
-
context->gate_activation_buffer.size + context->expert_activation_buffer.size +
|
| 149 |
-
context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +
|
| 150 |
-
context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
|
| 151 |
-
context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
|
| 152 |
-
|
| 153 |
-
context->model = model;
|
| 154 |
-
gptoss_model_retain(model);
|
| 155 |
-
*context_out = context;
|
| 156 |
-
context = NULL;
|
| 157 |
-
|
| 158 |
-
cleanup:
|
| 159 |
-
gptoss_context_release(context);
|
| 160 |
-
return status;
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
|
| 164 |
-
gptoss_context_t context,
|
| 165 |
-
size_t* num_tokens_out)
|
| 166 |
-
{
|
| 167 |
-
*num_tokens_out = context->num_tokens;
|
| 168 |
-
return gptoss_status_success;
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
|
| 172 |
-
gptoss_context_t context,
|
| 173 |
-
size_t* max_tokens_out)
|
| 174 |
-
{
|
| 175 |
-
*max_tokens_out = context->max_tokens;
|
| 176 |
-
return gptoss_status_success;
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
|
| 180 |
-
gptoss_context_t context,
|
| 181 |
-
uint32_t* tokens_out,
|
| 182 |
-
size_t max_tokens,
|
| 183 |
-
size_t* num_tokens_out)
|
| 184 |
-
{
|
| 185 |
-
*num_tokens_out = context->num_tokens;
|
| 186 |
-
if (max_tokens < context->num_tokens) {
|
| 187 |
-
return gptoss_status_insufficient_memory;
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
if (context->num_tokens != 0) {
|
| 191 |
-
memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));
|
| 192 |
-
}
|
| 193 |
-
return gptoss_status_success;
|
| 194 |
-
}
|
| 195 |
-
|
| 196 |
-
// Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.
|
| 197 |
-
// Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.
|
| 198 |
-
// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
|
| 199 |
-
static enum gptoss_status process_tokens(
|
| 200 |
-
gptoss_context_t context,
|
| 201 |
-
struct gptoss_metal_command_buffer* command_buffer,
|
| 202 |
-
size_t input_tokens_offset,
|
| 203 |
-
size_t num_input_tokens,
|
| 204 |
-
size_t num_output_tokens)
|
| 205 |
-
{
|
| 206 |
-
assert(num_input_tokens != 0);
|
| 207 |
-
assert(num_input_tokens <= context->max_batch_tokens);
|
| 208 |
-
assert(num_output_tokens <= context->max_batch_tokens);
|
| 209 |
-
assert(num_input_tokens >= num_output_tokens);
|
| 210 |
-
const size_t dense_matmul_kernel_token_multiple_constraint = 64;
|
| 211 |
-
const size_t min_tokens_for_dense_moe_kernels = 64;
|
| 212 |
-
|
| 213 |
-
enum gptoss_status status = gptoss_status_success;
|
| 214 |
-
const struct gptoss_model* model = context->model;
|
| 215 |
-
|
| 216 |
-
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
|
| 217 |
-
|
| 218 |
-
const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
|
| 219 |
-
for (size_t input_batch_start = input_tokens_offset;
|
| 220 |
-
input_batch_start < input_tokens_end;
|
| 221 |
-
input_batch_start += context->max_batch_tokens)
|
| 222 |
-
{
|
| 223 |
-
const size_t input_batch_size = math_min(context->max_batch_tokens, input_tokens_end - input_batch_start);
|
| 224 |
-
const size_t input_batch_end = input_batch_start + input_batch_size;
|
| 225 |
-
const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
|
| 226 |
-
|
| 227 |
-
status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
| 228 |
-
command_buffer,
|
| 229 |
-
&model->bf16_f32_embeddings_fn,
|
| 230 |
-
model->embeddings_threadgroup_size,
|
| 231 |
-
&context->token_buffer,
|
| 232 |
-
input_batch_start * sizeof(uint32_t),
|
| 233 |
-
&model->shared_weight_buffer,
|
| 234 |
-
/*weight_offset=*/0,
|
| 235 |
-
&context->residual_activation_buffer,
|
| 236 |
-
/*output_offset=*/0,
|
| 237 |
-
&context->control_buffer,
|
| 238 |
-
/*control_offset=*/0,
|
| 239 |
-
/*num_tokens=*/input_batch_size,
|
| 240 |
-
/*num_channels=*/model->embedding_dim);
|
| 241 |
-
if (status != gptoss_status_success) {
|
| 242 |
-
GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
|
| 243 |
-
return status;
|
| 244 |
-
}
|
| 245 |
-
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
| 246 |
-
const bool last_block = n + 1 == model->num_blocks;
|
| 247 |
-
const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
|
| 248 |
-
|
| 249 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 250 |
-
command_buffer,
|
| 251 |
-
&model->f32_bf16w_rmsnorm_fn,
|
| 252 |
-
&context->residual_activation_buffer,
|
| 253 |
-
/*input_offset=*/0,
|
| 254 |
-
&model->shared_weight_buffer,
|
| 255 |
-
/*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
|
| 256 |
-
&context->rmsnorm_activation_buffer,
|
| 257 |
-
/*output_offset=*/0,
|
| 258 |
-
&context->control_buffer,
|
| 259 |
-
/*control_offset=*/0,
|
| 260 |
-
/*num_tokens=*/input_batch_size,
|
| 261 |
-
/*num_channels=*/model->embedding_dim,
|
| 262 |
-
model->rmsnorm_epsilon);
|
| 263 |
-
if (status != gptoss_status_success) {
|
| 264 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
| 265 |
-
return status;
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
|
| 269 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
|
| 270 |
-
command_buffer,
|
| 271 |
-
&model->f32_bf16w_dense_matmul_qkv_fn,
|
| 272 |
-
&context->rmsnorm_activation_buffer,
|
| 273 |
-
/*input_offset=*/0,
|
| 274 |
-
&model->shared_weight_buffer,
|
| 275 |
-
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
|
| 276 |
-
&model->shared_weight_buffer,
|
| 277 |
-
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
|
| 278 |
-
&context->qkv_activation_buffer,
|
| 279 |
-
/*output_offset=*/0,
|
| 280 |
-
&context->control_buffer,
|
| 281 |
-
/*control_offset=*/0,
|
| 282 |
-
/*num_tokens=*/input_batch_size,
|
| 283 |
-
/*num_cols=*/model->embedding_dim,
|
| 284 |
-
/*num_rows=*/attn_qkv_dim);
|
| 285 |
-
if (status != gptoss_status_success) {
|
| 286 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
|
| 287 |
-
return status;
|
| 288 |
-
}
|
| 289 |
-
|
| 290 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
|
| 291 |
-
command_buffer,
|
| 292 |
-
&model->f32_rope_fn,
|
| 293 |
-
/*threadgroup_size=*/32,
|
| 294 |
-
&context->qkv_activation_buffer,
|
| 295 |
-
/*input_offset=*/0,
|
| 296 |
-
&context->control_buffer,
|
| 297 |
-
/*control_offset=*/0,
|
| 298 |
-
model->rope_theta,
|
| 299 |
-
model->interpolation_scale,
|
| 300 |
-
model->yarn_offset,
|
| 301 |
-
model->yarn_scale,
|
| 302 |
-
model->yarn_multiplier,
|
| 303 |
-
input_batch_size,
|
| 304 |
-
model->num_heads,
|
| 305 |
-
model->num_kv_heads,
|
| 306 |
-
model->head_dim,
|
| 307 |
-
/*token_offset=*/input_batch_start);
|
| 308 |
-
if (status != gptoss_status_success) {
|
| 309 |
-
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
|
| 310 |
-
return status;
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
for (uint32_t t = 0; t < input_batch_size; t++) {
|
| 314 |
-
for (uint32_t kv = 0; kv < 2; kv++) {
|
| 315 |
-
for (uint32_t h = 0; h < model->num_kv_heads; h++) {
|
| 316 |
-
status = gptoss_metal_command_buffer_encode_copy_buffer(
|
| 317 |
-
command_buffer,
|
| 318 |
-
&context->qkv_activation_buffer,
|
| 319 |
-
/*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float),
|
| 320 |
-
&context->kvcache_buffer,
|
| 321 |
-
/*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float),
|
| 322 |
-
/*size=*/model->head_dim * sizeof(float));
|
| 323 |
-
if (status != gptoss_status_success) {
|
| 324 |
-
GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
|
| 325 |
-
return status;
|
| 326 |
-
}
|
| 327 |
-
}
|
| 328 |
-
}
|
| 329 |
-
}
|
| 330 |
-
} else {
|
| 331 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
|
| 332 |
-
command_buffer,
|
| 333 |
-
&model->f32_bf16w_matmul_qkv_fn,
|
| 334 |
-
model->attn_qkv_threadgroup_size,
|
| 335 |
-
&context->rmsnorm_activation_buffer,
|
| 336 |
-
/*input_offset=*/0,
|
| 337 |
-
&model->shared_weight_buffer,
|
| 338 |
-
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
|
| 339 |
-
&model->shared_weight_buffer,
|
| 340 |
-
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
|
| 341 |
-
&context->qkv_activation_buffer,
|
| 342 |
-
/*output_offset=*/0,
|
| 343 |
-
&context->kvcache_buffer,
|
| 344 |
-
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
|
| 345 |
-
&context->control_buffer,
|
| 346 |
-
/*control_offset=*/0,
|
| 347 |
-
/*num_tokens=*/input_batch_size,
|
| 348 |
-
/*num_cols=*/model->embedding_dim,
|
| 349 |
-
/*num_q_heads=*/model->num_heads,
|
| 350 |
-
/*num_kv_heads=*/model->num_kv_heads,
|
| 351 |
-
/*attn_head_dim=*/model->head_dim,
|
| 352 |
-
/*token_offset=*/input_batch_start,
|
| 353 |
-
/*max_tokens=*/context->max_tokens,
|
| 354 |
-
/*rope_base=*/model->rope_theta,
|
| 355 |
-
/*interpolation_scale=*/model->interpolation_scale,
|
| 356 |
-
/*yarn_offset=*/model->yarn_offset,
|
| 357 |
-
/*yarn_scale=*/model->yarn_scale,
|
| 358 |
-
/*yarn_multiplier=*/model->yarn_multiplier);
|
| 359 |
-
if (status != gptoss_status_success) {
|
| 360 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
|
| 361 |
-
return status;
|
| 362 |
-
}
|
| 363 |
-
}
|
| 364 |
-
|
| 365 |
-
if (num_block_output_tokens != 0) {
|
| 366 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
| 367 |
-
command_buffer,
|
| 368 |
-
&model->f32_sdpa_q8_d64_fn,
|
| 369 |
-
&context->qkv_activation_buffer,
|
| 370 |
-
/*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 371 |
-
&context->kvcache_buffer,
|
| 372 |
-
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
|
| 373 |
-
&model->shared_weight_buffer,
|
| 374 |
-
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
|
| 375 |
-
&context->sdpa_activation_buffer,
|
| 376 |
-
/*output_offset=*/0,
|
| 377 |
-
&context->control_buffer,
|
| 378 |
-
/*control_offset=*/0,
|
| 379 |
-
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
|
| 380 |
-
/*kv_stride=*/2 * context->max_tokens * model->head_dim,
|
| 381 |
-
num_block_output_tokens,
|
| 382 |
-
input_batch_start + input_batch_size - num_block_output_tokens,
|
| 383 |
-
model->num_heads, model->num_kv_heads, model->head_dim);
|
| 384 |
-
if (status != gptoss_status_success) {
|
| 385 |
-
GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
|
| 386 |
-
return status;
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
|
| 390 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
|
| 391 |
-
command_buffer,
|
| 392 |
-
&model->f32_bf16w_dense_matmul_attn_output_fn,
|
| 393 |
-
&context->sdpa_activation_buffer,
|
| 394 |
-
/*input_offset=*/0,
|
| 395 |
-
&model->shared_weight_buffer,
|
| 396 |
-
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
|
| 397 |
-
&model->shared_weight_buffer,
|
| 398 |
-
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
|
| 399 |
-
&context->residual_activation_buffer,
|
| 400 |
-
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 401 |
-
&context->control_buffer,
|
| 402 |
-
/*control_offset=*/0,
|
| 403 |
-
/*num_tokens=*/num_block_output_tokens,
|
| 404 |
-
/*num_cols=*/model->num_heads * model->head_dim,
|
| 405 |
-
/*num_rows=*/model->embedding_dim);
|
| 406 |
-
if (status != gptoss_status_success) {
|
| 407 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch");
|
| 408 |
-
return status;
|
| 409 |
-
}
|
| 410 |
-
} else {
|
| 411 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
| 412 |
-
command_buffer,
|
| 413 |
-
&model->f32_bf16w_matmul_fn,
|
| 414 |
-
model->attn_out_threadgroup_size,
|
| 415 |
-
&context->sdpa_activation_buffer,
|
| 416 |
-
/*input_offset=*/0,
|
| 417 |
-
&model->shared_weight_buffer,
|
| 418 |
-
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
|
| 419 |
-
&model->shared_weight_buffer,
|
| 420 |
-
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
|
| 421 |
-
&context->residual_activation_buffer,
|
| 422 |
-
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 423 |
-
&context->control_buffer,
|
| 424 |
-
/*control_offset=*/0,
|
| 425 |
-
/*num_tokens=*/num_block_output_tokens,
|
| 426 |
-
/*num_cols=*/model->num_heads * model->head_dim,
|
| 427 |
-
/*num_rows=*/model->embedding_dim);
|
| 428 |
-
if (status != gptoss_status_success) {
|
| 429 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
|
| 430 |
-
return status;
|
| 431 |
-
}
|
| 432 |
-
}
|
| 433 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 434 |
-
command_buffer,
|
| 435 |
-
&model->f32_bf16w_rmsnorm_fn,
|
| 436 |
-
&context->residual_activation_buffer,
|
| 437 |
-
/*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 438 |
-
&model->shared_weight_buffer,
|
| 439 |
-
/*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
|
| 440 |
-
&context->rmsnorm_activation_buffer,
|
| 441 |
-
/*output_offset=*/0,
|
| 442 |
-
&context->control_buffer,
|
| 443 |
-
/*control_offset=*/0,
|
| 444 |
-
num_block_output_tokens,
|
| 445 |
-
model->embedding_dim,
|
| 446 |
-
model->rmsnorm_epsilon);
|
| 447 |
-
if (status != gptoss_status_success) {
|
| 448 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
| 449 |
-
return status;
|
| 450 |
-
}
|
| 451 |
-
if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
|
| 452 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
|
| 453 |
-
command_buffer,
|
| 454 |
-
&model->f32_bf16w_dense_matmul_mlp_gate_fn,
|
| 455 |
-
&context->rmsnorm_activation_buffer,
|
| 456 |
-
/*input_offset=*/0,
|
| 457 |
-
&model->shared_weight_buffer,
|
| 458 |
-
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
|
| 459 |
-
&model->shared_weight_buffer,
|
| 460 |
-
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
|
| 461 |
-
&context->gate_activation_buffer,
|
| 462 |
-
/*output_offset=*/0,
|
| 463 |
-
&context->control_buffer,
|
| 464 |
-
/*control_offset=*/0,
|
| 465 |
-
num_block_output_tokens,
|
| 466 |
-
model->embedding_dim,
|
| 467 |
-
model->num_experts);
|
| 468 |
-
if (status != gptoss_status_success) {
|
| 469 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch");
|
| 470 |
-
return status;
|
| 471 |
-
}
|
| 472 |
-
} else {
|
| 473 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
| 474 |
-
command_buffer,
|
| 475 |
-
&model->f32_bf16w_matmul_fn,
|
| 476 |
-
model->mlp_gate_threadgroup_size,
|
| 477 |
-
&context->rmsnorm_activation_buffer,
|
| 478 |
-
/*input_offset=*/0,
|
| 479 |
-
&model->shared_weight_buffer,
|
| 480 |
-
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
|
| 481 |
-
&model->shared_weight_buffer,
|
| 482 |
-
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
|
| 483 |
-
&context->gate_activation_buffer,
|
| 484 |
-
/*output_offset=*/0,
|
| 485 |
-
&context->control_buffer,
|
| 486 |
-
/*control_offset=*/0,
|
| 487 |
-
/*num_tokens=*/num_block_output_tokens,
|
| 488 |
-
/*num_cols=*/model->embedding_dim,
|
| 489 |
-
/*num_rows=*/model->num_experts);
|
| 490 |
-
if (status != gptoss_status_success) {
|
| 491 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
|
| 492 |
-
return status;
|
| 493 |
-
}
|
| 494 |
-
}
|
| 495 |
-
|
| 496 |
-
const char* kernel_name = NULL;
|
| 497 |
-
switch (model->num_experts) {
|
| 498 |
-
case 32:
|
| 499 |
-
kernel_name = "f32_topk_softmax_e32_k4_fn";
|
| 500 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
|
| 501 |
-
command_buffer,
|
| 502 |
-
&model->f32_topk_softmax_e32_k4_fn,
|
| 503 |
-
&context->gate_activation_buffer, /*input_offset=*/0,
|
| 504 |
-
&context->expert_activation_buffer, /*output_offset=*/0,
|
| 505 |
-
&context->control_buffer, /*control_offset=*/0,
|
| 506 |
-
num_block_output_tokens,
|
| 507 |
-
model->num_experts,
|
| 508 |
-
model->num_active_experts);
|
| 509 |
-
break;
|
| 510 |
-
case 128:
|
| 511 |
-
kernel_name = "f32_topk_softmax_e128_k4_fn";
|
| 512 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
|
| 513 |
-
command_buffer,
|
| 514 |
-
&model->f32_topk_softmax_e128_k4_fn,
|
| 515 |
-
&context->gate_activation_buffer, /*input_offset=*/0,
|
| 516 |
-
&context->expert_activation_buffer, /*output_offset=*/0,
|
| 517 |
-
&context->control_buffer, /*control_offset=*/0,
|
| 518 |
-
num_block_output_tokens,
|
| 519 |
-
model->num_experts,
|
| 520 |
-
model->num_active_experts);
|
| 521 |
-
break;
|
| 522 |
-
default:
|
| 523 |
-
status = gptoss_status_unsupported_argument;
|
| 524 |
-
GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
|
| 525 |
-
return status;
|
| 526 |
-
}
|
| 527 |
-
if (status != gptoss_status_success) {
|
| 528 |
-
GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
|
| 529 |
-
return status;
|
| 530 |
-
}
|
| 531 |
-
|
| 532 |
-
// If we have enough tokens in prefill, we will pick the prefill-optimized kernels.
|
| 533 |
-
if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {
|
| 534 |
-
status = gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
|
| 535 |
-
command_buffer,
|
| 536 |
-
&model->f32_expert_routing_metadata_fn,
|
| 537 |
-
&context->expert_activation_buffer,
|
| 538 |
-
/*expert_predictions_offset=*/0,
|
| 539 |
-
&context->expert_offset_buffer,
|
| 540 |
-
/*expert_offsets_offset=*/0,
|
| 541 |
-
&context->token_to_expert_routing_buffer,
|
| 542 |
-
/*intra_expert_offsets_offset=*/0,
|
| 543 |
-
num_block_output_tokens * model->num_active_experts,
|
| 544 |
-
model->num_experts);
|
| 545 |
-
if (status != gptoss_status_success) {
|
| 546 |
-
GPTOSS_LOG_ERROR("failed to encode f32_expert_routing_metadata kernel launch");
|
| 547 |
-
return status;
|
| 548 |
-
}
|
| 549 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_scatter(
|
| 550 |
-
command_buffer,
|
| 551 |
-
&model->f32_scatter_e4_fn,
|
| 552 |
-
&context->rmsnorm_activation_buffer,
|
| 553 |
-
/*input_offset=*/0,
|
| 554 |
-
&context->expert_activation_buffer,
|
| 555 |
-
/*expert_predictions_offset=*/0,
|
| 556 |
-
&context->expert_offset_buffer,
|
| 557 |
-
/*expert_offsets_offset=*/0,
|
| 558 |
-
&context->token_to_expert_routing_buffer,
|
| 559 |
-
/*intra_expert_offsets_offset=*/0,
|
| 560 |
-
&context->swiglu_input_buffer,
|
| 561 |
-
/*output_offset=*/0,
|
| 562 |
-
model->embedding_dim,
|
| 563 |
-
num_block_output_tokens,
|
| 564 |
-
model->num_active_experts);
|
| 565 |
-
if (status != gptoss_status_success) {
|
| 566 |
-
GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch");
|
| 567 |
-
return status;
|
| 568 |
-
}
|
| 569 |
-
// Dense MoE SwiGLU matmul.
|
| 570 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
|
| 571 |
-
command_buffer,
|
| 572 |
-
&model->f32_mf4w_moe_dense_matmul_swiglu_fn,
|
| 573 |
-
&context->expert_offset_buffer,
|
| 574 |
-
/*expert_offsets_offset=*/0,
|
| 575 |
-
&context->swiglu_input_buffer,
|
| 576 |
-
/*input_offset=*/0,
|
| 577 |
-
&model->block_weight_buffers[n],
|
| 578 |
-
/*weight_block_offset=*/0,
|
| 579 |
-
&model->block_weight_buffers[n],
|
| 580 |
-
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
|
| 581 |
-
&model->block_weight_buffers[n],
|
| 582 |
-
/*bias_offset=*/model->mlp_swiglu_bias_offset,
|
| 583 |
-
&context->swiglu_activation_buffer,
|
| 584 |
-
/*output_offset=*/0,
|
| 585 |
-
model->swiglu_limit,
|
| 586 |
-
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
|
| 587 |
-
num_block_output_tokens,
|
| 588 |
-
model->num_experts,
|
| 589 |
-
model->embedding_dim,
|
| 590 |
-
2 * model->mlp_dim);
|
| 591 |
-
if (status != gptoss_status_success) {
|
| 592 |
-
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
|
| 593 |
-
return status;
|
| 594 |
-
}
|
| 595 |
-
|
| 596 |
-
// Dense MoE proj matmul.
|
| 597 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
|
| 598 |
-
command_buffer,
|
| 599 |
-
&model->f32_mf4w_moe_dense_matmul_fn,
|
| 600 |
-
&context->expert_offset_buffer,
|
| 601 |
-
/*expert_offsets_offset=*/0,
|
| 602 |
-
&context->swiglu_activation_buffer,
|
| 603 |
-
/*input_offset=*/0,
|
| 604 |
-
&model->block_weight_buffers[n],
|
| 605 |
-
/*weight_block_offset=*/model->mlp_out_block_offset,
|
| 606 |
-
&model->block_weight_buffers[n],
|
| 607 |
-
/*weight_scale_offset=*/model->mlp_out_scale_offset,
|
| 608 |
-
&model->block_weight_buffers[n],
|
| 609 |
-
/*bias_offset=*/model->mlp_out_bias_offset,
|
| 610 |
-
&context->moe_activation_buffer,
|
| 611 |
-
/*output_offset=*/0,
|
| 612 |
-
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
|
| 613 |
-
num_block_output_tokens,
|
| 614 |
-
model->num_experts,
|
| 615 |
-
model->mlp_dim,
|
| 616 |
-
model->embedding_dim);
|
| 617 |
-
if (status != gptoss_status_success) {
|
| 618 |
-
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
|
| 619 |
-
return status;
|
| 620 |
-
}
|
| 621 |
-
// Gather and accumulate.
|
| 622 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
|
| 623 |
-
command_buffer,
|
| 624 |
-
&model->f32_gather_and_accumulate_e4_fn,
|
| 625 |
-
&context->moe_activation_buffer,
|
| 626 |
-
/*input_offset=*/0,
|
| 627 |
-
&context->expert_activation_buffer,
|
| 628 |
-
/*expert_predictions_offset=*/0,
|
| 629 |
-
&context->expert_offset_buffer,
|
| 630 |
-
/*expert_offsets_offset=*/0,
|
| 631 |
-
&context->token_to_expert_routing_buffer,
|
| 632 |
-
/*intra_expert_offsets_offset=*/0,
|
| 633 |
-
&context->residual_activation_buffer,
|
| 634 |
-
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 635 |
-
model->embedding_dim,
|
| 636 |
-
num_block_output_tokens,
|
| 637 |
-
model->num_active_experts);
|
| 638 |
-
if (status != gptoss_status_success) {
|
| 639 |
-
GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch");
|
| 640 |
-
return status;
|
| 641 |
-
}
|
| 642 |
-
|
| 643 |
-
} else {
|
| 644 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
| 645 |
-
command_buffer,
|
| 646 |
-
&model->f32_mf4w_moe_matmul_swiglu_fn,
|
| 647 |
-
model->mlp_swiglu_threadgroup_size,
|
| 648 |
-
&context->rmsnorm_activation_buffer,
|
| 649 |
-
/*input_offset=*/0,
|
| 650 |
-
&context->expert_activation_buffer,
|
| 651 |
-
/*expert_offset=*/0,
|
| 652 |
-
&model->block_weight_buffers[n],
|
| 653 |
-
/*weight_block_offset=*/0,
|
| 654 |
-
&model->block_weight_buffers[n],
|
| 655 |
-
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
|
| 656 |
-
&model->block_weight_buffers[n],
|
| 657 |
-
/*bias_offset=*/model->mlp_swiglu_bias_offset,
|
| 658 |
-
&context->swiglu_activation_buffer,
|
| 659 |
-
/*output_offset=*/0,
|
| 660 |
-
&context->control_buffer,
|
| 661 |
-
/*control_offset=*/0,
|
| 662 |
-
model->swiglu_limit,
|
| 663 |
-
model->per_expert_block_weight_size,
|
| 664 |
-
num_block_output_tokens,
|
| 665 |
-
model->num_active_experts,
|
| 666 |
-
model->embedding_dim,
|
| 667 |
-
model->mlp_dim);
|
| 668 |
-
if (status != gptoss_status_success) {
|
| 669 |
-
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
|
| 670 |
-
return status;
|
| 671 |
-
}
|
| 672 |
-
|
| 673 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
| 674 |
-
command_buffer,
|
| 675 |
-
&model->f32_mf4w_moe_matmul_fn,
|
| 676 |
-
model->mlp_out_threadgroup_size,
|
| 677 |
-
&context->swiglu_activation_buffer,
|
| 678 |
-
/*input_offset=*/0,
|
| 679 |
-
&context->expert_activation_buffer,
|
| 680 |
-
/*expert_offset=*/0,
|
| 681 |
-
&model->block_weight_buffers[n],
|
| 682 |
-
/*weight_block_offset=*/model->mlp_out_block_offset,
|
| 683 |
-
&model->block_weight_buffers[n],
|
| 684 |
-
/*weight_scale_offset=*/model->mlp_out_scale_offset,
|
| 685 |
-
&model->block_weight_buffers[n],
|
| 686 |
-
/*bias_offset=*/model->mlp_out_bias_offset,
|
| 687 |
-
&context->moe_activation_buffer,
|
| 688 |
-
/*output_offset=*/0,
|
| 689 |
-
&context->control_buffer,
|
| 690 |
-
/*control_offset=*/0,
|
| 691 |
-
model->per_expert_block_weight_size,
|
| 692 |
-
num_block_output_tokens,
|
| 693 |
-
model->num_active_experts,
|
| 694 |
-
model->mlp_dim,
|
| 695 |
-
model->embedding_dim);
|
| 696 |
-
if (status != gptoss_status_success) {
|
| 697 |
-
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
|
| 698 |
-
return status;
|
| 699 |
-
}
|
| 700 |
-
|
| 701 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
| 702 |
-
command_buffer,
|
| 703 |
-
&model->f32_accumulate_e4_fn,
|
| 704 |
-
model->mlp_acc_threadgroup_size,
|
| 705 |
-
model->max_threadgroups,
|
| 706 |
-
&context->moe_activation_buffer,
|
| 707 |
-
/*input_offset=*/0,
|
| 708 |
-
&context->expert_activation_buffer,
|
| 709 |
-
/*expert_offset=*/0,
|
| 710 |
-
&context->residual_activation_buffer,
|
| 711 |
-
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 712 |
-
&context->control_buffer,
|
| 713 |
-
/*control_offset=*/0,
|
| 714 |
-
model->embedding_dim,
|
| 715 |
-
num_block_output_tokens,
|
| 716 |
-
model->num_active_experts);
|
| 717 |
-
if (status != gptoss_status_success) {
|
| 718 |
-
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
|
| 719 |
-
return status;
|
| 720 |
-
}
|
| 721 |
-
}
|
| 722 |
-
}
|
| 723 |
-
}
|
| 724 |
-
|
| 725 |
-
if (output_batch_size != 0) {
|
| 726 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 727 |
-
command_buffer,
|
| 728 |
-
&model->f32_bf16w_rmsnorm_fn,
|
| 729 |
-
&context->residual_activation_buffer,
|
| 730 |
-
/*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
|
| 731 |
-
&model->shared_weight_buffer,
|
| 732 |
-
/*weight_offset=*/model->rmsnorm_weight_offset,
|
| 733 |
-
&context->rmsnorm_activation_buffer,
|
| 734 |
-
/*output_offset=*/0,
|
| 735 |
-
&context->control_buffer,
|
| 736 |
-
/*control_offset=*/0,
|
| 737 |
-
/*num_tokens=*/output_batch_size,
|
| 738 |
-
/*num_channels=*/model->embedding_dim,
|
| 739 |
-
model->rmsnorm_epsilon);
|
| 740 |
-
if (status != gptoss_status_success) {
|
| 741 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
| 742 |
-
return status;
|
| 743 |
-
}
|
| 744 |
-
|
| 745 |
-
status = gptoss_metal_command_buffer_encode_fill_buffer(
|
| 746 |
-
command_buffer,
|
| 747 |
-
&context->argmax_buffer,
|
| 748 |
-
/*offset=*/0,
|
| 749 |
-
/*size=*/sizeof(uint64_t) * output_batch_size,
|
| 750 |
-
/*fill_value=*/0xFF);
|
| 751 |
-
if (status != gptoss_status_success) {
|
| 752 |
-
GPTOSS_LOG_ERROR("failed to encode fill buffer command");
|
| 753 |
-
return status;
|
| 754 |
-
}
|
| 755 |
-
|
| 756 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
| 757 |
-
command_buffer,
|
| 758 |
-
&model->f32_bf16w_unembedding_fn,
|
| 759 |
-
model->unembedding_threadgroup_size,
|
| 760 |
-
model->max_threadgroups,
|
| 761 |
-
&context->rmsnorm_activation_buffer,
|
| 762 |
-
/*input_offset=*/0,
|
| 763 |
-
&model->shared_weight_buffer,
|
| 764 |
-
/*weight_offset=*/model->unembedding_weight_offset,
|
| 765 |
-
&context->score_buffer,
|
| 766 |
-
/*output_offset=*/0,
|
| 767 |
-
&context->argmax_buffer,
|
| 768 |
-
/*argmax_offset=*/0,
|
| 769 |
-
&context->control_buffer,
|
| 770 |
-
/*control_offset=*/0,
|
| 771 |
-
/*num_tokens=*/output_batch_size,
|
| 772 |
-
/*num_cols=*/model->embedding_dim,
|
| 773 |
-
/*num_rows=*/model->vocabulary_size);
|
| 774 |
-
if (status != gptoss_status_success) {
|
| 775 |
-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
|
| 776 |
-
return status;
|
| 777 |
-
}
|
| 778 |
-
}
|
| 779 |
-
}
|
| 780 |
-
return gptoss_status_success;
|
| 781 |
-
}
|
| 782 |
-
|
| 783 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
|
| 784 |
-
gptoss_context_t context,
|
| 785 |
-
const char* text,
|
| 786 |
-
size_t text_length,
|
| 787 |
-
size_t* num_tokens_out)
|
| 788 |
-
{
|
| 789 |
-
enum gptoss_status status = gptoss_status_success;
|
| 790 |
-
const struct gptoss_model* model = context->model;
|
| 791 |
-
const struct gptoss_tokenizer* tokenizer = model->tokenizer;
|
| 792 |
-
size_t num_appended_tokens = 0;
|
| 793 |
-
while (text_length != 0) {
|
| 794 |
-
if (context->num_tokens == context->max_tokens) {
|
| 795 |
-
status = gptoss_status_context_overflow;
|
| 796 |
-
break;
|
| 797 |
-
}
|
| 798 |
-
const char* tokens = tokenizer->tokens_ptr;
|
| 799 |
-
uint32_t best_token = UINT32_MAX;
|
| 800 |
-
uint32_t best_token_length = 0;
|
| 801 |
-
for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {
|
| 802 |
-
uint16_t token_length;
|
| 803 |
-
memcpy(&token_length, tokens, sizeof(uint16_t));
|
| 804 |
-
tokens += sizeof(uint16_t);
|
| 805 |
-
if (token_length <= text_length && token_length > best_token_length) {
|
| 806 |
-
if (memcmp(text, tokens, token_length) == 0) {
|
| 807 |
-
if (token_length > best_token_length) {
|
| 808 |
-
best_token = (uint32_t) t;
|
| 809 |
-
best_token_length = token_length;
|
| 810 |
-
}
|
| 811 |
-
}
|
| 812 |
-
}
|
| 813 |
-
tokens += token_length;
|
| 814 |
-
}
|
| 815 |
-
|
| 816 |
-
if (best_token == UINT32_MAX) {
|
| 817 |
-
GPTOSS_LOG_ERROR("failed to tokenize text \"%.*s\"", (int) text_length, text);
|
| 818 |
-
return gptoss_status_invalid_argument;
|
| 819 |
-
}
|
| 820 |
-
|
| 821 |
-
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
|
| 822 |
-
if (context->num_kv_tokens > context->num_tokens) {
|
| 823 |
-
if (input_tokens[context->num_tokens] != best_token) {
|
| 824 |
-
input_tokens[context->num_tokens] = best_token;
|
| 825 |
-
|
| 826 |
-
// Invalidate the KV cache starting with the newly added token.
|
| 827 |
-
context->num_kv_tokens = context->num_tokens;
|
| 828 |
-
}
|
| 829 |
-
context->num_tokens++;
|
| 830 |
-
} else {
|
| 831 |
-
input_tokens[context->num_tokens++] = best_token;
|
| 832 |
-
}
|
| 833 |
-
num_appended_tokens++;
|
| 834 |
-
text += best_token_length;
|
| 835 |
-
text_length -= best_token_length;
|
| 836 |
-
}
|
| 837 |
-
if (num_tokens_out != NULL) {
|
| 838 |
-
*num_tokens_out = num_appended_tokens;
|
| 839 |
-
}
|
| 840 |
-
return status;
|
| 841 |
-
}
|
| 842 |
-
|
| 843 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
|
| 844 |
-
gptoss_context_t context,
|
| 845 |
-
size_t num_tokens,
|
| 846 |
-
const uint32_t* tokens)
|
| 847 |
-
{
|
| 848 |
-
const struct gptoss_model* model = context->model;
|
| 849 |
-
|
| 850 |
-
// Validate all tokens
|
| 851 |
-
for (size_t t = 0; t < num_tokens; t++) {
|
| 852 |
-
const uint32_t token = tokens[t];
|
| 853 |
-
if (token >= model->vocabulary_size) {
|
| 854 |
-
GPTOSS_LOG_ERROR("token %" PRIu32 " at index %zu is out of bounds for vocabulary size %" PRIu32,
|
| 855 |
-
token, t, context->model->vocabulary_size);
|
| 856 |
-
return gptoss_status_invalid_argument;
|
| 857 |
-
}
|
| 858 |
-
}
|
| 859 |
-
|
| 860 |
-
enum gptoss_status status = gptoss_status_success;
|
| 861 |
-
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
|
| 862 |
-
while (num_tokens != 0) {
|
| 863 |
-
if (context->num_tokens == context->max_tokens) {
|
| 864 |
-
status = gptoss_status_context_overflow;
|
| 865 |
-
break;
|
| 866 |
-
}
|
| 867 |
-
|
| 868 |
-
if (context->num_kv_tokens > context->num_tokens) {
|
| 869 |
-
const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);
|
| 870 |
-
size_t num_verified_tokens = 0;
|
| 871 |
-
for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
|
| 872 |
-
if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
|
| 873 |
-
// Invalidate the KV cache starting with the newly added tokens.
|
| 874 |
-
context->num_kv_tokens = context->num_tokens + num_verified_tokens;
|
| 875 |
-
break;
|
| 876 |
-
}
|
| 877 |
-
}
|
| 878 |
-
|
| 879 |
-
context->num_tokens += num_verified_tokens;
|
| 880 |
-
tokens += num_verified_tokens;
|
| 881 |
-
num_tokens -= num_verified_tokens;
|
| 882 |
-
} else {
|
| 883 |
-
const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);
|
| 884 |
-
memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
|
| 885 |
-
context->num_tokens += num_tokens_to_copy;
|
| 886 |
-
tokens += num_tokens_to_copy;
|
| 887 |
-
num_tokens -= num_tokens_to_copy;
|
| 888 |
-
}
|
| 889 |
-
}
|
| 890 |
-
|
| 891 |
-
return status;
|
| 892 |
-
}
|
| 893 |
-
|
| 894 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_process(
|
| 895 |
-
gptoss_context_t context)
|
| 896 |
-
{
|
| 897 |
-
if (context->num_tokens > context->num_kv_tokens) {
|
| 898 |
-
struct gptoss_metal_command_buffer command_buffer = {0};
|
| 899 |
-
|
| 900 |
-
enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
|
| 901 |
-
if (status != gptoss_status_success) {
|
| 902 |
-
goto cleanup;
|
| 903 |
-
}
|
| 904 |
-
|
| 905 |
-
struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
|
| 906 |
-
control->abort = 0;
|
| 907 |
-
|
| 908 |
-
status = process_tokens(
|
| 909 |
-
context,
|
| 910 |
-
&command_buffer,
|
| 911 |
-
/*input_tokens_offset=*/context->num_kv_tokens,
|
| 912 |
-
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
|
| 913 |
-
/*num_output_tokens=*/0);
|
| 914 |
-
if (status != gptoss_status_success) {
|
| 915 |
-
goto cleanup;
|
| 916 |
-
}
|
| 917 |
-
|
| 918 |
-
status = gptoss_metal_command_buffer_commit(&command_buffer);
|
| 919 |
-
if (status != gptoss_status_success) {
|
| 920 |
-
goto cleanup;
|
| 921 |
-
}
|
| 922 |
-
|
| 923 |
-
status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
|
| 924 |
-
if (status != gptoss_status_success) {
|
| 925 |
-
goto cleanup;
|
| 926 |
-
}
|
| 927 |
-
|
| 928 |
-
context->num_kv_tokens = context->num_tokens;
|
| 929 |
-
|
| 930 |
-
cleanup:
|
| 931 |
-
gptoss_metal_command_buffer_release(&command_buffer);
|
| 932 |
-
return status;
|
| 933 |
-
}
|
| 934 |
-
|
| 935 |
-
return gptoss_status_success;
|
| 936 |
-
}
|
| 937 |
-
|
| 938 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
|
| 939 |
-
gptoss_context_t context,
|
| 940 |
-
float temperature,
|
| 941 |
-
uint64_t seed,
|
| 942 |
-
size_t max_tokens,
|
| 943 |
-
uint32_t* tokens_out,
|
| 944 |
-
size_t* num_tokens_out)
|
| 945 |
-
{
|
| 946 |
-
enum gptoss_status status = gptoss_status_success;
|
| 947 |
-
const struct gptoss_model* model = context->model;
|
| 948 |
-
struct gptoss_metal_command_buffer command_buffer = {0};
|
| 949 |
-
|
| 950 |
-
*num_tokens_out = 0;
|
| 951 |
-
|
| 952 |
-
const uint32_t num_original_tokens = context->num_tokens;
|
| 953 |
-
|
| 954 |
-
status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
|
| 955 |
-
if (status != gptoss_status_success) {
|
| 956 |
-
goto cleanup;
|
| 957 |
-
}
|
| 958 |
-
|
| 959 |
-
struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
|
| 960 |
-
control->abort = 0;
|
| 961 |
-
|
| 962 |
-
for (size_t t = 0; t < max_tokens; t++) {
|
| 963 |
-
if (context->num_kv_tokens < context->num_tokens) {
|
| 964 |
-
status = process_tokens(
|
| 965 |
-
context,
|
| 966 |
-
&command_buffer,
|
| 967 |
-
/*input_tokens_offset=*/context->num_kv_tokens,
|
| 968 |
-
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
|
| 969 |
-
/*num_output_tokens=*/1);
|
| 970 |
-
context->num_kv_tokens = context->num_tokens;
|
| 971 |
-
} else {
|
| 972 |
-
status = process_tokens(
|
| 973 |
-
context,
|
| 974 |
-
&command_buffer,
|
| 975 |
-
/*input_tokens_offset=*/context->num_tokens - 1,
|
| 976 |
-
/*num_input_tokens=*/1,
|
| 977 |
-
/*num_output_tokens=*/1);
|
| 978 |
-
}
|
| 979 |
-
if (status != gptoss_status_success) {
|
| 980 |
-
goto cleanup;
|
| 981 |
-
}
|
| 982 |
-
|
| 983 |
-
if (temperature != 0.0f) {
|
| 984 |
-
assert(context->num_processed_tokens != 0);
|
| 985 |
-
uint32_t num_threadgroups = 0;
|
| 986 |
-
uint32_t num_dims_per_threadgroup = 0;
|
| 987 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
| 988 |
-
&command_buffer,
|
| 989 |
-
&model->f32_softmax_fn,
|
| 990 |
-
/*threadgroup_size=*/512,
|
| 991 |
-
model->max_threadgroups,
|
| 992 |
-
&context->score_buffer,
|
| 993 |
-
/*score_offset=*/0,
|
| 994 |
-
&context->argmax_buffer,
|
| 995 |
-
/*argmax_offset=*/0,
|
| 996 |
-
&context->prob_buffer,
|
| 997 |
-
/*prob_offset=*/0,
|
| 998 |
-
&context->sum_buffer,
|
| 999 |
-
/*sum_offset=*/0,
|
| 1000 |
-
&context->control_buffer,
|
| 1001 |
-
/*control_offset=*/0,
|
| 1002 |
-
model->vocabulary_size,
|
| 1003 |
-
/*num_tokens=*/1,
|
| 1004 |
-
temperature,
|
| 1005 |
-
&num_threadgroups,
|
| 1006 |
-
&num_dims_per_threadgroup);
|
| 1007 |
-
if (status != gptoss_status_success) {
|
| 1008 |
-
GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
|
| 1009 |
-
goto cleanup;
|
| 1010 |
-
}
|
| 1011 |
-
|
| 1012 |
-
status = gptoss_metal_command_buffer_encode_launch_f32_sample(
|
| 1013 |
-
&command_buffer,
|
| 1014 |
-
&model->f32_sample_fn,
|
| 1015 |
-
/*min_threadgroup_size=*/512,
|
| 1016 |
-
&context->prob_buffer,
|
| 1017 |
-
/*prob_offset=*/0,
|
| 1018 |
-
&context->sum_buffer,
|
| 1019 |
-
/*sum_offset=*/0,
|
| 1020 |
-
&context->token_buffer,
|
| 1021 |
-
/*token_offset=*/context->num_tokens * sizeof(uint32_t),
|
| 1022 |
-
&context->control_buffer,
|
| 1023 |
-
/*control_offset=*/0,
|
| 1024 |
-
/*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
|
| 1025 |
-
/*rng_offset=*/context->num_tokens,
|
| 1026 |
-
/*num_blocks=*/num_threadgroups,
|
| 1027 |
-
/*num_channels=*/model->vocabulary_size,
|
| 1028 |
-
/*num_channels_per_block=*/num_dims_per_threadgroup);
|
| 1029 |
-
if (status != gptoss_status_success) {
|
| 1030 |
-
GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
|
| 1031 |
-
goto cleanup;
|
| 1032 |
-
}
|
| 1033 |
-
} else {
|
| 1034 |
-
status = gptoss_metal_command_buffer_encode_copy_buffer(
|
| 1035 |
-
&command_buffer,
|
| 1036 |
-
&context->argmax_buffer,
|
| 1037 |
-
/*input_offset=*/0,
|
| 1038 |
-
&context->token_buffer,
|
| 1039 |
-
/*output_offset=*/context->num_tokens * sizeof(uint32_t),
|
| 1040 |
-
/*size=*/sizeof(uint32_t));
|
| 1041 |
-
if (status != gptoss_status_success) {
|
| 1042 |
-
GPTOSS_LOG_ERROR("failed to encode copy buffer");
|
| 1043 |
-
goto cleanup;
|
| 1044 |
-
}
|
| 1045 |
-
}
|
| 1046 |
-
context->num_tokens += 1;
|
| 1047 |
-
context->num_kv_tokens = context->num_tokens;
|
| 1048 |
-
}
|
| 1049 |
-
|
| 1050 |
-
gptoss_metal_command_buffer_commit(&command_buffer);
|
| 1051 |
-
gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
|
| 1052 |
-
|
| 1053 |
-
const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
|
| 1054 |
-
const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
|
| 1055 |
-
memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
|
| 1056 |
-
*num_tokens_out = num_generated_tokens;
|
| 1057 |
-
|
| 1058 |
-
cleanup:
|
| 1059 |
-
gptoss_metal_command_buffer_release(&command_buffer);
|
| 1060 |
-
return status;
|
| 1061 |
-
}
|
| 1062 |
-
|
| 1063 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
|
| 1064 |
-
gptoss_context_t context)
|
| 1065 |
-
{
|
| 1066 |
-
context->num_tokens = 0;
|
| 1067 |
-
|
| 1068 |
-
// Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.
|
| 1069 |
-
// If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.
|
| 1070 |
-
|
| 1071 |
-
return gptoss_status_success;
|
| 1072 |
-
}
|
| 1073 |
-
|
| 1074 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
|
| 1075 |
-
gptoss_context_t context)
|
| 1076 |
-
{
|
| 1077 |
-
atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);
|
| 1078 |
-
return gptoss_status_success;
|
| 1079 |
-
}
|
| 1080 |
-
|
| 1081 |
-
enum gptoss_status GPTOSS_ABI gptoss_context_release(
|
| 1082 |
-
gptoss_context_t context)
|
| 1083 |
-
{
|
| 1084 |
-
if (context != NULL) {
|
| 1085 |
-
if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
|
| 1086 |
-
// Activation buffers
|
| 1087 |
-
gptoss_metal_buffer_release(&context->residual_activation_buffer);
|
| 1088 |
-
gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
|
| 1089 |
-
gptoss_metal_buffer_release(&context->qkv_activation_buffer);
|
| 1090 |
-
gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
|
| 1091 |
-
gptoss_metal_buffer_release(&context->gate_activation_buffer);
|
| 1092 |
-
gptoss_metal_buffer_release(&context->expert_activation_buffer);
|
| 1093 |
-
gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
|
| 1094 |
-
gptoss_metal_buffer_release(&context->moe_activation_buffer);
|
| 1095 |
-
gptoss_metal_buffer_release(&context->expert_offset_buffer);
|
| 1096 |
-
gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);
|
| 1097 |
-
gptoss_metal_buffer_release(&context->swiglu_input_buffer);
|
| 1098 |
-
|
| 1099 |
-
// Input/output buffers
|
| 1100 |
-
gptoss_metal_buffer_release(&context->control_buffer);
|
| 1101 |
-
gptoss_metal_buffer_release(&context->token_buffer);
|
| 1102 |
-
gptoss_metal_buffer_release(&context->score_buffer);
|
| 1103 |
-
gptoss_metal_buffer_release(&context->prob_buffer);
|
| 1104 |
-
gptoss_metal_buffer_release(&context->sum_buffer);
|
| 1105 |
-
gptoss_metal_buffer_release(&context->argmax_buffer);
|
| 1106 |
-
gptoss_metal_buffer_release(&context->kvcache_buffer);
|
| 1107 |
-
|
| 1108 |
-
gptoss_model_release(context->model);
|
| 1109 |
-
|
| 1110 |
-
memset(context, 0, sizeof(struct gptoss_context));
|
| 1111 |
-
free(context);
|
| 1112 |
-
}
|
| 1113 |
-
}
|
| 1114 |
-
return gptoss_status_success;
|
| 1115 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/source/generate.c
DELETED
|
@@ -1,317 +0,0 @@
|
|
| 1 |
-
#include <assert.h>
|
| 2 |
-
#include <inttypes.h>
|
| 3 |
-
#include <math.h>
|
| 4 |
-
#include <signal.h>
|
| 5 |
-
#include <stdatomic.h>
|
| 6 |
-
#include <stdbool.h>
|
| 7 |
-
#include <stdio.h>
|
| 8 |
-
#include <stdint.h>
|
| 9 |
-
#include <stdlib.h>
|
| 10 |
-
#include <string.h>
|
| 11 |
-
|
| 12 |
-
#include <mach/mach_time.h>
|
| 13 |
-
|
| 14 |
-
#include <gpt-oss.h>
|
| 15 |
-
|
| 16 |
-
#include "internal/model.h"
|
| 17 |
-
|
| 18 |
-
struct {
|
| 19 |
-
atomic_uint_least64_t inference_bytes;
|
| 20 |
-
atomic_size_t num_prefill_tokens;
|
| 21 |
-
atomic_uint_least64_t prefill_microseconds;
|
| 22 |
-
atomic_size_t num_generated_tokens;
|
| 23 |
-
atomic_uint_least64_t generation_microseconds;
|
| 24 |
-
} globals = {
|
| 25 |
-
.inference_bytes = 0,
|
| 26 |
-
.num_prefill_tokens = 0,
|
| 27 |
-
.prefill_microseconds = 0,
|
| 28 |
-
.num_generated_tokens = 0,
|
| 29 |
-
.generation_microseconds = 0,
|
| 30 |
-
};
|
| 31 |
-
|
| 32 |
-
struct options {
|
| 33 |
-
const char* model;
|
| 34 |
-
const char* prompt;
|
| 35 |
-
size_t context_length;
|
| 36 |
-
size_t max_tokens;
|
| 37 |
-
float temperature;
|
| 38 |
-
bool verbose;
|
| 39 |
-
};
|
| 40 |
-
|
| 41 |
-
static inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {
|
| 42 |
-
static mach_timebase_info_data_t timebase_info = {0};
|
| 43 |
-
if (timebase_info.denom == 0) {
|
| 44 |
-
mach_timebase_info(&timebase_info);
|
| 45 |
-
}
|
| 46 |
-
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
|
| 47 |
-
return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
static inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {
|
| 51 |
-
static mach_timebase_info_data_t timebase_info = {0};
|
| 52 |
-
if (timebase_info.denom == 0) {
|
| 53 |
-
mach_timebase_info(&timebase_info);
|
| 54 |
-
}
|
| 55 |
-
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
|
| 56 |
-
const uint64_t denominator = timebase_info.denom * UINT64_C(1000);
|
| 57 |
-
return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
static void print_usage(const char* program_name) {
|
| 61 |
-
printf("Usage: %s <model-path> [-p <prompt>] [-n <tokens>]\n", program_name);
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
struct options parse_options(int argc, char** argv) {
|
| 65 |
-
struct options options = (struct options) {
|
| 66 |
-
.model = NULL,
|
| 67 |
-
.prompt = NULL,
|
| 68 |
-
.context_length = 0,
|
| 69 |
-
.max_tokens = 0,
|
| 70 |
-
.temperature = 0.0f,
|
| 71 |
-
.verbose = false,
|
| 72 |
-
};
|
| 73 |
-
if (argc < 2) {
|
| 74 |
-
fprintf(stderr, "Error: missing required command-line argument\n");
|
| 75 |
-
print_usage(argv[0]);
|
| 76 |
-
exit(EXIT_FAILURE);
|
| 77 |
-
}
|
| 78 |
-
for (int i = 1; i < argc; i++) {
|
| 79 |
-
if (strcmp(argv[i], "--help") == 0) {
|
| 80 |
-
print_usage(argv[0]);
|
| 81 |
-
exit(EXIT_SUCCESS);
|
| 82 |
-
} else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--prompt") == 0) {
|
| 83 |
-
if (i + 1 >= argc) {
|
| 84 |
-
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
| 85 |
-
print_usage(argv[0]);
|
| 86 |
-
exit(EXIT_FAILURE);
|
| 87 |
-
}
|
| 88 |
-
options.prompt = argv[++i];
|
| 89 |
-
} else if (strcmp(argv[i], "--context-length") == 0) {
|
| 90 |
-
if (i + 1 >= argc) {
|
| 91 |
-
fprintf(stderr, "Error: missing argument for --context-length\n");
|
| 92 |
-
print_usage(argv[0]);
|
| 93 |
-
exit(EXIT_FAILURE);
|
| 94 |
-
}
|
| 95 |
-
char* context_length_start = argv[++i];
|
| 96 |
-
char* context_length_end = context_length_start;
|
| 97 |
-
options.context_length = strtoul(context_length_start, &context_length_end, 10);
|
| 98 |
-
if (context_length_end == context_length_start || *context_length_end != 0) {
|
| 99 |
-
fprintf(stderr, "Error: failed to parse context length value \"%s\"\n", context_length_start);
|
| 100 |
-
exit(EXIT_FAILURE);
|
| 101 |
-
}
|
| 102 |
-
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--max-tokens") == 0) {
|
| 103 |
-
if (i + 1 >= argc) {
|
| 104 |
-
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
| 105 |
-
print_usage(argv[0]);
|
| 106 |
-
exit(EXIT_FAILURE);
|
| 107 |
-
}
|
| 108 |
-
char* max_tokens_start = argv[++i];
|
| 109 |
-
char* max_tokens_end = max_tokens_start;
|
| 110 |
-
options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);
|
| 111 |
-
if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {
|
| 112 |
-
fprintf(stderr, "Error: failed to max tokens value \"%s\"\n", max_tokens_start);
|
| 113 |
-
exit(EXIT_FAILURE);
|
| 114 |
-
}
|
| 115 |
-
if (options.max_tokens == 0) {
|
| 116 |
-
fprintf(stderr, "Error: invalid max tokens value %zu\n", options.max_tokens);
|
| 117 |
-
exit(EXIT_FAILURE);
|
| 118 |
-
}
|
| 119 |
-
} else if (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--temperature") == 0) {
|
| 120 |
-
if (i + 1 >= argc) {
|
| 121 |
-
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
| 122 |
-
print_usage(argv[0]);
|
| 123 |
-
exit(EXIT_FAILURE);
|
| 124 |
-
}
|
| 125 |
-
char* temperature_start = argv[++i];
|
| 126 |
-
char* temperature_end = temperature_start;
|
| 127 |
-
options.temperature = strtof(temperature_start, &temperature_end);
|
| 128 |
-
if (temperature_end == temperature_start || *temperature_end != 0) {
|
| 129 |
-
fprintf(stderr, "Error: failed to parse temperature value \"%s\"\n", temperature_start);
|
| 130 |
-
exit(EXIT_FAILURE);
|
| 131 |
-
}
|
| 132 |
-
if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {
|
| 133 |
-
fprintf(stderr, "Error: invalid temperature value %f\n", options.temperature);
|
| 134 |
-
exit(EXIT_FAILURE);
|
| 135 |
-
}
|
| 136 |
-
} else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0) {
|
| 137 |
-
options.verbose = true;
|
| 138 |
-
} else {
|
| 139 |
-
if (options.model == NULL) {
|
| 140 |
-
options.model = argv[i];
|
| 141 |
-
} else {
|
| 142 |
-
fprintf(stderr, "Error: unexpected command-line argument %s\n", argv[i]);
|
| 143 |
-
print_usage(argv[0]);
|
| 144 |
-
exit(EXIT_FAILURE);
|
| 145 |
-
}
|
| 146 |
-
}
|
| 147 |
-
}
|
| 148 |
-
if (options.model == NULL) {
|
| 149 |
-
fprintf(stderr, "Error: missing required model argument\n");
|
| 150 |
-
print_usage(argv[0]);
|
| 151 |
-
exit(EXIT_FAILURE);
|
| 152 |
-
}
|
| 153 |
-
if (options.prompt == NULL) {
|
| 154 |
-
fprintf(stderr, "Error: missing required prompt argument\n");
|
| 155 |
-
print_usage(argv[0]);
|
| 156 |
-
exit(EXIT_FAILURE);
|
| 157 |
-
}
|
| 158 |
-
return options;
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
static void print_profile() {
|
| 163 |
-
const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
|
| 164 |
-
const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
|
| 165 |
-
const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);
|
| 166 |
-
const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
|
| 167 |
-
const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
|
| 168 |
-
if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
|
| 169 |
-
printf("\n");
|
| 170 |
-
}
|
| 171 |
-
if (num_prefill_tokens != 0) {
|
| 172 |
-
printf("Prefill speed (%zu tokens): %.1f tokens/second\n",
|
| 173 |
-
num_prefill_tokens,
|
| 174 |
-
(double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
|
| 175 |
-
}
|
| 176 |
-
if (num_generated_tokens != 0) {
|
| 177 |
-
printf("Generation speed (%zu tokens): %.1f tokens/second\n",
|
| 178 |
-
num_generated_tokens,
|
| 179 |
-
(double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);
|
| 180 |
-
}
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
static void ctrl_c_handler(int signum) {
|
| 184 |
-
print_profile();
|
| 185 |
-
exit(EXIT_SUCCESS);
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
int main(int argc, char *argv[]) {
|
| 189 |
-
enum gptoss_status status;
|
| 190 |
-
gptoss_model_t model = NULL;
|
| 191 |
-
gptoss_tokenizer_t tokenizer = NULL;
|
| 192 |
-
gptoss_context_t context = NULL;
|
| 193 |
-
|
| 194 |
-
struct sigaction act;
|
| 195 |
-
act.sa_handler = ctrl_c_handler;
|
| 196 |
-
sigaction(SIGINT, &act, NULL);
|
| 197 |
-
|
| 198 |
-
setvbuf(stdout, NULL, _IONBF, 0);
|
| 199 |
-
|
| 200 |
-
struct options options = parse_options(argc, argv);
|
| 201 |
-
|
| 202 |
-
const uint64_t load_start_time = mach_continuous_time();
|
| 203 |
-
status = gptoss_model_create_from_file(options.model, &model);
|
| 204 |
-
if (status != gptoss_status_success) {
|
| 205 |
-
fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
|
| 206 |
-
goto error;
|
| 207 |
-
}
|
| 208 |
-
size_t max_model_context_length = 0;
|
| 209 |
-
status = gptoss_model_get_max_context_length(model, &max_model_context_length);
|
| 210 |
-
if (status != gptoss_status_success) {
|
| 211 |
-
fprintf(stderr, "Error: failed to query maximum context length\n");
|
| 212 |
-
goto error;
|
| 213 |
-
}
|
| 214 |
-
assert(max_model_context_length != 0);
|
| 215 |
-
if (options.context_length == 0) {
|
| 216 |
-
options.context_length = max_model_context_length;
|
| 217 |
-
} else if (options.context_length > max_model_context_length) {
|
| 218 |
-
fprintf(stderr, "Error: context length %zu exceeds maximum context length %zu supported by the model\n", options.context_length, max_model_context_length);
|
| 219 |
-
goto error;
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
status = gptoss_model_get_tokenizer(model, &tokenizer);
|
| 223 |
-
if (status != gptoss_status_success) {
|
| 224 |
-
fprintf(stderr, "Error: failed to retrieve Tokenizer\n");
|
| 225 |
-
goto error;
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
uint32_t return_token_id = UINT32_MAX;
|
| 229 |
-
status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);
|
| 230 |
-
if (status != gptoss_status_success) {
|
| 231 |
-
fprintf(stderr, "Error: failed to query end-of-text token ID\n");
|
| 232 |
-
goto error;
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
status = gptoss_context_create(model, options.context_length, /*max_batch_tokens=*/0, &context);
|
| 236 |
-
if (status != gptoss_status_success) {
|
| 237 |
-
fprintf(stderr, "Error: failed to create Context object\n");
|
| 238 |
-
goto error;
|
| 239 |
-
}
|
| 240 |
-
if (options.verbose) {
|
| 241 |
-
printf("Model weights size: %.2lf MB\n", (double) model->weights_size * 0x1.0p-20);
|
| 242 |
-
printf("Model allocation size: %.2lf MB\n", (double) model->allocation_size * 0x1.0p-20);
|
| 243 |
-
printf("Context allocation size: %.2lf MB\n", (double) context->allocation_size * 0x1.0p-20);
|
| 244 |
-
printf(" Including KV cache: %.2lf MB\n", (double) context->kvcache_size * 0x1.0p-20);
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
const uint64_t load_end_time = mach_continuous_time();
|
| 248 |
-
const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);
|
| 249 |
-
if (options.verbose) {
|
| 250 |
-
printf("Loaded model in %.3f seconds\n", load_elapsed_seconds);
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
const uint64_t prefill_start_time = mach_continuous_time();
|
| 254 |
-
size_t num_prefill_tokens = 0;
|
| 255 |
-
status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);
|
| 256 |
-
if (status != gptoss_status_success) {
|
| 257 |
-
fprintf(stderr, "Error: failed to tokenize prompt \"%s\"\n", options.prompt);
|
| 258 |
-
goto error;
|
| 259 |
-
}
|
| 260 |
-
atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);
|
| 261 |
-
status = gptoss_context_process(context);
|
| 262 |
-
if (status != gptoss_status_success) {
|
| 263 |
-
fprintf(stderr, "Error: failed to process Context object\n");
|
| 264 |
-
goto error;
|
| 265 |
-
}
|
| 266 |
-
const uint64_t prefill_end_time = mach_continuous_time();
|
| 267 |
-
|
| 268 |
-
while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
|
| 269 |
-
|
| 270 |
-
uint32_t predicted_token = UINT32_MAX;
|
| 271 |
-
size_t num_predicted_tokens = 0;
|
| 272 |
-
const uint64_t inference_start_timestamp = mach_continuous_time();
|
| 273 |
-
status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
|
| 274 |
-
if (status != gptoss_status_success) {
|
| 275 |
-
fprintf(stderr, "Error: failed to sample from the Context object\n");
|
| 276 |
-
goto error;
|
| 277 |
-
}
|
| 278 |
-
const uint64_t inference_end_timestamp = mach_continuous_time();
|
| 279 |
-
|
| 280 |
-
if (predicted_token == return_token_id) {
|
| 281 |
-
// Yield token -> stop generation
|
| 282 |
-
break;
|
| 283 |
-
}
|
| 284 |
-
|
| 285 |
-
// Unembedding: detokenize
|
| 286 |
-
size_t token_size = 0;
|
| 287 |
-
const void* token_ptr = NULL;
|
| 288 |
-
status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);
|
| 289 |
-
if (status != gptoss_status_success) {
|
| 290 |
-
fprintf(stderr, "Error: failed to detokenize predicted token %" PRIu32 "\n", predicted_token);
|
| 291 |
-
goto error;
|
| 292 |
-
}
|
| 293 |
-
const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
|
| 294 |
-
if (previous_num_generated_tokens == 0) {
|
| 295 |
-
atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
|
| 296 |
-
} else {
|
| 297 |
-
atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
|
| 298 |
-
}
|
| 299 |
-
printf("%.*s", (int) token_size, (const char*) token_ptr);
|
| 300 |
-
|
| 301 |
-
status = gptoss_context_append_tokens(context, 1, &predicted_token);
|
| 302 |
-
if (status != gptoss_status_success) {
|
| 303 |
-
fprintf(stderr, "Error: failed to append predicted token %" PRIu32 " to context\n", predicted_token);
|
| 304 |
-
goto error;
|
| 305 |
-
}
|
| 306 |
-
}
|
| 307 |
-
|
| 308 |
-
print_profile();
|
| 309 |
-
|
| 310 |
-
return EXIT_SUCCESS;
|
| 311 |
-
|
| 312 |
-
error:
|
| 313 |
-
gptoss_context_release(context);
|
| 314 |
-
gptoss_tokenizer_release(tokenizer);
|
| 315 |
-
gptoss_model_release(model);
|
| 316 |
-
return EXIT_FAILURE;
|
| 317 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/source/include/internal/log.h
CHANGED
|
@@ -2,6 +2,9 @@
|
|
| 2 |
|
| 3 |
#include <stdarg.h>
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
void gptoss_format_log(const char* format, va_list args);
|
| 7 |
|
|
@@ -13,6 +16,10 @@ inline static void gptoss_log(const char* format, ...) {
|
|
| 13 |
va_end(args);
|
| 14 |
}
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
#define GPTOSS_LOG_ERROR(message, ...) \
|
| 17 |
gptoss_log("Error: " message "\n", ##__VA_ARGS__)
|
| 18 |
|
|
|
|
| 2 |
|
| 3 |
#include <stdarg.h>
|
| 4 |
|
| 5 |
+
#ifdef __cplusplus
|
| 6 |
+
extern "C" {
|
| 7 |
+
#endif
|
| 8 |
|
| 9 |
void gptoss_format_log(const char* format, va_list args);
|
| 10 |
|
|
|
|
| 16 |
va_end(args);
|
| 17 |
}
|
| 18 |
|
| 19 |
+
#ifdef __cplusplus
|
| 20 |
+
} // extern "C"
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
#define GPTOSS_LOG_ERROR(message, ...) \
|
| 24 |
gptoss_log("Error: " message "\n", ##__VA_ARGS__)
|
| 25 |
|
gptoss_kernels/source/include/internal/metal.h
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include <stddef.h>
|
| 4 |
-
|
| 5 |
#include <gpt-oss/types.h>
|
| 6 |
|
| 7 |
#ifdef __cplusplus
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include <stddef.h>
|
|
|
|
| 4 |
#include <gpt-oss/types.h>
|
| 5 |
|
| 6 |
#ifdef __cplusplus
|
gptoss_kernels/source/matmul.metal
CHANGED
|
@@ -43,7 +43,10 @@ kernel void gptoss_f32_bf16w_matmul(
|
|
| 43 |
bias += row;
|
| 44 |
output += gid.y * args.num_rows + row;
|
| 45 |
|
| 46 |
-
uint num_iter =
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
float4 sum4 = 0.0f;
|
| 49 |
do {
|
|
@@ -97,7 +100,10 @@ kernel void gptoss_f32_bf16w_matmul_qkv(
|
|
| 97 |
bias += row;
|
| 98 |
q += gid.y * args.num_rows;
|
| 99 |
|
| 100 |
-
uint num_iter =
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
float4 sum4 = 0.0f;
|
| 103 |
do {
|
|
|
|
| 43 |
bias += row;
|
| 44 |
output += gid.y * args.num_rows + row;
|
| 45 |
|
| 46 |
+
uint num_iter = 0;
|
| 47 |
+
if (simdgroup_tid < num_column_vecs) {
|
| 48 |
+
num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 49 |
+
}
|
| 50 |
|
| 51 |
float4 sum4 = 0.0f;
|
| 52 |
do {
|
|
|
|
| 100 |
bias += row;
|
| 101 |
q += gid.y * args.num_rows;
|
| 102 |
|
| 103 |
+
uint num_iter = 0;
|
| 104 |
+
if (simdgroup_tid < num_column_vecs) {
|
| 105 |
+
num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 106 |
+
}
|
| 107 |
|
| 108 |
float4 sum4 = 0.0f;
|
| 109 |
do {
|
gptoss_kernels/source/metal.m
CHANGED
|
@@ -9,7 +9,6 @@
|
|
| 9 |
#include <internal/log.h>
|
| 10 |
#include <internal/metal.h>
|
| 11 |
|
| 12 |
-
|
| 13 |
static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
|
| 14 |
if (!device) {
|
| 15 |
return 0;
|
|
|
|
| 9 |
#include <internal/log.h>
|
| 10 |
#include <internal/metal.h>
|
| 11 |
|
|
|
|
| 12 |
static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
|
| 13 |
if (!device) {
|
| 14 |
return 0;
|
gptoss_kernels/source/model.c
DELETED
|
@@ -1,581 +0,0 @@
|
|
| 1 |
-
#include <assert.h>
|
| 2 |
-
#include <inttypes.h>
|
| 3 |
-
#include <stdatomic.h>
|
| 4 |
-
#include <stdint.h>
|
| 5 |
-
#include <stdlib.h>
|
| 6 |
-
#include <string.h>
|
| 7 |
-
|
| 8 |
-
#include <errno.h> // errno, EISDIR, ENOENT, ENOTDIR
|
| 9 |
-
#include <fcntl.h> // open
|
| 10 |
-
#include <mach/vm_page_size.h> // vm_page_size
|
| 11 |
-
#include <sys/mman.h> // mmap, PROT_READ, MAP_PRIVATE
|
| 12 |
-
#include <sys/stat.h> // fstat, stat
|
| 13 |
-
#include <sys/types.h> // off_t, ssize_t
|
| 14 |
-
#include <unistd.h> // close
|
| 15 |
-
|
| 16 |
-
#include <gpt-oss.h>
|
| 17 |
-
|
| 18 |
-
#include "internal/datatype.h"
|
| 19 |
-
#include "internal/kernel-args.h" // gptoss_expert_prediction
|
| 20 |
-
#include "internal/log.h"
|
| 21 |
-
#include "internal/uuid.h"
|
| 22 |
-
#include "internal/storage.h"
|
| 23 |
-
#include "internal/math.h"
|
| 24 |
-
#include "internal/model.h"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
static size_t round_up_to_page_size(size_t bytes) {
|
| 28 |
-
const size_t page_size_mask = (size_t) vm_page_size - 1;
|
| 29 |
-
if ((bytes & page_size_mask) != 0) {
|
| 30 |
-
bytes |= page_size_mask;
|
| 31 |
-
bytes += 1;
|
| 32 |
-
}
|
| 33 |
-
return bytes;
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
static size_t round_down_to_page_size(size_t bytes) {
|
| 37 |
-
const size_t page_size_mask = (size_t) vm_page_size - 1;
|
| 38 |
-
return bytes & ~page_size_mask;
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
static enum gptoss_status read_fd(int fd, void* data, size_t size, const char* path) {
|
| 42 |
-
assert(fd != -1);
|
| 43 |
-
assert(data != NULL);
|
| 44 |
-
assert(size != 0);
|
| 45 |
-
|
| 46 |
-
size_t bytes_to_read = size;
|
| 47 |
-
char* current_byte = (char*) data;
|
| 48 |
-
do {
|
| 49 |
-
const ssize_t read_result = read(fd, current_byte, bytes_to_read);
|
| 50 |
-
if (read_result < 0) {
|
| 51 |
-
GPTOSS_LOG_ERROR("reading %zu bytes from file %s failed with error %d",
|
| 52 |
-
size, path, errno);
|
| 53 |
-
return gptoss_status_io_error;
|
| 54 |
-
}
|
| 55 |
-
current_byte += (size_t) read_result;
|
| 56 |
-
bytes_to_read -= (size_t) read_result;
|
| 57 |
-
} while (bytes_to_read != 0);
|
| 58 |
-
return gptoss_status_success;
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {
|
| 62 |
-
// radvisory.ra_count is int, so we can't prefetch 2GB+ at once
|
| 63 |
-
const size_t prefetch_max = round_down_to_page_size((size_t) INT_MAX);
|
| 64 |
-
do {
|
| 65 |
-
const size_t prefetch_size = math_min(size, prefetch_max);
|
| 66 |
-
const struct radvisory ra = {
|
| 67 |
-
.ra_offset = offset,
|
| 68 |
-
.ra_count = (int) prefetch_size,
|
| 69 |
-
};
|
| 70 |
-
if (fcntl(fd, F_RDADVISE, &ra) == -1) {
|
| 71 |
-
GPTOSS_LOG_WARNING("fcntl(%s, F_RDADVISE, .ra_offset=%zu, .ra_count=%d) failed with error %d\n",
|
| 72 |
-
path, (size_t) ra.ra_offset, ra.ra_count, errno);
|
| 73 |
-
return;
|
| 74 |
-
}
|
| 75 |
-
offset += prefetch_size;
|
| 76 |
-
size -= prefetch_size;
|
| 77 |
-
} while (size != 0);
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
|
| 81 |
-
const char* path,
|
| 82 |
-
gptoss_model_t* model_out)
|
| 83 |
-
{
|
| 84 |
-
*model_out = NULL;
|
| 85 |
-
|
| 86 |
-
enum gptoss_status status = gptoss_status_success;
|
| 87 |
-
struct gptoss_model* model = NULL;
|
| 88 |
-
struct gptoss_tokenizer* tokenizer = NULL;
|
| 89 |
-
int fd = -1;
|
| 90 |
-
size_t file_offset = 0;
|
| 91 |
-
|
| 92 |
-
fd = open(path, O_RDONLY);
|
| 93 |
-
if (fd == -1) {
|
| 94 |
-
GPTOSS_LOG_ERROR("open(%s) failed with error %d", path, errno);
|
| 95 |
-
switch (errno) {
|
| 96 |
-
case EISDIR:
|
| 97 |
-
case ENOENT:
|
| 98 |
-
case ENOTDIR:
|
| 99 |
-
status = gptoss_status_invalid_argument;
|
| 100 |
-
break;
|
| 101 |
-
default:
|
| 102 |
-
status = gptoss_status_io_error;
|
| 103 |
-
break;
|
| 104 |
-
}
|
| 105 |
-
goto cleanup;
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
struct gptoss_file_header file_header;
|
| 109 |
-
status = read_fd(fd, &file_header, sizeof(file_header), path);
|
| 110 |
-
if (status != gptoss_status_success) {
|
| 111 |
-
goto cleanup;
|
| 112 |
-
}
|
| 113 |
-
file_offset += sizeof(file_header);
|
| 114 |
-
|
| 115 |
-
if (file_header.magic[0] != 'G' ||
|
| 116 |
-
file_header.magic[1] != 'P' ||
|
| 117 |
-
file_header.magic[2] != 'T' ||
|
| 118 |
-
file_header.magic[3] != '-' ||
|
| 119 |
-
file_header.magic[4] != 'O' ||
|
| 120 |
-
file_header.magic[5] != 'S' ||
|
| 121 |
-
file_header.magic[6] != 'S' ||
|
| 122 |
-
file_header.magic[7] != ' ' ||
|
| 123 |
-
file_header.magic[8] != 'v' ||
|
| 124 |
-
file_header.magic[9] != '1' ||
|
| 125 |
-
file_header.magic[10] != '.' ||
|
| 126 |
-
file_header.magic[11] != '0' ||
|
| 127 |
-
file_header.zero != 0)
|
| 128 |
-
{
|
| 129 |
-
GPTOSS_LOG_ERROR("invalid magic in file %s", path);
|
| 130 |
-
status = gptoss_status_invalid_argument;
|
| 131 |
-
goto cleanup;
|
| 132 |
-
}
|
| 133 |
-
|
| 134 |
-
struct gptoss_uuid model_uuid;
|
| 135 |
-
status = read_fd(fd, &model_uuid, sizeof(model_uuid), path);
|
| 136 |
-
if (status != gptoss_status_success) {
|
| 137 |
-
goto cleanup;
|
| 138 |
-
}
|
| 139 |
-
file_offset += sizeof(model_uuid);
|
| 140 |
-
|
| 141 |
-
if (!gptoss_is_gptoss_model_uuid(&model_uuid)) {
|
| 142 |
-
GPTOSS_LOG_ERROR("unsupported model UUID " UUID_FORMAT, UUID_ARGS(model_uuid));
|
| 143 |
-
status = gptoss_status_invalid_argument;
|
| 144 |
-
goto cleanup;
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
struct gptoss_gptoss_model_header model_header;
|
| 148 |
-
status = read_fd(fd, &model_header, sizeof(model_header), path);
|
| 149 |
-
if (status != gptoss_status_success) {
|
| 150 |
-
goto cleanup;
|
| 151 |
-
}
|
| 152 |
-
file_offset += sizeof(model_header);
|
| 153 |
-
|
| 154 |
-
struct gptoss_uuid layout_uuid;
|
| 155 |
-
status = read_fd(fd, &layout_uuid, sizeof(layout_uuid), path);
|
| 156 |
-
if (status != gptoss_status_success) {
|
| 157 |
-
goto cleanup;
|
| 158 |
-
}
|
| 159 |
-
file_offset += sizeof(layout_uuid);
|
| 160 |
-
|
| 161 |
-
if (!gptoss_is_applegpu_layout_uuid(&layout_uuid)) {
|
| 162 |
-
GPTOSS_LOG_ERROR("unsupported layout UUID " UUID_FORMAT, UUID_ARGS(layout_uuid));
|
| 163 |
-
status = gptoss_status_invalid_argument;
|
| 164 |
-
goto cleanup;
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
const size_t model_size = sizeof(struct gptoss_model) + model_header.num_blocks * sizeof(struct gptoss_metal_buffer);
|
| 168 |
-
model = malloc(model_size);
|
| 169 |
-
if (model == NULL) {
|
| 170 |
-
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for model descriptor", model_size);
|
| 171 |
-
status = gptoss_status_insufficient_memory;
|
| 172 |
-
goto cleanup;
|
| 173 |
-
}
|
| 174 |
-
memset(model, 0, model_size);
|
| 175 |
-
|
| 176 |
-
atomic_store_explicit(&model->ref_count, 1, memory_order_relaxed);
|
| 177 |
-
model->context_length = model_header.context_length;
|
| 178 |
-
model->num_blocks = model_header.num_blocks;
|
| 179 |
-
model->num_experts = model_header.num_experts;
|
| 180 |
-
model->num_active_experts = model_header.num_active_experts;
|
| 181 |
-
model->embedding_dim = model_header.embedding_dim;
|
| 182 |
-
model->mlp_dim = model_header.mlp_dim;
|
| 183 |
-
model->swiglu_limit = model_header.swiglu_limit;
|
| 184 |
-
model->head_dim = model_header.head_dim;
|
| 185 |
-
model->num_heads = model_header.num_heads;
|
| 186 |
-
model->num_kv_heads = model_header.num_kv_heads;
|
| 187 |
-
model->attention_window = model_header.attention_window;
|
| 188 |
-
model->rope_theta = model_header.rope_theta;
|
| 189 |
-
model->interpolation_scale = model_header.interpolation_scale;
|
| 190 |
-
model->yarn_offset = model_header.yarn_offset;
|
| 191 |
-
model->yarn_scale = model_header.yarn_scale;
|
| 192 |
-
model->yarn_multiplier = model_header.yarn_multiplier;
|
| 193 |
-
model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;
|
| 194 |
-
|
| 195 |
-
struct gptoss_uuid tokenizer_uuid;
|
| 196 |
-
status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);
|
| 197 |
-
if (status != gptoss_status_success) {
|
| 198 |
-
goto cleanup;
|
| 199 |
-
}
|
| 200 |
-
file_offset += sizeof(tokenizer_uuid);
|
| 201 |
-
|
| 202 |
-
if (!gptoss_is_tiktoken_tokenizer_uuid(&tokenizer_uuid)) {
|
| 203 |
-
GPTOSS_LOG_ERROR("unsupported tokenizer UUID " UUID_FORMAT, UUID_ARGS(tokenizer_uuid));
|
| 204 |
-
status = gptoss_status_invalid_argument;
|
| 205 |
-
goto cleanup;
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
struct gptoss_tiktoken_tokenizer_header tokenizer_header;
|
| 209 |
-
status = read_fd(fd, &tokenizer_header, sizeof(tokenizer_header), path);
|
| 210 |
-
if (status != gptoss_status_success) {
|
| 211 |
-
goto cleanup;
|
| 212 |
-
}
|
| 213 |
-
file_offset += sizeof(tokenizer_header);
|
| 214 |
-
|
| 215 |
-
tokenizer = malloc(sizeof(struct gptoss_tokenizer));
|
| 216 |
-
if (tokenizer == NULL) {
|
| 217 |
-
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for tokenizer descriptor", sizeof(struct gptoss_tokenizer));
|
| 218 |
-
status = gptoss_status_insufficient_memory;
|
| 219 |
-
goto cleanup;
|
| 220 |
-
}
|
| 221 |
-
memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
|
| 222 |
-
// Initialize all special token IDs to UINT32_MAX (0xFF in all bytes)
|
| 223 |
-
memset(tokenizer->special_token_id, 0xFF, sizeof(tokenizer->special_token_id));
|
| 224 |
-
|
| 225 |
-
atomic_store_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
| 226 |
-
tokenizer->num_special_tokens = tokenizer_header.num_special_tokens;
|
| 227 |
-
tokenizer->num_text_tokens = tokenizer_header.num_text_tokens;
|
| 228 |
-
model->vocabulary_size = tokenizer_header.num_special_tokens + tokenizer_header.num_text_tokens;
|
| 229 |
-
for (uint32_t t = 0; t < tokenizer_header.num_special_tokens; t++) {
|
| 230 |
-
struct gptoss_uuid token_uuid;
|
| 231 |
-
status = read_fd(fd, &token_uuid, sizeof(token_uuid), path);
|
| 232 |
-
if (status != gptoss_status_success) {
|
| 233 |
-
goto cleanup;
|
| 234 |
-
}
|
| 235 |
-
file_offset += sizeof(token_uuid);
|
| 236 |
-
|
| 237 |
-
const enum gptoss_special_token token = gptoss_special_token_decode_uuid(&token_uuid);
|
| 238 |
-
if (token != gptoss_special_token_invalid) {
|
| 239 |
-
tokenizer->special_token_id[token - 1] = tokenizer_header.num_text_tokens + t;
|
| 240 |
-
}
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
const size_t tokenizer_start_offset = file_offset;
|
| 244 |
-
const size_t tokenizer_end_offset = tokenizer_start_offset + tokenizer_header.regex_size + tokenizer_header.tokens_size;
|
| 245 |
-
const size_t tokenizer_mapping_start = round_down_to_page_size(tokenizer_start_offset);
|
| 246 |
-
const size_t tokenizer_mapping_size = round_up_to_page_size(tokenizer_end_offset) - tokenizer_mapping_start;
|
| 247 |
-
void* tokenizer_mapping_ptr = mmap(NULL, tokenizer_mapping_size, PROT_READ, MAP_PRIVATE, fd, tokenizer_mapping_start);
|
| 248 |
-
if (tokenizer_mapping_ptr == (void*) -1) {
|
| 249 |
-
GPTOSS_LOG_ERROR("failed to mmap(%s) tokenizer at offset %zu size %zu",
|
| 250 |
-
path, tokenizer_mapping_start, tokenizer_mapping_size);
|
| 251 |
-
status = gptoss_status_io_error;
|
| 252 |
-
goto cleanup;
|
| 253 |
-
}
|
| 254 |
-
tokenizer->mapping_ptr = tokenizer_mapping_ptr;
|
| 255 |
-
tokenizer->mapping_size = tokenizer_mapping_size;
|
| 256 |
-
tokenizer->regex_ptr = (const char*) tokenizer_mapping_ptr + (tokenizer_start_offset - tokenizer_mapping_start);
|
| 257 |
-
tokenizer->tokens_ptr = tokenizer->regex_ptr + tokenizer_header.regex_size;
|
| 258 |
-
|
| 259 |
-
if (madvise(tokenizer_mapping_ptr, tokenizer_mapping_size, MADV_RANDOM | MADV_WILLNEED) != 0) {
|
| 260 |
-
GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, tokenizer_mapping_size, errno);
|
| 261 |
-
}
|
| 262 |
-
|
| 263 |
-
prefetch_fd(fd, tokenizer_mapping_start, tokenizer_mapping_size, path);
|
| 264 |
-
|
| 265 |
-
struct stat model_stat = {0};
|
| 266 |
-
int stat_result = fstat(fd, &model_stat);
|
| 267 |
-
if (stat_result != 0) {
|
| 268 |
-
GPTOSS_LOG_ERROR("stat(%s) failed with error %d", path, errno);
|
| 269 |
-
status = gptoss_status_io_error;
|
| 270 |
-
goto cleanup;
|
| 271 |
-
}
|
| 272 |
-
|
| 273 |
-
const size_t model_mapping_start = round_up_to_page_size(tokenizer_end_offset);
|
| 274 |
-
const size_t model_mapping_size = round_up_to_page_size((size_t) model_stat.st_size) - model_mapping_start;
|
| 275 |
-
void* model_mapping_ptr = mmap(NULL, model_mapping_size, PROT_READ, MAP_PRIVATE, fd, model_mapping_start);
|
| 276 |
-
if (model_mapping_ptr == (void*) -1) {
|
| 277 |
-
GPTOSS_LOG_ERROR("failed to mmap(%s) model weights at offset %zu size %zu",
|
| 278 |
-
path, model_mapping_start, model_mapping_size);
|
| 279 |
-
status = gptoss_status_io_error;
|
| 280 |
-
goto cleanup;
|
| 281 |
-
}
|
| 282 |
-
model->mapping_ptr = model_mapping_ptr;
|
| 283 |
-
model->mapping_size = model_mapping_size;
|
| 284 |
-
|
| 285 |
-
if (madvise(model_mapping_ptr, model_mapping_size, MADV_SEQUENTIAL | MADV_WILLNEED) != 0) {
|
| 286 |
-
GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
|
| 287 |
-
}
|
| 288 |
-
|
| 289 |
-
prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
|
| 290 |
-
|
| 291 |
-
if (mlock(model_mapping_ptr, model_mapping_size) != 0) {
|
| 292 |
-
GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
|
| 293 |
-
} else {
|
| 294 |
-
model->lock_memory = true;
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
// Initialize Metal
|
| 298 |
-
status = gptoss_metal_device_create_system_default(&model->device);
|
| 299 |
-
if (status != gptoss_status_success) {
|
| 300 |
-
goto cleanup;
|
| 301 |
-
}
|
| 302 |
-
model->max_threadgroups = model->device.num_cores * 3;
|
| 303 |
-
status = gptoss_metal_command_queue_create(&model->device, &model->command_queue);
|
| 304 |
-
if (status != gptoss_status_success) {
|
| 305 |
-
goto cleanup;
|
| 306 |
-
}
|
| 307 |
-
|
| 308 |
-
// Metal kernels
|
| 309 |
-
status = gptoss_metal_library_create_default(&model->device, &model->library);
|
| 310 |
-
if (status != gptoss_status_success) {
|
| 311 |
-
goto cleanup;
|
| 312 |
-
}
|
| 313 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_bf16_f32_embeddings", &model->bf16_f32_embeddings_fn);
|
| 314 |
-
if (status != gptoss_status_success) {
|
| 315 |
-
goto cleanup;
|
| 316 |
-
}
|
| 317 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_rmsnorm", &model->f32_bf16w_rmsnorm_fn);
|
| 318 |
-
if (status != gptoss_status_success) {
|
| 319 |
-
goto cleanup;
|
| 320 |
-
}
|
| 321 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul", &model->f32_bf16w_matmul_fn);
|
| 322 |
-
if (status != gptoss_status_success) {
|
| 323 |
-
goto cleanup;
|
| 324 |
-
}
|
| 325 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul_qkv", &model->f32_bf16w_matmul_qkv_fn);
|
| 326 |
-
if (status != gptoss_status_success) {
|
| 327 |
-
goto cleanup;
|
| 328 |
-
}
|
| 329 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn);
|
| 330 |
-
if (status != gptoss_status_success) {
|
| 331 |
-
goto cleanup;
|
| 332 |
-
}
|
| 333 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_attn_output", &model->f32_bf16w_dense_matmul_attn_output_fn);
|
| 334 |
-
if (status != gptoss_status_success) {
|
| 335 |
-
goto cleanup;
|
| 336 |
-
}
|
| 337 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_mlp_gate", &model->f32_bf16w_dense_matmul_mlp_gate_fn);
|
| 338 |
-
if (status != gptoss_status_success) {
|
| 339 |
-
goto cleanup;
|
| 340 |
-
}
|
| 341 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn);
|
| 342 |
-
if (status != gptoss_status_success) {
|
| 343 |
-
goto cleanup;
|
| 344 |
-
}
|
| 345 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_rope", &model->f32_rope_fn);
|
| 346 |
-
if (status != gptoss_status_success) {
|
| 347 |
-
goto cleanup;
|
| 348 |
-
}
|
| 349 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_expert_routing_metadata", &model->f32_expert_routing_metadata_fn);
|
| 350 |
-
if (status != gptoss_status_success) {
|
| 351 |
-
goto cleanup;
|
| 352 |
-
}
|
| 353 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_scatter_e4", &model->f32_scatter_e4_fn);
|
| 354 |
-
if (status != gptoss_status_success) {
|
| 355 |
-
goto cleanup;
|
| 356 |
-
}
|
| 357 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul_swiglu", &model->f32_mf4w_moe_dense_matmul_swiglu_fn);
|
| 358 |
-
if (status != gptoss_status_success) {
|
| 359 |
-
goto cleanup;
|
| 360 |
-
}
|
| 361 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul", &model->f32_mf4w_moe_dense_matmul_fn);
|
| 362 |
-
if (status != gptoss_status_success) {
|
| 363 |
-
goto cleanup;
|
| 364 |
-
}
|
| 365 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_gather_and_accumulate_e4", &model->f32_gather_and_accumulate_e4_fn);
|
| 366 |
-
if (status != gptoss_status_success) {
|
| 367 |
-
goto cleanup;
|
| 368 |
-
}
|
| 369 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul_swiglu", &model->f32_mf4w_moe_matmul_swiglu_fn);
|
| 370 |
-
if (status != gptoss_status_success) {
|
| 371 |
-
goto cleanup;
|
| 372 |
-
}
|
| 373 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul", &model->f32_mf4w_moe_matmul_fn);
|
| 374 |
-
if (status != gptoss_status_success) {
|
| 375 |
-
goto cleanup;
|
| 376 |
-
}
|
| 377 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_accumulate_e4", &model->f32_accumulate_e4_fn);
|
| 378 |
-
if (status != gptoss_status_success) {
|
| 379 |
-
goto cleanup;
|
| 380 |
-
}
|
| 381 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e32_k4", &model->f32_topk_softmax_e32_k4_fn);
|
| 382 |
-
if (status != gptoss_status_success) {
|
| 383 |
-
goto cleanup;
|
| 384 |
-
}
|
| 385 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e128_k4", &model->f32_topk_softmax_e128_k4_fn);
|
| 386 |
-
if (status != gptoss_status_success) {
|
| 387 |
-
goto cleanup;
|
| 388 |
-
}
|
| 389 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_softmax", &model->f32_softmax_fn);
|
| 390 |
-
if (status != gptoss_status_success) {
|
| 391 |
-
goto cleanup;
|
| 392 |
-
}
|
| 393 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn);
|
| 394 |
-
if (status != gptoss_status_success) {
|
| 395 |
-
goto cleanup;
|
| 396 |
-
}
|
| 397 |
-
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
|
| 398 |
-
if (status != gptoss_status_success) {
|
| 399 |
-
goto cleanup;
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
// Kernel launch parameters
|
| 403 |
-
model->embeddings_threadgroup_size = 512;
|
| 404 |
-
model->attn_qkv_threadgroup_size = 1024;
|
| 405 |
-
model->attn_out_threadgroup_size = 768;
|
| 406 |
-
model->mlp_gate_threadgroup_size = 256;
|
| 407 |
-
model->mlp_swiglu_threadgroup_size = 192;
|
| 408 |
-
model->mlp_out_threadgroup_size = 192;
|
| 409 |
-
model->mlp_acc_threadgroup_size = 768;
|
| 410 |
-
model->unembedding_threadgroup_size = 416;
|
| 411 |
-
|
| 412 |
-
// Weight buffers
|
| 413 |
-
const char* current_ptr = (const char*) model->mapping_ptr;
|
| 414 |
-
|
| 415 |
-
const size_t embedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 416 |
-
model->attn_rmsnorm_gain_offset = embedding_weight_size;
|
| 417 |
-
const size_t rmsnorm_weight_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 418 |
-
model->attn_qkv_weight_offset = model->attn_rmsnorm_gain_offset + rmsnorm_weight_size;
|
| 419 |
-
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
|
| 420 |
-
const size_t attn_qkv_weight_size = math_round_up_po2(attn_qkv_dim * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 421 |
-
model->attn_qkv_bias_offset = model->attn_qkv_weight_offset + attn_qkv_weight_size;
|
| 422 |
-
const size_t attn_qkv_bias_size = math_round_up_po2(attn_qkv_dim * sizeof(gptoss_bfloat16), 16);
|
| 423 |
-
model->attn_sdpa_sink_offset = model->attn_qkv_bias_offset + attn_qkv_bias_size;
|
| 424 |
-
const size_t attn_sink_weight_size = math_round_up_po2(model->num_heads * sizeof(gptoss_bfloat16), 16);
|
| 425 |
-
model->attn_out_weight_offset = model->attn_sdpa_sink_offset + attn_sink_weight_size;
|
| 426 |
-
const size_t attn_out_weight_size = math_round_up_po2(model->embedding_dim * model->num_heads * model->head_dim * sizeof(gptoss_bfloat16), 16);
|
| 427 |
-
model->attn_out_bias_offset = model->attn_out_weight_offset + attn_out_weight_size;
|
| 428 |
-
const size_t attn_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 429 |
-
model->mlp_rmsnorm_gain_offset = model->attn_out_bias_offset + attn_out_bias_size;
|
| 430 |
-
model->mlp_gate_weight_offset = model->mlp_rmsnorm_gain_offset + rmsnorm_weight_size;
|
| 431 |
-
const size_t mlp_gate_weight_size = math_round_up_po2(model->num_experts * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 432 |
-
model->mlp_gate_bias_offset = model->mlp_gate_weight_offset + mlp_gate_weight_size;
|
| 433 |
-
const size_t mlp_gate_bias_size = math_round_up_po2(model->num_experts * sizeof(gptoss_bfloat16), 16);
|
| 434 |
-
const size_t per_block_shared_weights_size =
|
| 435 |
-
rmsnorm_weight_size + attn_qkv_weight_size + attn_qkv_bias_size + attn_sink_weight_size + attn_out_weight_size + attn_out_bias_size +
|
| 436 |
-
rmsnorm_weight_size + mlp_gate_weight_size + mlp_gate_bias_size;
|
| 437 |
-
model->rmsnorm_weight_offset = embedding_weight_size + model->num_blocks * per_block_shared_weights_size;
|
| 438 |
-
model->unembedding_weight_offset = model->rmsnorm_weight_offset + rmsnorm_weight_size;
|
| 439 |
-
const size_t unembedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 440 |
-
|
| 441 |
-
model->per_block_shared_weights_size = per_block_shared_weights_size;
|
| 442 |
-
const size_t shared_weights_size =
|
| 443 |
-
round_up_to_page_size(embedding_weight_size + rmsnorm_weight_size + unembedding_weight_size + model->num_blocks * per_block_shared_weights_size);
|
| 444 |
-
|
| 445 |
-
status = gptoss_metal_buffer_wrap(&model->device, shared_weights_size, current_ptr, &model->shared_weight_buffer);
|
| 446 |
-
if (status != gptoss_status_success) {
|
| 447 |
-
GPTOSS_LOG_ERROR("failed to map expert-shared weight of size %zu onto a Metal buffer", shared_weights_size);
|
| 448 |
-
goto cleanup;
|
| 449 |
-
}
|
| 450 |
-
current_ptr += shared_weights_size;
|
| 451 |
-
model->weights_size += shared_weights_size;
|
| 452 |
-
|
| 453 |
-
const size_t mlp_swiglu_weight_block_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 2, 16);
|
| 454 |
-
model->mlp_swiglu_scale_offset = mlp_swiglu_weight_block_size;
|
| 455 |
-
const size_t mlp_swiglu_weight_scale_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 32, 16);
|
| 456 |
-
model->mlp_swiglu_bias_offset = model->mlp_swiglu_scale_offset + mlp_swiglu_weight_scale_size;
|
| 457 |
-
const size_t mlp_swiglu_bias_size = math_round_up_po2(2 * model->mlp_dim * sizeof(gptoss_bfloat16), 16);
|
| 458 |
-
model->mlp_out_block_offset = model->mlp_swiglu_bias_offset + mlp_swiglu_bias_size;
|
| 459 |
-
const size_t mlp_out_weight_block_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 2, 16);
|
| 460 |
-
model->mlp_out_scale_offset = model->mlp_out_block_offset + mlp_out_weight_block_size;
|
| 461 |
-
const size_t mlp_out_weight_scale_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 32, 16);
|
| 462 |
-
model->mlp_out_bias_offset = model->mlp_out_scale_offset + mlp_out_weight_scale_size;
|
| 463 |
-
const size_t mlp_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 464 |
-
model->per_expert_block_weight_size =
|
| 465 |
-
mlp_swiglu_weight_block_size + mlp_swiglu_weight_scale_size + mlp_swiglu_bias_size + mlp_out_weight_block_size + mlp_out_weight_scale_size + mlp_out_bias_size;
|
| 466 |
-
const size_t moe_block_weight_size = round_up_to_page_size(model->num_experts * model->per_expert_block_weight_size);
|
| 467 |
-
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
| 468 |
-
status = gptoss_metal_buffer_wrap(&model->device, moe_block_weight_size, current_ptr, &model->block_weight_buffers[n]);
|
| 469 |
-
if (status != gptoss_status_success) {
|
| 470 |
-
GPTOSS_LOG_ERROR("failed to map block #%" PRIu32 " MoE weight of size %zu onto a Metal buffer",
|
| 471 |
-
n, moe_block_weight_size);
|
| 472 |
-
goto cleanup;
|
| 473 |
-
}
|
| 474 |
-
current_ptr += moe_block_weight_size;
|
| 475 |
-
model->weights_size += moe_block_weight_size;
|
| 476 |
-
}
|
| 477 |
-
|
| 478 |
-
// Commit tokenizer
|
| 479 |
-
model->tokenizer = tokenizer;
|
| 480 |
-
tokenizer = NULL;
|
| 481 |
-
|
| 482 |
-
// Commit model
|
| 483 |
-
*model_out = model;
|
| 484 |
-
model = NULL;
|
| 485 |
-
|
| 486 |
-
cleanup:
|
| 487 |
-
if (fd != -1) {
|
| 488 |
-
close(fd);
|
| 489 |
-
fd = -1;
|
| 490 |
-
}
|
| 491 |
-
gptoss_model_release(model); // does nothing if model is NULL
|
| 492 |
-
gptoss_tokenizer_release(tokenizer); // does nothing if tokenizer is NULL
|
| 493 |
-
return status;
|
| 494 |
-
}
|
| 495 |
-
|
| 496 |
-
enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
|
| 497 |
-
gptoss_model_t model,
|
| 498 |
-
gptoss_tokenizer_t* tokenizer_out)
|
| 499 |
-
{
|
| 500 |
-
gptoss_tokenizer_t tokenizer = model->tokenizer;
|
| 501 |
-
atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
| 502 |
-
*tokenizer_out = tokenizer;
|
| 503 |
-
return gptoss_status_success;
|
| 504 |
-
}
|
| 505 |
-
|
| 506 |
-
enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
|
| 507 |
-
gptoss_model_t model,
|
| 508 |
-
size_t* max_context_length_out)
|
| 509 |
-
{
|
| 510 |
-
*max_context_length_out = model->context_length;
|
| 511 |
-
return gptoss_status_success;
|
| 512 |
-
}
|
| 513 |
-
|
| 514 |
-
enum gptoss_status GPTOSS_ABI gptoss_model_retain(
|
| 515 |
-
gptoss_model_t model)
|
| 516 |
-
{
|
| 517 |
-
atomic_fetch_add_explicit(&model->ref_count, 1, memory_order_relaxed);
|
| 518 |
-
return gptoss_status_success;
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
enum gptoss_status GPTOSS_ABI gptoss_model_release(
|
| 522 |
-
gptoss_model_t model)
|
| 523 |
-
{
|
| 524 |
-
if (model != NULL) {
|
| 525 |
-
if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {
|
| 526 |
-
gptoss_tokenizer_release(model->tokenizer);
|
| 527 |
-
|
| 528 |
-
// Weight buffers
|
| 529 |
-
gptoss_metal_buffer_release(&model->shared_weight_buffer);
|
| 530 |
-
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
| 531 |
-
gptoss_metal_buffer_release(&model->block_weight_buffers[n]);
|
| 532 |
-
}
|
| 533 |
-
|
| 534 |
-
// Metal kernels
|
| 535 |
-
gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);
|
| 536 |
-
gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);
|
| 537 |
-
gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);
|
| 538 |
-
gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn);
|
| 539 |
-
gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn);
|
| 540 |
-
gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn);
|
| 541 |
-
gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn);
|
| 542 |
-
gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);
|
| 543 |
-
gptoss_metal_function_release(&model->f32_rope_fn);
|
| 544 |
-
gptoss_metal_function_release(&model->f32_expert_routing_metadata_fn);
|
| 545 |
-
gptoss_metal_function_release(&model->f32_scatter_e4_fn);
|
| 546 |
-
gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_swiglu_fn);
|
| 547 |
-
gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_fn);
|
| 548 |
-
gptoss_metal_function_release(&model->f32_gather_and_accumulate_e4_fn);
|
| 549 |
-
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);
|
| 550 |
-
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn);
|
| 551 |
-
gptoss_metal_function_release(&model->f32_accumulate_e4_fn);
|
| 552 |
-
gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
|
| 553 |
-
gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
|
| 554 |
-
gptoss_metal_function_release(&model->f32_softmax_fn);
|
| 555 |
-
gptoss_metal_function_release(&model->f32_sample_fn);
|
| 556 |
-
gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
|
| 557 |
-
gptoss_metal_library_release(&model->library);
|
| 558 |
-
|
| 559 |
-
gptoss_metal_command_queue_release(&model->command_queue);
|
| 560 |
-
gptoss_metal_device_release(&model->device);
|
| 561 |
-
// Weight buffers
|
| 562 |
-
|
| 563 |
-
if (model->mapping_ptr != NULL && model->mapping_size != 0) {
|
| 564 |
-
if (model->lock_memory) {
|
| 565 |
-
if (munlock(model->mapping_ptr, model->mapping_size) != 0) {
|
| 566 |
-
GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno);
|
| 567 |
-
}
|
| 568 |
-
}
|
| 569 |
-
|
| 570 |
-
if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
|
| 571 |
-
GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
|
| 572 |
-
}
|
| 573 |
-
}
|
| 574 |
-
|
| 575 |
-
const size_t model_size = sizeof(struct gptoss_model) + model->num_blocks * sizeof(struct gptoss_metal_buffer);
|
| 576 |
-
memset(model, 0, model_size);
|
| 577 |
-
free(model);
|
| 578 |
-
}
|
| 579 |
-
}
|
| 580 |
-
return gptoss_status_success;
|
| 581 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gptoss_kernels/source/tensor_wrappers.cpp
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <internal/metal-kernels.h>
|
| 2 |
+
#include <internal/metal.h>
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
|
| 5 |
+
void f32_bf16w_matmul_torch(const at::Tensor &input,
|
| 6 |
+
const at::Tensor &weight_bf16,
|
| 7 |
+
const at::Tensor &bias_bf16,
|
| 8 |
+
at::Tensor &output,
|
| 9 |
+
int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size)
|
| 10 |
+
{
|
| 11 |
+
TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
|
| 12 |
+
TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
|
| 13 |
+
TORCH_CHECK(bias_bf16.dtype() == at::kBFloat16, "bias must be bfloat16");
|
| 14 |
+
TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
|
| 15 |
+
|
| 16 |
+
TORCH_CHECK(input.dim() == 2, "input must be 2D");
|
| 17 |
+
TORCH_CHECK(weight_bf16.dim() == 2, "weight must be 2D");
|
| 18 |
+
TORCH_CHECK(bias_bf16.dim() == 1, "bias must be 1D");
|
| 19 |
+
TORCH_CHECK(output.dim() == 2, "output must be 2D");
|
| 20 |
+
TORCH_CHECK(input.size(0) == num_tokens && input.size(1) == num_cols,
|
| 21 |
+
"input shape must be [num_tokens, num_cols]");
|
| 22 |
+
TORCH_CHECK(weight_bf16.size(0) == num_cols && weight_bf16.size(1) == num_rows,
|
| 23 |
+
"weight shape must be [num_cols, num_rows]");
|
| 24 |
+
TORCH_CHECK(bias_bf16.size(0) == num_rows, "bias length must be num_rows");
|
| 25 |
+
TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_rows,
|
| 26 |
+
"output shape must be [num_tokens, num_rows]");
|
| 27 |
+
|
| 28 |
+
auto input_cpu = input.contiguous().to(at::kCPU);
|
| 29 |
+
auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
|
| 30 |
+
auto bias_cpu = bias_bf16.contiguous().to(at::kCPU);
|
| 31 |
+
auto out_cpu = output.detach().to(at::kCPU).contiguous().clone();
|
| 32 |
+
|
| 33 |
+
gptoss_metal_device device{}; gptoss_metal_library library{};
|
| 34 |
+
gptoss_metal_function fn{}; gptoss_metal_command_queue cq{};
|
| 35 |
+
gptoss_metal_command_buffer cb{};
|
| 36 |
+
|
| 37 |
+
TORCH_CHECK(gptoss_metal_device_create_system_default(&device) == gptoss_status_success, "device_create failed");
|
| 38 |
+
TORCH_CHECK(gptoss_metal_library_create_default(&device, &library) == gptoss_status_success, "library_create failed");
|
| 39 |
+
TORCH_CHECK(gptoss_metal_function_create(&library, "gptoss_f32_bf16w_matmul", &fn) == gptoss_status_success, "function_create failed");
|
| 40 |
+
TORCH_CHECK(gptoss_metal_command_queue_create(&device, &cq) == gptoss_status_success, "cq_create failed");
|
| 41 |
+
TORCH_CHECK(gptoss_metal_command_buffer_create(&cq, &cb) == gptoss_status_success, "cb_create failed");
|
| 42 |
+
|
| 43 |
+
const size_t in_bytes = (size_t)num_tokens * (size_t)num_cols * sizeof(float);
|
| 44 |
+
const size_t wt_bytes = (size_t)num_rows * (size_t)num_cols * sizeof(uint16_t);
|
| 45 |
+
const size_t bs_bytes = (size_t)num_rows * sizeof(uint16_t);
|
| 46 |
+
const size_t out_bytes = (size_t)num_tokens * (size_t)num_rows * sizeof(float);
|
| 47 |
+
|
| 48 |
+
gptoss_metal_buffer in_buf{}, wt_buf{}, bs_buf{}, out_buf{}, ctrl_buf{};
|
| 49 |
+
TORCH_CHECK(gptoss_metal_buffer_wrap(&device, in_bytes, input_cpu.data_ptr(), &in_buf) == gptoss_status_success, "wrap input failed");
|
| 50 |
+
TORCH_CHECK(gptoss_metal_buffer_wrap(&device, wt_bytes, weight_cpu.data_ptr(), &wt_buf) == gptoss_status_success, "wrap weight failed");
|
| 51 |
+
TORCH_CHECK(gptoss_metal_buffer_wrap(&device, bs_bytes, bias_cpu.data_ptr(), &bs_buf) == gptoss_status_success, "wrap bias failed");
|
| 52 |
+
TORCH_CHECK(gptoss_metal_buffer_create(&device, out_bytes, nullptr, &out_buf) == gptoss_status_success, "alloc out failed");
|
| 53 |
+
uint32_t ctrl_zero = 0;
|
| 54 |
+
TORCH_CHECK(gptoss_metal_buffer_create(&device, sizeof(uint32_t), &ctrl_zero, &ctrl_buf) == gptoss_status_success, "alloc ctrl failed");
|
| 55 |
+
|
| 56 |
+
TORCH_CHECK(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
| 57 |
+
&cb, &fn, (size_t)threadgroup_size,
|
| 58 |
+
&in_buf, 0, &wt_buf, 0, &bs_buf, 0, &out_buf, 0, &ctrl_buf, 0,
|
| 59 |
+
(uint32_t)num_tokens, (uint32_t)num_cols, (uint32_t)num_rows) == gptoss_status_success, "encode failed");
|
| 60 |
+
|
| 61 |
+
TORCH_CHECK(gptoss_metal_command_buffer_commit(&cb) == gptoss_status_success, "commit failed");
|
| 62 |
+
TORCH_CHECK(gptoss_metal_command_buffer_wait_completion(&cb, nullptr) == gptoss_status_success, "wait failed");
|
| 63 |
+
|
| 64 |
+
std::memcpy(out_cpu.data_ptr(), out_buf.ptr, out_bytes);
|
| 65 |
+
output.copy_(out_cpu.to(output.device(), /*non_blocking=*/false, /*copy=*/true));
|
| 66 |
+
|
| 67 |
+
(void) gptoss_metal_command_buffer_release(&cb);
|
| 68 |
+
(void) gptoss_metal_command_queue_release(&cq);
|
| 69 |
+
(void) gptoss_metal_function_release(&fn);
|
| 70 |
+
(void) gptoss_metal_library_release(&library);
|
| 71 |
+
(void) gptoss_metal_device_release(&device);
|
| 72 |
+
(void) gptoss_metal_buffer_release(&ctrl_buf);
|
| 73 |
+
(void) gptoss_metal_buffer_release(&out_buf);
|
| 74 |
+
(void) gptoss_metal_buffer_release(&bs_buf);
|
| 75 |
+
(void) gptoss_metal_buffer_release(&wt_buf);
|
| 76 |
+
(void) gptoss_metal_buffer_release(&in_buf);
|
| 77 |
+
}
|
gptoss_kernels/source/tokenizer.c
DELETED
|
@@ -1,106 +0,0 @@
|
|
| 1 |
-
#include <assert.h>
|
| 2 |
-
#include <stdatomic.h>
|
| 3 |
-
#include <stddef.h>
|
| 4 |
-
#include <stdint.h>
|
| 5 |
-
#include <stdlib.h>
|
| 6 |
-
#include <string.h>
|
| 7 |
-
|
| 8 |
-
#include <errno.h>
|
| 9 |
-
#include <sys/mman.h>
|
| 10 |
-
|
| 11 |
-
#include <gpt-oss.h>
|
| 12 |
-
|
| 13 |
-
#include "internal/log.h"
|
| 14 |
-
#include "internal/model.h"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
|
| 18 |
-
gptoss_tokenizer_t tokenizer,
|
| 19 |
-
enum gptoss_special_token token_type,
|
| 20 |
-
uint32_t* token_id_out)
|
| 21 |
-
{
|
| 22 |
-
uint32_t token_id = UINT32_MAX;
|
| 23 |
-
if (token_type != gptoss_special_token_invalid && token_type < gptoss_special_token_max)
|
| 24 |
-
{
|
| 25 |
-
token_id = tokenizer->special_token_id[(uint32_t) token_type - 1];
|
| 26 |
-
}
|
| 27 |
-
if (token_id == UINT32_MAX) {
|
| 28 |
-
return gptoss_status_invalid_argument;
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
*token_id_out = token_id;
|
| 32 |
-
return gptoss_status_success;
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
|
| 36 |
-
gptoss_tokenizer_t tokenizer,
|
| 37 |
-
uint32_t* num_text_tokens_out)
|
| 38 |
-
{
|
| 39 |
-
*num_text_tokens_out = tokenizer->num_text_tokens;
|
| 40 |
-
return gptoss_status_success;
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
|
| 44 |
-
gptoss_tokenizer_t tokenizer,
|
| 45 |
-
uint32_t* num_special_tokens_out)
|
| 46 |
-
{
|
| 47 |
-
*num_special_tokens_out = tokenizer->num_special_tokens;
|
| 48 |
-
return gptoss_status_success;
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
|
| 52 |
-
gptoss_tokenizer_t tokenizer,
|
| 53 |
-
uint32_t* num_tokens_out)
|
| 54 |
-
{
|
| 55 |
-
*num_tokens_out = tokenizer->num_text_tokens + tokenizer->num_special_tokens;
|
| 56 |
-
return gptoss_status_success;
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
|
| 60 |
-
gptoss_tokenizer_t tokenizer,
|
| 61 |
-
uint32_t token_id,
|
| 62 |
-
const void** token_ptr_out,
|
| 63 |
-
size_t* token_size_out)
|
| 64 |
-
{
|
| 65 |
-
if (token_id >= tokenizer->num_text_tokens) {
|
| 66 |
-
return gptoss_status_invalid_argument;
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
const char* token_ptr = (const char*) tokenizer->tokens_ptr;
|
| 70 |
-
for (uint32_t t = 0; t < token_id; t++) {
|
| 71 |
-
// Reading unaligned uint16_t
|
| 72 |
-
uint16_t token_length;
|
| 73 |
-
memcpy(&token_length, token_ptr, sizeof(token_length));
|
| 74 |
-
|
| 75 |
-
token_ptr += (size_t) token_length + sizeof(uint16_t);
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
*token_ptr_out = (const void*) (token_ptr + sizeof(uint16_t));
|
| 79 |
-
*token_size_out = (size_t) *token_ptr;
|
| 80 |
-
return gptoss_status_success;
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
|
| 84 |
-
gptoss_tokenizer_t tokenizer)
|
| 85 |
-
{
|
| 86 |
-
atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
| 87 |
-
return gptoss_status_success;
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
|
| 91 |
-
gptoss_tokenizer_t tokenizer)
|
| 92 |
-
{
|
| 93 |
-
if (tokenizer != NULL) {
|
| 94 |
-
if (atomic_fetch_sub_explicit(&tokenizer->ref_count, 1, memory_order_acquire) == 1) {
|
| 95 |
-
if (tokenizer->mapping_ptr != NULL && tokenizer->mapping_size != 0) {
|
| 96 |
-
if (munmap(tokenizer->mapping_ptr, tokenizer->mapping_size) != 0) {
|
| 97 |
-
GPTOSS_LOG_WARNING("munmap for tokenizer mapping failed with error %d", errno);
|
| 98 |
-
}
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
|
| 102 |
-
free(tokenizer);
|
| 103 |
-
}
|
| 104 |
-
}
|
| 105 |
-
return gptoss_status_success;
|
| 106 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = [
|
| 3 |
+
"cmake>=3.26",
|
| 4 |
+
"ninja",
|
| 5 |
+
"packaging",
|
| 6 |
+
"setuptools>=61",
|
| 7 |
+
"torch",
|
| 8 |
+
"wheel",
|
| 9 |
+
]
|
| 10 |
+
build-backend = "setuptools.build_meta"
|
setup.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from shutil import which, move
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from setuptools import Extension, find_packages, setup
|
| 9 |
+
from setuptools.command.build_ext import build_ext
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_sccache_available() -> bool:
|
| 15 |
+
return which("sccache") is not None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_ccache_available() -> bool:
|
| 19 |
+
return which("ccache") is not None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_ninja_available() -> bool:
|
| 23 |
+
return which("ninja") is not None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CMakeExtension(Extension):
|
| 27 |
+
def __init__(self, name: str, sourcedir: str = "") -> None:
|
| 28 |
+
super().__init__(name, sources=[], py_limited_api=True)
|
| 29 |
+
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CMakeBuild(build_ext):
|
| 33 |
+
def build_extension(self, ext: CMakeExtension) -> None:
|
| 34 |
+
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
|
| 35 |
+
extdir = ext_fullpath.parent.resolve()
|
| 36 |
+
|
| 37 |
+
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
| 38 |
+
cfg = "Debug" if debug else "Release"
|
| 39 |
+
|
| 40 |
+
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
| 41 |
+
|
| 42 |
+
# Set Python3_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
| 43 |
+
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
| 44 |
+
# from Python.
|
| 45 |
+
cmake_args = [
|
| 46 |
+
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
| 47 |
+
f"-DPython3_EXECUTABLE={sys.executable}",
|
| 48 |
+
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
|
| 49 |
+
]
|
| 50 |
+
build_args = []
|
| 51 |
+
if "CMAKE_ARGS" in os.environ:
|
| 52 |
+
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
| 53 |
+
|
| 54 |
+
if not cmake_generator or cmake_generator == "Ninja":
|
| 55 |
+
try:
|
| 56 |
+
import ninja
|
| 57 |
+
|
| 58 |
+
ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
|
| 59 |
+
cmake_args += [
|
| 60 |
+
"-GNinja",
|
| 61 |
+
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
|
| 62 |
+
]
|
| 63 |
+
except ImportError:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
if is_sccache_available():
|
| 67 |
+
cmake_args += [
|
| 68 |
+
"-DCMAKE_C_COMPILER_LAUNCHER=sccache",
|
| 69 |
+
"-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
|
| 70 |
+
"-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
|
| 71 |
+
"-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache",
|
| 72 |
+
"-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache",
|
| 73 |
+
]
|
| 74 |
+
elif is_ccache_available():
|
| 75 |
+
cmake_args += [
|
| 76 |
+
"-DCMAKE_C_COMPILER_LAUNCHER=ccache",
|
| 77 |
+
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
|
| 78 |
+
"-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
|
| 79 |
+
"-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache",
|
| 80 |
+
"-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache",
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
num_jobs = os.getenv("MAX_JOBS", None)
|
| 84 |
+
if num_jobs is not None:
|
| 85 |
+
num_jobs = int(num_jobs)
|
| 86 |
+
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
|
| 87 |
+
else:
|
| 88 |
+
try:
|
| 89 |
+
# os.sched_getaffinity() isn't universally available, so fall
|
| 90 |
+
# back to os.cpu_count() if we get an error here.
|
| 91 |
+
num_jobs = len(os.sched_getaffinity(0))
|
| 92 |
+
except AttributeError:
|
| 93 |
+
num_jobs = os.cpu_count()
|
| 94 |
+
|
| 95 |
+
build_temp = Path(self.build_temp) / ext.name
|
| 96 |
+
if not build_temp.exists():
|
| 97 |
+
build_temp.mkdir(parents=True)
|
| 98 |
+
|
| 99 |
+
subprocess.run(
|
| 100 |
+
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
| 101 |
+
)
|
| 102 |
+
subprocess.run(
|
| 103 |
+
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
setup(
|
| 108 |
+
name="gptoss_kernels",
|
| 109 |
+
# The version is just a stub, it's not used by the final build artefact.
|
| 110 |
+
version="0.1.0",
|
| 111 |
+
ext_modules=[CMakeExtension("gptoss_kernels._gptoss_kernels_931bc1b_dirty")],
|
| 112 |
+
cmdclass={"build_ext": CMakeBuild},
|
| 113 |
+
packages=find_packages(where="torch-ext", include=["gptoss_kernels*"]),
|
| 114 |
+
package_dir={"": "torch-ext"},
|
| 115 |
+
zip_safe=False,
|
| 116 |
+
install_requires=["torch"],
|
| 117 |
+
python_requires=">=3.9",
|
| 118 |
+
)
|
{gptoss_kernels/test β test}/bf16-f32-embeddings.cc
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/embeddings-kernel-tester.hpp
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/f32-bf16w-matmul.cc
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/f32-bf16w-rmsnorm.cc
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/f32-random.cc
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/f32-rope.cc
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/fill-random-kernel-tester.hpp
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/matmul-kernel-tester.hpp
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/mf4-f32-convert.cc
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/rmsnorm-kernel-tester.hpp
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/rope-kernel-tester.hpp
RENAMED
|
File without changes
|
{gptoss_kernels/test β test}/u32-random.cc
RENAMED
|
File without changes
|
torch-ext/gptoss_kernels/__init__.py
CHANGED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._ops import ops
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def f32_bf16w_matmul(input: torch.Tensor, weight_bf16: torch.Tensor, bias_bf16: torch.Tensor, output: torch.Tensor, num_tokens: int, num_cols: int, num_rows: int, threadgroup_size: int) -> None:
|
| 5 |
+
ops.f32_bf16w_matmul_torch(input, weight_bf16, bias_bf16, output, num_tokens, num_cols, num_rows, threadgroup_size)
|
| 6 |
+
return output
|
| 7 |
+
|
| 8 |
+
__all__ = ["f32_bf16w_matmul"]
|
torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (868 Bytes). View file
|
|
|
torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc
ADDED
|
Binary file (552 Bytes). View file
|
|
|
torch-ext/gptoss_kernels/_gptoss_kernels_931bc1b_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31cddc1925c6c7901a5619ff55420ae6249d2c48de202a23a7c4534e4ccdcd4c
|
| 3 |
+
size 126536
|
torch-ext/gptoss_kernels/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _gptoss_kernels_931bc1b_dirty
|
| 3 |
+
ops = torch.ops._gptoss_kernels_931bc1b_dirty
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_gptoss_kernels_931bc1b_dirty::{op_name}"
|
torch-ext/gptoss_kernels/test.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import _gptoss_kernels_931bc1b_dirty
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
print(dir(_gptoss_kernels_931bc1b_dirty))
|
| 5 |
+
|
| 6 |
+
from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
|
torch-ext/registration.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Registration macros from vLLM:
|
| 2 |
+
// https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <Python.h>
|
| 7 |
+
|
| 8 |
+
#define _CONCAT(A, B) A##B
|
| 9 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
| 10 |
+
|
| 11 |
+
#define _STRINGIFY(A) #A
|
| 12 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
| 13 |
+
|
| 14 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
| 15 |
+
// could be a macro instead of a literal token.
|
| 16 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
| 17 |
+
|
| 18 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
| 19 |
+
// could be a macro instead of a literal token.
|
| 20 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
| 21 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
| 22 |
+
|
| 23 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
| 24 |
+
// via python's import statement.
|
| 25 |
+
#define REGISTER_EXTENSION(NAME) \
|
| 26 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
| 27 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
| 28 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
| 29 |
+
return PyModule_Create(&module); \
|
| 30 |
+
}
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Tensor.h>
|
| 2 |
+
#include "torch_binding.h"
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
|
| 5 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 6 |
+
ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
|
| 7 |
+
ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch);
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
CHANGED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void f32_bf16w_matmul_torch(const at::Tensor &input, const at::Tensor &weight_bf16, const at::Tensor &bias_bf16, at::Tensor &output, int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size);
|