Mohamed Mekkouri commited on
Commit
95d28ad
Β·
1 Parent(s): 51250cb

commit evtn

Browse files
Files changed (44) hide show
  1. CMakeLists.txt +104 -0
  2. README.md +104 -0
  3. build.toml +22 -8
  4. cmake/compile-metal.cmake +86 -0
  5. cmake/metallib_to_header.py +73 -0
  6. cmake/utils.cmake +557 -0
  7. flake.lock +169 -0
  8. flake.nix +1 -1
  9. gptoss_kernels/CMakeLists.txt +0 -191
  10. gptoss_kernels/__init__.py +0 -6
  11. gptoss_kernels/examples/chat.py +0 -104
  12. gptoss_kernels/examples/generate.py +0 -34
  13. gptoss_kernels/source/context.c +0 -1115
  14. gptoss_kernels/source/generate.c +0 -317
  15. gptoss_kernels/source/include/internal/log.h +7 -0
  16. gptoss_kernels/source/include/internal/metal.h +0 -1
  17. gptoss_kernels/source/matmul.metal +8 -2
  18. gptoss_kernels/source/metal.m +0 -1
  19. gptoss_kernels/source/model.c +0 -581
  20. gptoss_kernels/source/tensor_wrappers.cpp +77 -0
  21. gptoss_kernels/source/tokenizer.c +0 -106
  22. pyproject.toml +10 -0
  23. setup.py +118 -0
  24. {gptoss_kernels/test β†’ test}/bf16-f32-embeddings.cc +0 -0
  25. {gptoss_kernels/test β†’ test}/embeddings-kernel-tester.hpp +0 -0
  26. {gptoss_kernels/test β†’ test}/f32-bf16w-matmul.cc +0 -0
  27. {gptoss_kernels/test β†’ test}/f32-bf16w-rmsnorm.cc +0 -0
  28. {gptoss_kernels/test β†’ test}/f32-random.cc +0 -0
  29. {gptoss_kernels/test β†’ test}/f32-rope.cc +0 -0
  30. {gptoss_kernels/test β†’ test}/fill-random-kernel-tester.hpp +0 -0
  31. {gptoss_kernels/test β†’ test}/matmul-kernel-tester.hpp +0 -0
  32. {gptoss_kernels/test β†’ test}/mf4-f32-convert.cc +0 -0
  33. {gptoss_kernels/test β†’ test}/rmsnorm-kernel-tester.hpp +0 -0
  34. {gptoss_kernels/test β†’ test}/rope-kernel-tester.hpp +0 -0
  35. {gptoss_kernels/test β†’ test}/u32-random.cc +0 -0
  36. torch-ext/gptoss_kernels/__init__.py +8 -0
  37. torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc +0 -0
  38. torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc +0 -0
  39. torch-ext/gptoss_kernels/_gptoss_kernels_931bc1b_dirty.abi3.so +3 -0
  40. torch-ext/gptoss_kernels/_ops.py +9 -0
  41. torch-ext/gptoss_kernels/test.py +6 -0
  42. torch-ext/registration.h +30 -0
  43. torch-ext/torch_binding.cpp +10 -0
  44. torch-ext/torch_binding.h +5 -0
CMakeLists.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.26)
2
+ project(gptoss_kernels LANGUAGES CXX)
3
+
4
+ set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")
5
+
6
+ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7
+
8
+ include(FetchContent)
9
+ file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10
+ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11
+
12
+ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
13
+
14
+ if(DEFINED Python3_EXECUTABLE)
15
+ # Allow passing through the interpreter (e.g. from setup.py).
16
+ find_package(Python3 COMPONENTS Development Development.SABIModule Interpreter)
17
+ if (NOT Python3_FOUND)
18
+ message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
19
+ endif()
20
+ else()
21
+ find_package(Python3 REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
22
+ endif()
23
+
24
+ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
25
+
26
+ find_package(Torch REQUIRED)
27
+
28
+ add_compile_definitions(METAL_KERNEL)
29
+
30
+ # Initialize list for Metal shader sources
31
+ set(ALL_METAL_SOURCES)
32
+ #get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
33
+ #list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
34
+
35
+ set(TORCH_gptoss_kernels_SRC
36
+ torch-ext/torch_binding.cpp torch-ext/torch_binding.h
37
+ )
38
+
39
+
40
+ list(APPEND SRC "${TORCH_gptoss_kernels_SRC}")
41
+ set(gptoss_kernels_SRC
42
+ "gptoss_kernels/source/accumulate.metal"
43
+ "gptoss_kernels/source/expert_routing_metadata.metal"
44
+ "gptoss_kernels/source/metal.m"
45
+ "gptoss_kernels/source/scatter.metal"
46
+ "gptoss_kernels/source/topk.metal"
47
+ "gptoss_kernels/source/embeddings.metal"
48
+ "gptoss_kernels/source/metal-kernels.c"
49
+ "gptoss_kernels/source/random.metal"
50
+ "gptoss_kernels/source/sdpa.metal"
51
+ "gptoss_kernels/source/matmul.metal"
52
+ "gptoss_kernels/source/rmsnorm.metal"
53
+ "gptoss_kernels/source/sample.metal"
54
+ "gptoss_kernels/source/moematmul.metal"
55
+ "gptoss_kernels/source/convert.metal"
56
+ "gptoss_kernels/source/rope.metal"
57
+ "gptoss_kernels/source/gather_and_accumulate.metal"
58
+ "gptoss_kernels/source/tensor_wrappers.cpp"
59
+ "gptoss_kernels/source/log.c"
60
+ )
61
+
62
+ # Separate Metal shader files from other sources
63
+ set(gptoss_kernels_METAL_SRC)
64
+ set(gptoss_kernels_CPP_SRC)
65
+
66
+ foreach(src_file IN LISTS gptoss_kernels_SRC)
67
+ if(src_file MATCHES "\\.(metal|h)$")
68
+ list(APPEND gptoss_kernels_METAL_SRC ${src_file})
69
+ else()
70
+ list(APPEND gptoss_kernels_CPP_SRC ${src_file})
71
+ endif()
72
+ endforeach()
73
+
74
+ # TODO: check if CLion support this:
75
+ # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
76
+ set_source_files_properties(
77
+ ${gptoss_kernels_CPP_SRC}
78
+ PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/gptoss_kernels/source/include;${CMAKE_SOURCE_DIR}/gptoss_kernels/include;${CMAKE_SOURCE_DIR}/.")
79
+
80
+
81
+ # Add C++ sources to main source list
82
+ list(APPEND SRC "${gptoss_kernels_CPP_SRC}")
83
+
84
+ # Keep track of Metal sources for later compilation
85
+ if(gptoss_kernels_METAL_SRC)
86
+ list(APPEND ALL_METAL_SOURCES "${gptoss_kernels_METAL_SRC}")
87
+ endif()
88
+ # Include Metal shader compilation utilities
89
+ include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
90
+
91
+ define_gpu_extension_target(
92
+ _gptoss_kernels_931bc1b_dirty
93
+ DESTINATION _gptoss_kernels_931bc1b_dirty
94
+ LANGUAGE ${GPU_LANG}
95
+ SOURCES ${SRC}
96
+ COMPILE_FLAGS ${GPU_FLAGS}
97
+ ARCHITECTURES ${GPU_ARCHES}
98
+ USE_SABI 3
99
+ WITH_SOABI)
100
+
101
+ # Compile Metal shaders if any were found
102
+ if(ALL_METAL_SOURCES)
103
+ compile_metal_shaders(_gptoss_kernels_931bc1b_dirty "${ALL_METAL_SOURCES}")
104
+ endif()
README.md CHANGED
@@ -8,3 +8,107 @@ tags:
8
 
9
  This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
10
 
11
+ ```21:69:/Users/medmekk/projects/ai/kernels/gpt-oss/gpt_oss/metal/source/matmul.metal
12
+ kernel void gptoss_f32_bf16w_matmul(
13
+ constant gptoss_matmul_args& args [[ buffer(0) ]],
14
+ const device float4* input [[ buffer(1) ]],
15
+ const device bfloat4* weight [[ buffer(2) ]],
16
+ const device bfloat* bias [[ buffer(3) ]],
17
+ device float* output [[ buffer(4) ]],
18
+ const device gptoss_control* control [[ buffer(5) ]],
19
+ uint2 gid [[threadgroup_position_in_grid]],
20
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
21
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
22
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
23
+ {
24
+ const uint simdgroup_size = 32;
25
+ if (control->abort != 0) {
26
+ return;
27
+ }
28
+
29
+ const uint num_column_vecs = args.num_column_vecs;
30
+ const uint row = gid.x * num_simdgroups + simdgroup_idx;
31
+
32
+ input += gid.y * num_column_vecs + simdgroup_tid;
33
+ weight += num_column_vecs * row + simdgroup_tid;
34
+ bias += row;
35
+ output += gid.y * args.num_rows + row;
36
+
37
+ uint num_iter = 0;
38
+ num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
39
+
40
+ float4 sum4 = 0.0f;
41
+ do {
42
+ const bfloat4 w = *weight;
43
+ const float4 i = *input;
44
+ sum4 = metal::fma(static_cast<float4>(w), i, sum4);
45
+
46
+ weight += simdgroup_size;
47
+ input += simdgroup_size;
48
+ } while (--num_iter != 0);
49
+ const float2 sum2 = sum4.xy + sum4.zw;
50
+ float sum = sum2.x + sum2.y;
51
+ sum = metal::simd_sum(sum);
52
+ if (metal::simd_is_first()) {
53
+ sum += static_cast<float>(*bias);
54
+ if (args.add) {
55
+ *output += sum;
56
+ } else {
57
+ *output = sum;
58
+ }
59
+ }
60
+ }
61
+ ```
62
+
63
+ ### What it computes
64
+ - Computes Y = X Β· W + b for a batch of tokens.
65
+ - Types/layout:
66
+ - X is float32, shape [num_tokens, num_cols], viewed as `float4` vectors β†’ num_column_vecs = num_cols/4.
67
+ - W is bfloat16, shape [num_rows, num_cols], viewed as `bfloat4` vectors per row (row-major).
68
+ - b is bfloat16, length num_rows.
69
+ - Y is float32, shape [num_tokens, num_rows].
70
+
71
+ ### Work decomposition
72
+ - Grid Y (gid.y) = token index t in [0, num_tokens).
73
+ - Grid X (gid.x) spans output rows in groups of `num_simdgroups`. Within a threadgroup:
74
+ - simdgroup_idx in [0, num_simdgroups) selects one output row r.
75
+ - Therefore row r = gid.x*num_simdgroups + simdgroup_idx.
76
+ - Each simdgroup computes exactly one scalar Y[t, r].
77
+ - Lanes inside a simdgroup (simdgroup_tid in [0, 31]) split the K dimension (num_column_vecs vectors) in a strided pattern: lane β„“ processes indices β„“, β„“+32, β„“+64, ...
78
+
79
+ ### Pointer setup per simdgroup
80
+ - input points to the start of token t’s vector-list, then lane offset: `input += t*num_column_vecs + lane`.
81
+ - weight points to the start of row r’s vector-list, then lane offset: `weight += r*num_column_vecs + lane`.
82
+ - bias points to `b[r]`. output points to `Y[t, r]`.
83
+
84
+ ### Inner loop
85
+ - num_iter = ceil((num_column_vecs - lane)/32). Each lane loops over its share of the K/4 vectors.
86
+ - On each iteration:
87
+ - Load one `float4` from input (4 consecutive columns) and one `bfloat4` from weight for row r.
88
+ - Fused multiply-add into `sum4` (vector-wise).
89
+ - Advance both pointers by 32 vectors (next stripe for this lane).
90
+ - After loop, reduce the 4 lanes of `sum4` into a scalar: sum4.xy + sum4.zw β†’ sum2, then sum2.x + sum2.y.
91
+ - Reduce across all 32 lanes with `metal::simd_sum(sum)`.
92
+ - Lane 0 adds bias[r] and writes `Y[t, r]` (add or overwrite depending on args.add).
93
+
94
+ ### Example mapping (num_tokens=2, num_cols=128, num_rows=4, threadgroup_size=32)
95
+ - num_column_vecs = 128/4 = 32.
96
+ - threadgroup_size=32 β‡’ num_simdgroups=1 β‡’ each threadgroup computes 1 row.
97
+ - Grid:
98
+ - gid.y ∈ {0,1} (two tokens).
99
+ - gid.x ∈ {0,1,2,3} (four rows).
100
+ - For a given (gid.x, gid.y), simdgroup_idx=0 computes one output scalar Y[t=gid.y, r=gid.x].
101
+ - Per-lane work:
102
+ - lane β„“ loads X[t, cols 4β„“..4β„“+3] as float4 and W[r, cols 4β„“..4β„“+3] as bfloat4 (exactly one iteration since 32 vectors total).
103
+ - Each lane accumulates the dot over its 4 elements; lanes are then summed β†’ full dot(X_row, W_row).
104
+ - Lane 0 writes Y[t, r] = dot + b[r].
105
+
106
+ ### Example mapping (same shapes, threadgroup_size=64)
107
+ - num_simdgroups=64/32=2 β‡’ each threadgroup computes 2 rows at once.
108
+ - For gid.x=k:
109
+ - simdgroup_idx=0 computes row r=2k, simdgroup_idx=1 computes row r=2k+1.
110
+ - Lanes split K identically; two output scalars are produced per threadgroup (one per simdgroup).
111
+
112
+ ### Which piece each unit owns
113
+ - Token t (grid y) Γ— Row r (simdgroup within grid x) β†’ one output scalar Y[t, r].
114
+ - Lane β„“ within that simdgroup β†’ partial dot over columns {4β„“, 4β„“+1, 4β„“+2, 4β„“+3}, plus any further stripes {4(β„“+32), ...} if num_cols > 128.
build.toml CHANGED
@@ -9,15 +9,29 @@ src = [
9
  ]
10
 
11
  [kernel.gptoss_kernels]
 
12
  depends = ["torch"]
13
- backend = "cuda"
14
 
15
  src = [
16
- "gptoss_kernels/attention_cuda_fwd.cu",
17
- "gptoss_kernels/attention_cuda_bwd.cu",
18
- "gptoss_kernels/attention_cuda_utils.cu",
19
- "gptoss_kernels/attention_cuda_utils.cuh",
20
- "gptoss_kernels/attention_cuda.cuh",
21
- "gptoss_kernels/attention.h",
22
- "gptoss_kernels/cudamacro.h",
 
 
 
 
 
 
 
 
 
 
 
23
  ]
 
 
 
9
  ]
10
 
11
  [kernel.gptoss_kernels]
12
+
13
  depends = ["torch"]
14
+ backend = "metal"
15
 
16
  src = [
17
+ "gptoss_kernels/source/accumulate.metal",
18
+ "gptoss_kernels/source/expert_routing_metadata.metal",
19
+ "gptoss_kernels/source/metal.m",
20
+ "gptoss_kernels/source/scatter.metal",
21
+ "gptoss_kernels/source/topk.metal",
22
+ "gptoss_kernels/source/embeddings.metal",
23
+ "gptoss_kernels/source/metal-kernels.c",
24
+ "gptoss_kernels/source/random.metal",
25
+ "gptoss_kernels/source/sdpa.metal",
26
+ "gptoss_kernels/source/matmul.metal",
27
+ "gptoss_kernels/source/rmsnorm.metal",
28
+ "gptoss_kernels/source/sample.metal",
29
+ "gptoss_kernels/source/moematmul.metal",
30
+ "gptoss_kernels/source/convert.metal",
31
+ "gptoss_kernels/source/rope.metal",
32
+ "gptoss_kernels/source/gather_and_accumulate.metal",
33
+ "gptoss_kernels/source/tensor_wrappers.cpp",
34
+ "gptoss_kernels/source/log.c",
35
  ]
36
+
37
+ include = ["gptoss_kernels/source/include", "gptoss_kernels/include", "."]
cmake/compile-metal.cmake ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Metal shader compilation function
2
+ function(compile_metal_shaders TARGET_NAME METAL_SOURCES)
3
+ # Find the Metal compiler
4
+ find_program(METAL_COMPILER xcrun REQUIRED)
5
+
6
+ # Set Metal compiler flags
7
+ set(METAL_FLAGS "-std=metal3.0" "-O2")
8
+
9
+ # Output directory for compiled metallib
10
+ set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")
11
+ file(MAKE_DIRECTORY ${METALLIB_OUTPUT_DIR})
12
+
13
+ # Separate .metal files from .h files and compile .metal files to .air
14
+ set(AIR_FILES)
15
+ set(METAL_FILES)
16
+ set(HEADER_FILES)
17
+
18
+ foreach(SOURCE_FILE ${METAL_SOURCES})
19
+ if(SOURCE_FILE MATCHES "\\.metal$")
20
+ list(APPEND METAL_FILES ${SOURCE_FILE})
21
+ elseif(SOURCE_FILE MATCHES "\\.h$")
22
+ list(APPEND HEADER_FILES ${SOURCE_FILE})
23
+ endif()
24
+ endforeach()
25
+
26
+ foreach(METAL_FILE ${METAL_FILES})
27
+ get_filename_component(METAL_NAME ${METAL_FILE} NAME_WE)
28
+ set(AIR_FILE "${CMAKE_BINARY_DIR}/${METAL_NAME}.air")
29
+
30
+ # Include header files as dependencies
31
+ set(ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE})
32
+ foreach(HEADER_FILE ${HEADER_FILES})
33
+ list(APPEND ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${HEADER_FILE})
34
+ endforeach()
35
+
36
+ add_custom_command(
37
+ OUTPUT ${AIR_FILE}
38
+ COMMAND ${METAL_COMPILER} -sdk macosx metal ${METAL_FLAGS}
39
+ -c ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE}
40
+ -o ${AIR_FILE}
41
+ DEPENDS ${ALL_DEPENDENCIES}
42
+ COMMENT "Compiling Metal shader ${METAL_FILE} to ${AIR_FILE}"
43
+ VERBATIM
44
+ )
45
+
46
+ list(APPEND AIR_FILES ${AIR_FILE})
47
+ endforeach()
48
+
49
+ # Link all .air files into a single .metallib
50
+ set(METALLIB_FILE "${METALLIB_OUTPUT_DIR}/${TARGET_NAME}.metallib")
51
+ add_custom_command(
52
+ OUTPUT ${METALLIB_FILE}
53
+ COMMAND ${METAL_COMPILER} -sdk macosx metallib ${AIR_FILES}
54
+ -o ${METALLIB_FILE}
55
+ DEPENDS ${AIR_FILES}
56
+ COMMENT "Linking Metal library ${METALLIB_FILE}"
57
+ VERBATIM
58
+ )
59
+
60
+ # Generate C++ header with embedded metallib data
61
+ set(METALLIB_HEADER "${CMAKE_BINARY_DIR}/${TARGET_NAME}_metallib.h")
62
+ set(METALLIB_TO_HEADER_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/cmake/metallib_to_header.py")
63
+
64
+ add_custom_command(
65
+ OUTPUT ${METALLIB_HEADER}
66
+ COMMAND ${Python_EXECUTABLE} ${METALLIB_TO_HEADER_SCRIPT} ${METALLIB_FILE} ${METALLIB_HEADER} ${TARGET_NAME}
67
+ DEPENDS ${METALLIB_FILE} ${METALLIB_TO_HEADER_SCRIPT}
68
+ COMMENT "Generating embedded Metal library header ${METALLIB_HEADER}"
69
+ VERBATIM
70
+ )
71
+
72
+ # Create a custom target for the metallib
73
+ add_custom_target(${TARGET_NAME}_metallib ALL DEPENDS ${METALLIB_FILE} ${METALLIB_HEADER})
74
+
75
+ # Add dependency to main target
76
+ add_dependencies(${TARGET_NAME} ${TARGET_NAME}_metallib)
77
+
78
+ # Add the generated header to include directories
79
+ target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_BINARY_DIR})
80
+
81
+ # Pass the metallib header and namespace as compile definitions
82
+ target_compile_definitions(${TARGET_NAME} PRIVATE
83
+ EMBEDDED_METALLIB_HEADER="${TARGET_NAME}_metallib.h"
84
+ EMBEDDED_METALLIB_NAMESPACE=${TARGET_NAME}_metal
85
+ )
86
+ endfunction()
cmake/metallib_to_header.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import os
4
+
5
+ def convert_metallib_to_header(metallib_path: str, header_path: str, target_name: str) -> None:
6
+ """Convert a metallib binary file to a C++ header with embedded data."""
7
+
8
+ # Read the metallib binary data
9
+ with open(metallib_path, 'rb') as f:
10
+ data: bytes = f.read()
11
+
12
+ # Generate the header content
13
+ header_content: str = """// Auto-generated file containing embedded Metal library
14
+ #pragma once
15
+ #include <cstddef>
16
+ #include <Metal/Metal.h>
17
+
18
+ namespace """ + target_name + """_metal {
19
+ static const unsigned char metallib_data[] = {
20
+ """
21
+
22
+ # Convert binary data to C array format
23
+ bytes_per_line: int = 16
24
+ for i in range(0, len(data), bytes_per_line):
25
+ chunk: bytes = data[i:i + bytes_per_line]
26
+ hex_values: str = ', '.join('0x{:02x}'.format(b) for b in chunk)
27
+ header_content += " " + hex_values + ","
28
+ if i + bytes_per_line < len(data):
29
+ header_content += "\n"
30
+
31
+ header_content += """
32
+ };
33
+ static const size_t metallib_data_len = """ + str(len(data)) + """;
34
+
35
+ // Convenience function to create Metal library from embedded data
36
+ inline id<MTLLibrary> createLibrary(id<MTLDevice> device, NSError** error = nullptr) {
37
+ dispatch_data_t libraryData = dispatch_data_create(
38
+ metallib_data,
39
+ metallib_data_len,
40
+ dispatch_get_main_queue(),
41
+ ^{ /* No cleanup needed for static data */ });
42
+
43
+ NSError* localError = nil;
44
+ id<MTLLibrary> library = [device newLibraryWithData:libraryData error:&localError];
45
+
46
+ if (error) {
47
+ *error = localError;
48
+ }
49
+
50
+ return library;
51
+ }
52
+ } // namespace """ + target_name + """_metal
53
+ """
54
+
55
+ # Write the header file
56
+ dir_path: str = os.path.dirname(header_path)
57
+ if dir_path:
58
+ os.makedirs(dir_path, exist_ok=True)
59
+ with open(header_path, 'w') as f:
60
+ f.write(header_content)
61
+
62
+ print("Generated {} ({} bytes)".format(header_path, len(data)))
63
+
64
+ if __name__ == "__main__":
65
+ if len(sys.argv) != 4:
66
+ print("Usage: metallib_to_header.py <metallib_path> <header_path> <target_name>")
67
+ sys.exit(1)
68
+
69
+ metallib_path: str = sys.argv[1]
70
+ header_path: str = sys.argv[2]
71
+ target_name: str = sys.argv[3]
72
+
73
+ convert_metallib_to_header(metallib_path, header_path, target_name)
cmake/utils.cmake ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vendored from vLLM:
2
+ #
3
+ # https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
4
+ #
5
+ # Attempt to find the python package that uses the same python executable as
6
+ # `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
7
+ #
8
+ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
9
+ file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
10
+ set(Python3_EXECUTABLE ${EXECUTABLE})
11
+ find_package(Python3 COMPONENTS Interpreter Development.Module Development.SABIModule)
12
+ if (NOT Python3_FOUND)
13
+ message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
14
+ endif()
15
+ set(_VER "${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}")
16
+ set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
17
+ if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
18
+ message(FATAL_ERROR
19
+ "Python version (${_VER}) is not one of the supported versions: "
20
+ "${_SUPPORTED_VERSIONS_LIST}.")
21
+ endif()
22
+ message(STATUS "Found python matching: ${EXECUTABLE}.")
23
+ endmacro()
24
+
25
+ #
26
+ # Run `EXPR` in python. The standard output of python is stored in `OUT` and
27
+ # has trailing whitespace stripped. If an error is encountered when running
28
+ # python, a fatal message `ERR_MSG` is issued.
29
+ #
30
+ function (run_python OUT EXPR ERR_MSG)
31
+ execute_process(
32
+ COMMAND
33
+ "${Python3_EXECUTABLE}" "-c" "${EXPR}"
34
+ OUTPUT_VARIABLE PYTHON_OUT
35
+ RESULT_VARIABLE PYTHON_ERROR_CODE
36
+ ERROR_VARIABLE PYTHON_STDERR
37
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
38
+
39
+ if(NOT PYTHON_ERROR_CODE EQUAL 0)
40
+ message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
41
+ endif()
42
+ set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
43
+ endfunction()
44
+
45
+ # Run `EXPR` in python after importing `PKG`. Use the result of this to extend
46
+ # `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
47
+ macro (append_cmake_prefix_path PKG EXPR)
48
+ run_python(_PREFIX_PATH
49
+ "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
50
+ list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
51
+ endmacro()
52
+
53
+ #
54
+ # Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
55
+ # of CUDA source files. The names of the corresponding "hipified" sources are
56
+ # stored in `OUT_SRCS`.
57
+ #
58
+ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
59
+ #
60
+ # Split into C++ and non-C++ (i.e. CUDA) sources.
61
+ #
62
+ set(NODUP_SRCS ${ORIG_SRCS})
63
+ list(REMOVE_DUPLICATES NODUP_SRCS)
64
+ set(SRCS ${NODUP_SRCS})
65
+ set(CXX_SRCS ${NODUP_SRCS})
66
+ list(FILTER SRCS INCLUDE REGEX "\.cu$")
67
+ list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
68
+
69
+ #
70
+ # Generate ROCm/HIP source file names from CUDA file names.
71
+ # Since HIP files are generated code, they will appear in the build area
72
+ # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
73
+ #
74
+ set(HIP_SRCS)
75
+ foreach (SRC ${SRCS})
76
+ get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
77
+ get_source_file_property(compile_options "${SRC}" COMPILE_OPTIONS)
78
+ string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
79
+ string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
80
+
81
+ if(include_dirs)
82
+ # Copy over include directories from the original CUDA file.
83
+ set_source_files_properties(
84
+ ${SRC}
85
+ PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
86
+ endif()
87
+
88
+ if(compile_options)
89
+ set_source_files_properties(
90
+ ${SRC}
91
+ PROPERTIES COMPILE_OPTIONS "${compile_options}")
92
+ endif()
93
+
94
+ list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
95
+ endforeach()
96
+
97
+ add_custom_target(
98
+ hipify${NAME}
99
+ COMMAND "${Python3_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
100
+ DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
101
+ BYPRODUCTS ${HIP_SRCS}
102
+ COMMENT "Running hipify on ${NAME} extension source files.")
103
+
104
+ # Swap out original extension sources with hipified sources.
105
+ list(APPEND HIP_SRCS ${CXX_SRCS})
106
+ set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
107
+ endfunction()
108
+
109
+ #
110
+ # Get additional GPU compiler flags from torch.
111
+ #
112
+ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
113
+ if (${GPU_LANG} STREQUAL "CUDA")
114
+ #
115
+ # Get common NVCC flags from torch.
116
+ #
117
+ run_python(GPU_FLAGS
118
+ "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
119
+ "Failed to determine torch nvcc compiler flags")
120
+
121
+ if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
122
+ list(APPEND GPU_FLAGS "-DENABLE_FP8")
123
+ list(REMOVE_ITEM GPU_FLAGS
124
+ "-D__CUDA_NO_HALF_OPERATORS__"
125
+ "-D__CUDA_NO_HALF_CONVERSIONS__"
126
+ "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
127
+ "-D__CUDA_NO_HALF2_OPERATORS__")
128
+ endif()
129
+
130
+ elseif(${GPU_LANG} STREQUAL "HIP")
131
+ #
132
+ # Get common HIP/HIPCC flags from torch.
133
+ #
134
+ run_python(GPU_FLAGS
135
+ "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
136
+ "Failed to determine torch nvcc compiler flags")
137
+
138
+ list(APPEND GPU_FLAGS
139
+ "-DUSE_ROCM"
140
+ "-DENABLE_FP8"
141
+ "-U__HIP_NO_HALF_CONVERSIONS__"
142
+ "-U__HIP_NO_HALF_OPERATORS__"
143
+ "-fno-gpu-rdc")
144
+
145
+ endif()
146
+ set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
147
+ endfunction()
148
+
149
+ # Macro for converting a `gencode` version number to a cmake version number.
150
+ macro(string_to_ver OUT_VER IN_STR)
151
+ string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
152
+ endmacro()
153
+
154
+ #
155
+ # Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
156
+ # `CUDA_ARCH_FLAGS`.
157
+ #
158
+ # Example:
159
+ # CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
160
+ # clear_cuda_arches(CUDA_ARCH_FLAGS)
161
+ # CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
162
+ # CMAKE_CUDA_FLAGS="-Wall"
163
+ #
164
+ macro(clear_cuda_arches CUDA_ARCH_FLAGS)
165
+ # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
166
+ string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
167
+ ${CMAKE_CUDA_FLAGS})
168
+
169
+ # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
170
+ # and passed back via the `CUDA_ARCHITECTURES` property.
171
+ string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
172
+ ${CMAKE_CUDA_FLAGS})
173
+ endmacro()
174
+
175
+ #
176
+ # Extract unique CUDA architectures from a list of compute capabilities codes in
177
+ # the form `<major><minor>[<letter>]`, convert them to the form sort
178
+ # `<major>.<minor>`, dedupes them and then sorts them in ascending order and
179
+ # stores them in `OUT_ARCHES`.
180
+ #
181
+ # Example:
182
+ # CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
183
+ # extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
184
+ # OUT_ARCHES="7.5;...;9.0"
185
+ function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
186
+ set(_CUDA_ARCHES)
187
+ foreach(_ARCH ${CUDA_ARCH_FLAGS})
188
+ string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
189
+ if (_COMPUTE)
190
+ set(_COMPUTE ${CMAKE_MATCH_1})
191
+ endif()
192
+
193
+ string_to_ver(_COMPUTE_VER ${_COMPUTE})
194
+ list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
195
+ endforeach()
196
+
197
+ list(REMOVE_DUPLICATES _CUDA_ARCHES)
198
+ list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
199
+ set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
200
+ endfunction()
201
+
202
+ #
203
+ # For a specific file set the `-gencode` flag in compile options conditionally
204
+ # for the CUDA language.
205
+ #
206
+ # Example:
207
+ # set_gencode_flag_for_srcs(
208
+ # SRCS "foo.cu"
209
+ # ARCH "compute_75"
210
+ # CODE "sm_75")
211
+ # adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
212
+ # `foo.cu` (only for the CUDA language).
213
+ #
214
+ macro(set_gencode_flag_for_srcs)
215
+ set(options)
216
+ set(oneValueArgs ARCH CODE)
217
+ set(multiValueArgs SRCS)
218
+ cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
219
+ "${multiValueArgs}" ${ARGN} )
220
+ set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
221
+ set_property(
222
+ SOURCE ${arg_SRCS}
223
+ APPEND PROPERTY
224
+ COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
225
+ )
226
+
227
+ message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
228
+ endmacro(set_gencode_flag_for_srcs)
229
+
230
+ #
231
+ # For a list of source files set the `-gencode` flags in the files specific
232
+ # compile options (specifically for the CUDA language).
233
+ #
234
+ # arguments are:
235
+ # SRCS: list of source files
236
+ # CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
237
+ # BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
238
+ # for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
239
+ # that is larger than BUILD_PTX_FOR_ARCH.
240
+ #
241
+ macro(set_gencode_flags_for_srcs)
242
+ set(options)
243
+ set(oneValueArgs BUILD_PTX_FOR_ARCH)
244
+ set(multiValueArgs SRCS CUDA_ARCHS)
245
+ cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
246
+ "${multiValueArgs}" ${ARGN} )
247
+
248
+ foreach(_ARCH ${arg_CUDA_ARCHS})
249
+ # handle +PTX suffix: generate both sm and ptx codes if requested
250
+ string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
251
+ if(NOT _HAS_PTX EQUAL -1)
252
+ string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
253
+ string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
254
+ set_gencode_flag_for_srcs(
255
+ SRCS ${arg_SRCS}
256
+ ARCH "compute_${_STRIPPED_ARCH}"
257
+ CODE "sm_${_STRIPPED_ARCH}")
258
+ set_gencode_flag_for_srcs(
259
+ SRCS ${arg_SRCS}
260
+ ARCH "compute_${_STRIPPED_ARCH}"
261
+ CODE "compute_${_STRIPPED_ARCH}")
262
+ else()
263
+ string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
264
+ set_gencode_flag_for_srcs(
265
+ SRCS ${arg_SRCS}
266
+ ARCH "compute_${_STRIPPED_ARCH}"
267
+ CODE "sm_${_STRIPPED_ARCH}")
268
+ endif()
269
+ endforeach()
270
+
271
+ if (${arg_BUILD_PTX_FOR_ARCH})
272
+ list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
273
+ list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
274
+ if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
275
+ string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
276
+ set_gencode_flag_for_srcs(
277
+ SRCS ${arg_SRCS}
278
+ ARCH "compute_${_PTX_ARCH}"
279
+ CODE "compute_${_PTX_ARCH}")
280
+ endif()
281
+ endif()
282
+ endmacro()
283
+
284
+ #
285
+ # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
286
+ # `<major>.<minor>[letter]` compute the "loose intersection" with the
287
+ # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
288
+ # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
289
+ # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
290
+ # architecture in `SRC_CUDA_ARCHS`.
291
+ # The loose intersection is defined as:
292
+ # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
293
+ # where `<=` is the version comparison operator.
294
+ # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
295
+ # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
296
+ # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
297
+ # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
298
+ # x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
299
+ # The result is stored in `OUT_CUDA_ARCHS`.
300
+ #
301
+ # Example:
302
+ # SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
303
+ # TGT_CUDA_ARCHS="8.0;8.9;9.0"
304
+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
305
+ # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
306
+ #
307
+ # Example With PTX:
308
+ # SRC_CUDA_ARCHS="8.0+PTX"
309
+ # TGT_CUDA_ARCHS="9.0"
310
+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
311
+ # OUT_CUDA_ARCHS="8.0+PTX"
312
+ #
313
+ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
314
+ set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
315
+ set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
316
+
317
+ # handle +PTX suffix: separate base arch for matching, record PTX requests
318
+ set(_PTX_ARCHS)
319
+ foreach(_arch ${_SRC_CUDA_ARCHS})
320
+ if(_arch MATCHES "\\+PTX$")
321
+ string(REPLACE "+PTX" "" _base "${_arch}")
322
+ list(APPEND _PTX_ARCHS "${_base}")
323
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
324
+ list(APPEND _SRC_CUDA_ARCHS "${_base}")
325
+ endif()
326
+ endforeach()
327
+ list(REMOVE_DUPLICATES _PTX_ARCHS)
328
+ list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
329
+
330
+ # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
331
+ # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
332
+ set(_CUDA_ARCHS)
333
+ foreach(_arch ${_SRC_CUDA_ARCHS})
334
+ if(_arch MATCHES "\\a$")
335
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
336
+ string(REPLACE "a" "" _base "${_arch}")
337
+ if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
338
+ list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
339
+ list(APPEND _CUDA_ARCHS "${_arch}")
340
+ endif()
341
+ endif()
342
+ endforeach()
343
+
344
+ list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
345
+
346
+ # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
347
+ # is less or equal to ARCH (but has the same major version since SASS binary
348
+ # compatibility is only forward compatible within the same major version).
349
+ foreach(_ARCH ${_TGT_CUDA_ARCHS})
350
+ set(_TMP_ARCH)
351
+ # Extract the major version of the target arch
352
+ string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
353
+ foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
354
+ # Extract the major version of the source arch
355
+ string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
356
+ # Check version-less-or-equal, and allow PTX arches to match across majors
357
+ if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
358
+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
359
+ set(_TMP_ARCH "${_SRC_ARCH}")
360
+ endif()
361
+ else()
362
+ # If we hit a version greater than the target, we can break
363
+ break()
364
+ endif()
365
+ endforeach()
366
+
367
+ # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
368
+ if (_TMP_ARCH)
369
+ list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
370
+ endif()
371
+ endforeach()
372
+
373
+ list(REMOVE_DUPLICATES _CUDA_ARCHS)
374
+
375
+ # reapply +PTX suffix to architectures that requested PTX
376
+ set(_FINAL_ARCHS)
377
+ foreach(_arch ${_CUDA_ARCHS})
378
+ if(_arch IN_LIST _PTX_ARCHS)
379
+ list(APPEND _FINAL_ARCHS "${_arch}+PTX")
380
+ else()
381
+ list(APPEND _FINAL_ARCHS "${_arch}")
382
+ endif()
383
+ endforeach()
384
+ set(_CUDA_ARCHS ${_FINAL_ARCHS})
385
+
386
+ set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
387
+ endfunction()
388
+
389
+ #
390
+ # For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
391
+ # `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
392
+ # The loose intersection is defined as:
393
+ # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
394
+ # where `<=` is the version comparison operator.
395
+ # In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
396
+ # in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
397
+ # The result is stored in `OUT_ROCM_ARCHS`.
398
+ #
399
+ # Example:
400
+ # SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
401
+ # TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
402
+ # hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
403
+ # OUT_ROCM_ARCHS="gfx906;gfx908"
404
+ #
405
+ function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
406
+ list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
407
+
408
+ # ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
409
+ # and x is a letter. We can sort them by string comparison which works for this format.
410
+ list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
411
+
412
+ set(_ROCM_ARCHS)
413
+
414
+ # Find the intersection of supported architectures
415
+ foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
416
+ if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
417
+ list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
418
+ endif()
419
+ endforeach()
420
+
421
+ list(REMOVE_DUPLICATES _ROCM_ARCHS)
422
+ set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
423
+ endfunction()
424
+
425
+ #
426
+ # Override the GPU architectures detected by cmake/torch and filter them by
427
+ # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
428
+ # `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
429
+ # the architectures on a per file basis.
430
+ #
431
+ # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
432
+ #
433
+ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
434
+ set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
435
+ message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
436
+
437
+ if (${GPU_LANG} STREQUAL "HIP")
438
+ #
439
+ # `GPU_ARCHES` controls the `--offload-arch` flags.
440
+ #
441
+ # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
442
+ # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
443
+ # "rocm_agent_enumerator" in "enable_language(HIP)"
444
+ # (in file Modules/CMakeDetermineHIPCompiler.cmake)
445
+ #
446
+ if(DEFINED ENV{PYTORCH_ROCM_ARCH})
447
+ set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
448
+ else()
449
+ set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
450
+ endif()
451
+ #
452
+ # Find the intersection of the supported + detected architectures to
453
+ # set the module architecture flags.
454
+ #
455
+ set(${GPU_ARCHES})
456
+ foreach (_ARCH ${HIP_ARCHITECTURES})
457
+ if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
458
+ list(APPEND ${GPU_ARCHES} ${_ARCH})
459
+ endif()
460
+ endforeach()
461
+
462
+ if(NOT ${GPU_ARCHES})
463
+ message(FATAL_ERROR
464
+ "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
465
+ " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
466
+ endif()
467
+ endif()
468
+ endmacro()
469
+
470
+ #
471
+ # Define a target named `GPU_MOD_NAME` for a single extension. The
472
+ # arguments are:
473
+ #
474
+ # DESTINATION <dest> - Module destination directory.
475
+ # LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
476
+ # etc.
477
+ # SOURCES <sources> - List of source files relative to CMakeLists.txt
478
+ # directory.
479
+ #
480
+ # Optional arguments:
481
+ #
482
+ # ARCHITECTURES <arches> - A list of target GPU architectures in cmake
483
+ # format.
484
+ # Refer `CMAKE_CUDA_ARCHITECTURES` documentation
485
+ # and `CMAKE_HIP_ARCHITECTURES` for more info.
486
+ # ARCHITECTURES will use cmake's defaults if
487
+ # not provided.
488
+ # COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
489
+ # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
490
+ # LIBRARIES <libraries> - Extra link libraries.
491
+ # WITH_SOABI - Generate library with python SOABI suffix name.
492
+ # USE_SABI <version> - Use python stable api <version>
493
+ #
494
+ # Note: optimization level/debug info is set via cmake build type.
495
+ #
496
+ function (define_gpu_extension_target GPU_MOD_NAME)
497
+ cmake_parse_arguments(PARSE_ARGV 1
498
+ GPU
499
+ "WITH_SOABI"
500
+ "DESTINATION;LANGUAGE;USE_SABI"
501
+ "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
502
+
503
+ # Add hipify preprocessing step when building with HIP/ROCm.
504
+ if (GPU_LANGUAGE STREQUAL "HIP")
505
+ hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
506
+ endif()
507
+
508
+ if (GPU_WITH_SOABI)
509
+ set(GPU_WITH_SOABI WITH_SOABI)
510
+ else()
511
+ set(GPU_WITH_SOABI)
512
+ endif()
513
+
514
+ if (GPU_USE_SABI)
515
+ Python3_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
516
+ else()
517
+ Python3_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
518
+ endif()
519
+
520
+ if (GPU_LANGUAGE STREQUAL "HIP")
521
+ # Make this target dependent on the hipify preprocessor step.
522
+ add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
523
+ endif()
524
+
525
+ if (GPU_ARCHITECTURES)
526
+ if (GPU_LANGUAGE STREQUAL "HIP")
527
+ # Clear target architectures, we are passing arch flags per source file.
528
+ set_property(TARGET ${GPU_MOD_NAME} PROPERTY HIP_ARCHITECTURES off)
529
+ else()
530
+ set_target_properties(${GPU_MOD_NAME} PROPERTIES
531
+ ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
532
+ endif()
533
+ endif()
534
+
535
+ set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
536
+
537
+ target_compile_options(${GPU_MOD_NAME} PRIVATE
538
+ $<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
539
+
540
+ target_compile_definitions(${GPU_MOD_NAME} PRIVATE
541
+ "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
542
+
543
+ target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
544
+ ${GPU_INCLUDE_DIRECTORIES})
545
+
546
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
547
+
548
+ # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
549
+ # dependencies that are not necessary and may not be installed.
550
+ if (GPU_LANGUAGE STREQUAL "CUDA")
551
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
552
+ else()
553
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
554
+ endif()
555
+
556
+ install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
557
+ endfunction()
flake.lock ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1761588595,
6
+ "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1761756835,
77
+ "narHash": "sha256-Vjrv8ZIhkQRgQ3MHGVFaj/fUcE4yuGr+vnoKYRwWmYw=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "6839b6998be18679992978c2f3abddc902276280",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1761991868,
102
+ "narHash": "sha256-+csvkWC9jC4mwq1LNfK4O6m3Qg4dCCXjP5JGdPa3TEo=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "79cbfcdfde82c8847551f67f4b951a410794a5c6",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "ref": "metal_kernels",
111
+ "repo": "kernel-builder",
112
+ "type": "github"
113
+ }
114
+ },
115
+ "nixpkgs": {
116
+ "locked": {
117
+ "lastModified": 1755963616,
118
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
119
+ "owner": "nixos",
120
+ "repo": "nixpkgs",
121
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
122
+ "type": "github"
123
+ },
124
+ "original": {
125
+ "owner": "nixos",
126
+ "ref": "nixos-unstable-small",
127
+ "repo": "nixpkgs",
128
+ "type": "github"
129
+ }
130
+ },
131
+ "root": {
132
+ "inputs": {
133
+ "kernel-builder": "kernel-builder"
134
+ }
135
+ },
136
+ "systems": {
137
+ "locked": {
138
+ "lastModified": 1681028828,
139
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
140
+ "owner": "nix-systems",
141
+ "repo": "default",
142
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
143
+ "type": "github"
144
+ },
145
+ "original": {
146
+ "owner": "nix-systems",
147
+ "repo": "default",
148
+ "type": "github"
149
+ }
150
+ },
151
+ "systems_2": {
152
+ "locked": {
153
+ "lastModified": 1681028828,
154
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
+ "owner": "nix-systems",
156
+ "repo": "default",
157
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
+ "type": "github"
159
+ },
160
+ "original": {
161
+ "owner": "nix-systems",
162
+ "repo": "default",
163
+ "type": "github"
164
+ }
165
+ }
166
+ },
167
+ "root": "root",
168
+ "version": 7
169
+ }
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for Torch kernel extension";
3
 
4
  inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs = { self, kernel-builder, }:
 
2
  description = "Flake for Torch kernel extension";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder?ref=metal_kernels";
6
  };
7
 
8
  outputs = { self, kernel-builder, }:
gptoss_kernels/CMakeLists.txt DELETED
@@ -1,191 +0,0 @@
1
- cmake_minimum_required(VERSION 3.24)
2
- project(GPTOSS
3
- VERSION 1.0
4
- DESCRIPTION "Local GPT-OSS inference"
5
- LANGUAGES C CXX OBJC)
6
-
7
- set(CMAKE_C_STANDARD 11)
8
- set(CMAKE_CXX_STANDARD 20)
9
- set(CMAKE_OBJC_STANDARD 11)
10
- set(CMAKE_OBJC_STANDARD_REQUIRED ON)
11
-
12
- find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)
13
- find_library(METAL_FRAMEWORK Metal REQUIRED)
14
- find_library(IOKIT_FRAMEWORK IOKit REQUIRED)
15
-
16
- set(METAL_SOURCES
17
- ${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
18
- ${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
19
- ${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
20
- ${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal
21
- ${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal
22
- ${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
23
- ${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
24
- ${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
25
- ${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
26
- ${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
27
- ${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
28
- ${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal
29
- ${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
30
- ${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
31
- )
32
- set(METAL_LIB default.metallib)
33
-
34
- include_directories(BEFORE include source/include)
35
-
36
- add_custom_command(
37
- OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
38
- COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/source/"
39
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air"
40
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air"
41
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
42
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air"
43
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
44
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
45
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air"
46
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
47
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
48
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
49
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
50
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air"
51
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
52
- COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
53
- COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air" "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
54
- DEPENDS ${METAL_SOURCES}
55
- COMMENT "Compiling Metal compute library"
56
- )
57
-
58
- add_custom_target(build_metallib ALL
59
- DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})
60
-
61
- add_library(log OBJECT source/log.c)
62
-
63
- add_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)
64
- target_link_libraries(metal-kernels PRIVATE log)
65
-
66
- add_dependencies(metal-kernels build_metallib)
67
- add_custom_command(TARGET metal-kernels POST_BUILD
68
- COMMAND ${CMAKE_COMMAND} -E copy
69
- ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
70
- $<TARGET_FILE_DIR:metal-kernels>)
71
-
72
- target_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})
73
-
74
- add_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)
75
- target_link_libraries(gptoss PRIVATE log metal-kernels)
76
-
77
- add_executable(generate source/generate.c)
78
- target_link_libraries(generate gptoss)
79
-
80
- # --- [ Tests
81
- include(FetchContent)
82
- FetchContent_Declare(
83
- googletest
84
- URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip
85
- DOWNLOAD_EXTRACT_TIMESTAMP OFF
86
- )
87
- # For Windows: Prevent overriding the parent project's compiler/linker settings
88
- set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
89
- set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
90
- FetchContent_MakeAvailable(googletest)
91
-
92
- enable_testing()
93
-
94
- add_executable(u32-random-test test/u32-random.cc)
95
- target_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)
96
- target_include_directories(u32-random-test PRIVATE source/include)
97
- add_test(NAME u32-random-test COMMAND u32-random-test)
98
-
99
- add_executable(f32-random-test test/f32-random.cc)
100
- target_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)
101
- target_include_directories(f32-random-test PRIVATE source/include)
102
- add_test(NAME f32-random-test COMMAND f32-random-test)
103
-
104
- add_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)
105
- target_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)
106
- target_include_directories(mf4-f32-convert-test PRIVATE source/include)
107
- add_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)
108
-
109
- add_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)
110
- target_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)
111
- target_include_directories(bf16-f32-embeddings-test PRIVATE source/include)
112
- add_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)
113
-
114
- add_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)
115
- target_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)
116
- target_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)
117
- add_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)
118
-
119
- add_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)
120
- target_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)
121
- target_include_directories(f32-bf16w-matmul-test PRIVATE source/include)
122
- add_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)
123
-
124
- add_executable(f32-rope-test test/f32-rope.cc)
125
- target_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)
126
- target_include_directories(f32-rope-test PRIVATE source/include)
127
- add_test(NAME f32-rope-test COMMAND f32-rope-test)
128
-
129
- # --- [ Benchmarks
130
- include(FetchContent)
131
- set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable self-tests in Google Benchmark" FORCE)
132
- set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable installation of Google Benchmark" FORCE)
133
- FetchContent_Declare(
134
- benchmark
135
- URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip
136
- DOWNLOAD_EXTRACT_TIMESTAMP OFF
137
- )
138
- FetchContent_MakeAvailable(benchmark)
139
-
140
- add_executable(f32-random-bench benchmark/f32-random.cc)
141
- target_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)
142
- target_include_directories(f32-random-bench PRIVATE source/include)
143
-
144
- add_executable(u32-random-bench benchmark/u32-random.cc)
145
- target_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)
146
- target_include_directories(u32-random-bench PRIVATE source/include)
147
-
148
- add_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)
149
- target_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)
150
- target_include_directories(mf4-f32-convert-bench PRIVATE source/include)
151
-
152
- add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
153
- target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
154
- target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
155
-
156
- add_executable(end-to-end-bench benchmark/end-to-end.cc)
157
- target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
158
- target_include_directories(end-to-end-bench PRIVATE source/include)
159
-
160
- add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)
161
- target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)
162
- target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)
163
-
164
- # --- [ Python extension ] -----------------------------------------------
165
- find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
166
-
167
- pybind11_add_module(_metal
168
- python/module.c
169
- python/context.c
170
- python/model.c
171
- python/tokenizer.c
172
- )
173
- set_target_properties(_metal PROPERTIES PREFIX "")
174
-
175
- target_link_libraries(_metal PRIVATE gptoss)
176
- add_dependencies(_metal build_metallib)
177
- target_link_options(_metal PRIVATE
178
- LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
179
- )
180
- add_custom_command(TARGET _metal POST_BUILD
181
- COMMAND ${CMAKE_COMMAND} -E copy
182
- ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
183
- $<TARGET_FILE_DIR:_metal>)
184
-
185
- # 1️⃣ install the extension module into the Python package
186
- install(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)
187
-
188
- # 2️⃣ make sure the Metal shader archive travels with it
189
- install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
190
- DESTINATION gpt_oss/metal)
191
- # ------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gptoss_kernels/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from importlib import import_module as _im
2
-
3
- # Load the compiled extension (gpt_oss.metal._metal)
4
- _ext = _im(f"{__name__}._metal")
5
- globals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith("_")})
6
- del _im, _ext
 
 
 
 
 
 
 
gptoss_kernels/examples/chat.py DELETED
@@ -1,104 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import argparse
4
- import sys
5
-
6
- from datetime import date
7
- from gpt_oss.metal import Context, Model
8
-
9
-
10
- DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
11
- Knowledge cutoff: 2024-06
12
- Current date: {date.today().isoformat()}
13
-
14
- reasoning effort high
15
-
16
- # Valid channels: analysis, final. Channel must be included for every message."""
17
-
18
-
19
- parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
- parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
21
- parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
22
- parser.add_argument(
23
- "--context-length", type=int, default=0, help="The maximum context length"
24
- )
25
- parser.add_argument(
26
- "--temperature", type=float, default=1.0, help="Sampling temperature"
27
- )
28
- parser.add_argument(
29
- "--seed", type=int, default=0, help="Sampling seed"
30
- )
31
-
32
-
33
- GREY = "\33[90m"
34
- BOLD = "\33[1m"
35
- RESET = "\33[0m"
36
-
37
-
38
- def main(args):
39
- options = parser.parse_args(args)
40
- model = Model(options.model)
41
- tokenizer = model.tokenizer
42
- start_token = tokenizer.encode_special_token("<|start|>")
43
- message_token = tokenizer.encode_special_token("<|message|>")
44
- end_token = tokenizer.encode_special_token("<|end|>")
45
- return_token = tokenizer.encode_special_token("<|return|>")
46
- channel_token = tokenizer.encode_special_token("<|channel|>")
47
-
48
- context = Context(model, context_length=options.context_length)
49
- context.append(start_token)
50
- context.append("system")
51
- context.append(message_token)
52
- context.append(options.prompt)
53
- context.append(end_token)
54
-
55
- while True:
56
- context.append(start_token)
57
- context.append("user")
58
- context.append(message_token)
59
- message = input(f"{BOLD}User:{RESET} ").rstrip()
60
- context.append(message)
61
- context.append(end_token)
62
- print(f"{BOLD}Assistant:{RESET} {GREY}", end="", flush=True)
63
- context.append(start_token)
64
- context.append("assistant")
65
- context.append(channel_token)
66
-
67
- inside_start_block = True
68
- inside_channel_block = True
69
- role = "assistant"
70
- channel = ""
71
- while True:
72
- token = context.sample(
73
- temperature=options.temperature,
74
- seed=options.seed,
75
- )
76
- context.append(token)
77
- if token == return_token:
78
- print(flush=True)
79
- break
80
- elif token == start_token:
81
- inside_start_block = True
82
- role = ""
83
- channel = ""
84
- elif token == message_token:
85
- inside_start_block = False
86
- inside_channel_block = False
87
- if channel == "analysis":
88
- print(f"{GREY}", end="", flush=True)
89
- elif token == end_token:
90
- print(f"{RESET}", flush=True)
91
- elif token == channel_token:
92
- inside_channel_block = True
93
- elif token < tokenizer.num_text_tokens:
94
- if inside_channel_block:
95
- channel += str(tokenizer.decode(token), encoding="utf-8")
96
- elif inside_start_block:
97
- role += str(tokenizer.decode(token), encoding="utf-8")
98
- else:
99
- sys.stdout.buffer.write(tokenizer.decode(token))
100
- sys.stdout.buffer.flush()
101
-
102
-
103
- if __name__ == "__main__":
104
- main(sys.argv[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gptoss_kernels/examples/generate.py DELETED
@@ -1,34 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import argparse
4
- import sys
5
-
6
- from gpt_oss.metal import Context, Model
7
-
8
-
9
- parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10
- parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
11
- parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
12
- parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
13
- parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
14
-
15
-
16
- def main(args):
17
- options = parser.parse_args(args)
18
- model = Model(options.model)
19
-
20
- context = Context(model, context_length=options.context_length)
21
- context.append(options.prompt)
22
- print(context.tokens)
23
- prompt_tokens = context.num_tokens
24
-
25
- tokenizer = model.tokenizer
26
-
27
- while context.num_tokens - prompt_tokens < options.limit:
28
- token = context.sample()
29
- context.append(token)
30
- print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
31
-
32
-
33
- if __name__ == '__main__':
34
- main(sys.argv[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gptoss_kernels/source/context.c DELETED
@@ -1,1115 +0,0 @@
1
- #include <assert.h>
2
- #include <float.h>
3
- #include <inttypes.h>
4
- #include <stdbool.h>
5
- #include <stdint.h>
6
- #include <stdlib.h>
7
- #include <string.h>
8
-
9
- #include <gpt-oss.h>
10
-
11
- #include "internal/datatype.h"
12
- #include "internal/model.h"
13
- #include "internal/metal.h"
14
- #include "internal/metal-kernels.h"
15
- #include "internal/log.h"
16
- #include "internal/rng.h"
17
-
18
-
19
- enum gptoss_status GPTOSS_ABI gptoss_context_create(
20
- gptoss_model_t model,
21
- size_t context_length,
22
- size_t max_batch_tokens,
23
- gptoss_context_t* context_out)
24
- {
25
- *context_out = NULL;
26
-
27
- enum gptoss_status status = gptoss_status_success;
28
- struct gptoss_context* context = NULL;
29
-
30
- // Validate context_length
31
- if (context_length == 0) {
32
- context_length = model->context_length;
33
- } else if (context_length > model->context_length) {
34
- GPTOSS_LOG_ERROR("requested context length %zu exceeds model context length %" PRIu32,
35
- context_length, model->context_length);
36
- status = gptoss_status_invalid_argument;
37
- goto cleanup;
38
- }
39
- assert(context_length != 0);
40
- assert(context_length <= model->context_length);
41
-
42
- // Validate max_batch_tokens
43
- if (max_batch_tokens == 0) {
44
- max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
45
- } else if (max_batch_tokens > context_length) {
46
- GPTOSS_LOG_ERROR("requested max batch tokens %zu exceeds context length %zu",
47
- max_batch_tokens, context_length);
48
- status = gptoss_status_invalid_argument;
49
- goto cleanup;
50
- }
51
- assert(max_batch_tokens != 0);
52
- assert(max_batch_tokens <= context_length);
53
-
54
- context = malloc(sizeof(struct gptoss_context));
55
- if (context == NULL) {
56
- GPTOSS_LOG_ERROR("failed to allocate %zu bytes for Context object",
57
- sizeof(struct gptoss_context));
58
- status = gptoss_status_insufficient_memory;
59
- goto cleanup;
60
- }
61
- memset(context, 0, sizeof(struct gptoss_context));
62
-
63
- atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
64
- context->max_tokens = context_length;
65
- context->max_batch_tokens = max_batch_tokens;
66
-
67
- // Activation buffers
68
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
69
- if (status != gptoss_status_success) {
70
- goto cleanup;
71
- }
72
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
73
- if (status != gptoss_status_success) {
74
- goto cleanup;
75
- }
76
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
77
- if (status != gptoss_status_success) {
78
- goto cleanup;
79
- }
80
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
81
- if (status != gptoss_status_success) {
82
- goto cleanup;
83
- }
84
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
85
- if (status != gptoss_status_success) {
86
- goto cleanup;
87
- }
88
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
89
- if (status != gptoss_status_success) {
90
- goto cleanup;
91
- }
92
- // The last entry will hold the total number of tokens.
93
- status = gptoss_metal_buffer_create(&model->device, (1 + model->num_experts) * sizeof(uint32_t), NULL, &context->expert_offset_buffer);
94
- if (status != gptoss_status_success) {
95
- goto cleanup;
96
- }
97
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);
98
- if (status != gptoss_status_success) {
99
- goto cleanup;
100
- }
101
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);
102
- if (status != gptoss_status_success) {
103
- goto cleanup;
104
- }
105
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
106
- if (status != gptoss_status_success) {
107
- goto cleanup;
108
- }
109
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
110
- if (status != gptoss_status_success) {
111
- goto cleanup;
112
- }
113
-
114
- // Input/output buffers
115
- status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
116
- if (status != gptoss_status_success) {
117
- goto cleanup;
118
- }
119
- status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
120
- if (status != gptoss_status_success) {
121
- goto cleanup;
122
- }
123
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);
124
- if (status != gptoss_status_success) {
125
- goto cleanup;
126
- }
127
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);
128
- if (status != gptoss_status_success) {
129
- goto cleanup;
130
- }
131
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);
132
- if (status != gptoss_status_success) {
133
- goto cleanup;
134
- }
135
- status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);
136
- if (status != gptoss_status_success) {
137
- goto cleanup;
138
- }
139
- status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);
140
- if (status != gptoss_status_success) {
141
- goto cleanup;
142
- }
143
-
144
- context->kvcache_size = context->kvcache_buffer.size;
145
- context->allocation_size =
146
- context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
147
- context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
148
- context->gate_activation_buffer.size + context->expert_activation_buffer.size +
149
- context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +
150
- context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
151
- context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
152
-
153
- context->model = model;
154
- gptoss_model_retain(model);
155
- *context_out = context;
156
- context = NULL;
157
-
158
- cleanup:
159
- gptoss_context_release(context);
160
- return status;
161
- }
162
-
163
- enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
164
- gptoss_context_t context,
165
- size_t* num_tokens_out)
166
- {
167
- *num_tokens_out = context->num_tokens;
168
- return gptoss_status_success;
169
- }
170
-
171
- enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
172
- gptoss_context_t context,
173
- size_t* max_tokens_out)
174
- {
175
- *max_tokens_out = context->max_tokens;
176
- return gptoss_status_success;
177
- }
178
-
179
- enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
180
- gptoss_context_t context,
181
- uint32_t* tokens_out,
182
- size_t max_tokens,
183
- size_t* num_tokens_out)
184
- {
185
- *num_tokens_out = context->num_tokens;
186
- if (max_tokens < context->num_tokens) {
187
- return gptoss_status_insufficient_memory;
188
- }
189
-
190
- if (context->num_tokens != 0) {
191
- memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));
192
- }
193
- return gptoss_status_success;
194
- }
195
-
196
- // Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.
197
- // Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.
198
- // Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
199
- static enum gptoss_status process_tokens(
200
- gptoss_context_t context,
201
- struct gptoss_metal_command_buffer* command_buffer,
202
- size_t input_tokens_offset,
203
- size_t num_input_tokens,
204
- size_t num_output_tokens)
205
- {
206
- assert(num_input_tokens != 0);
207
- assert(num_input_tokens <= context->max_batch_tokens);
208
- assert(num_output_tokens <= context->max_batch_tokens);
209
- assert(num_input_tokens >= num_output_tokens);
210
- const size_t dense_matmul_kernel_token_multiple_constraint = 64;
211
- const size_t min_tokens_for_dense_moe_kernels = 64;
212
-
213
- enum gptoss_status status = gptoss_status_success;
214
- const struct gptoss_model* model = context->model;
215
-
216
- const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
217
-
218
- const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
219
- for (size_t input_batch_start = input_tokens_offset;
220
- input_batch_start < input_tokens_end;
221
- input_batch_start += context->max_batch_tokens)
222
- {
223
- const size_t input_batch_size = math_min(context->max_batch_tokens, input_tokens_end - input_batch_start);
224
- const size_t input_batch_end = input_batch_start + input_batch_size;
225
- const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
226
-
227
- status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
228
- command_buffer,
229
- &model->bf16_f32_embeddings_fn,
230
- model->embeddings_threadgroup_size,
231
- &context->token_buffer,
232
- input_batch_start * sizeof(uint32_t),
233
- &model->shared_weight_buffer,
234
- /*weight_offset=*/0,
235
- &context->residual_activation_buffer,
236
- /*output_offset=*/0,
237
- &context->control_buffer,
238
- /*control_offset=*/0,
239
- /*num_tokens=*/input_batch_size,
240
- /*num_channels=*/model->embedding_dim);
241
- if (status != gptoss_status_success) {
242
- GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
243
- return status;
244
- }
245
- for (uint32_t n = 0; n < model->num_blocks; n++) {
246
- const bool last_block = n + 1 == model->num_blocks;
247
- const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
248
-
249
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
250
- command_buffer,
251
- &model->f32_bf16w_rmsnorm_fn,
252
- &context->residual_activation_buffer,
253
- /*input_offset=*/0,
254
- &model->shared_weight_buffer,
255
- /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
256
- &context->rmsnorm_activation_buffer,
257
- /*output_offset=*/0,
258
- &context->control_buffer,
259
- /*control_offset=*/0,
260
- /*num_tokens=*/input_batch_size,
261
- /*num_channels=*/model->embedding_dim,
262
- model->rmsnorm_epsilon);
263
- if (status != gptoss_status_success) {
264
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
265
- return status;
266
- }
267
-
268
- if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
269
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
270
- command_buffer,
271
- &model->f32_bf16w_dense_matmul_qkv_fn,
272
- &context->rmsnorm_activation_buffer,
273
- /*input_offset=*/0,
274
- &model->shared_weight_buffer,
275
- /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
276
- &model->shared_weight_buffer,
277
- /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
278
- &context->qkv_activation_buffer,
279
- /*output_offset=*/0,
280
- &context->control_buffer,
281
- /*control_offset=*/0,
282
- /*num_tokens=*/input_batch_size,
283
- /*num_cols=*/model->embedding_dim,
284
- /*num_rows=*/attn_qkv_dim);
285
- if (status != gptoss_status_success) {
286
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
287
- return status;
288
- }
289
-
290
- status = gptoss_metal_command_buffer_encode_launch_f32_rope(
291
- command_buffer,
292
- &model->f32_rope_fn,
293
- /*threadgroup_size=*/32,
294
- &context->qkv_activation_buffer,
295
- /*input_offset=*/0,
296
- &context->control_buffer,
297
- /*control_offset=*/0,
298
- model->rope_theta,
299
- model->interpolation_scale,
300
- model->yarn_offset,
301
- model->yarn_scale,
302
- model->yarn_multiplier,
303
- input_batch_size,
304
- model->num_heads,
305
- model->num_kv_heads,
306
- model->head_dim,
307
- /*token_offset=*/input_batch_start);
308
- if (status != gptoss_status_success) {
309
- GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
310
- return status;
311
- }
312
-
313
- for (uint32_t t = 0; t < input_batch_size; t++) {
314
- for (uint32_t kv = 0; kv < 2; kv++) {
315
- for (uint32_t h = 0; h < model->num_kv_heads; h++) {
316
- status = gptoss_metal_command_buffer_encode_copy_buffer(
317
- command_buffer,
318
- &context->qkv_activation_buffer,
319
- /*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float),
320
- &context->kvcache_buffer,
321
- /*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float),
322
- /*size=*/model->head_dim * sizeof(float));
323
- if (status != gptoss_status_success) {
324
- GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
325
- return status;
326
- }
327
- }
328
- }
329
- }
330
- } else {
331
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
332
- command_buffer,
333
- &model->f32_bf16w_matmul_qkv_fn,
334
- model->attn_qkv_threadgroup_size,
335
- &context->rmsnorm_activation_buffer,
336
- /*input_offset=*/0,
337
- &model->shared_weight_buffer,
338
- /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
339
- &model->shared_weight_buffer,
340
- /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
341
- &context->qkv_activation_buffer,
342
- /*output_offset=*/0,
343
- &context->kvcache_buffer,
344
- /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
345
- &context->control_buffer,
346
- /*control_offset=*/0,
347
- /*num_tokens=*/input_batch_size,
348
- /*num_cols=*/model->embedding_dim,
349
- /*num_q_heads=*/model->num_heads,
350
- /*num_kv_heads=*/model->num_kv_heads,
351
- /*attn_head_dim=*/model->head_dim,
352
- /*token_offset=*/input_batch_start,
353
- /*max_tokens=*/context->max_tokens,
354
- /*rope_base=*/model->rope_theta,
355
- /*interpolation_scale=*/model->interpolation_scale,
356
- /*yarn_offset=*/model->yarn_offset,
357
- /*yarn_scale=*/model->yarn_scale,
358
- /*yarn_multiplier=*/model->yarn_multiplier);
359
- if (status != gptoss_status_success) {
360
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
361
- return status;
362
- }
363
- }
364
-
365
- if (num_block_output_tokens != 0) {
366
- status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
367
- command_buffer,
368
- &model->f32_sdpa_q8_d64_fn,
369
- &context->qkv_activation_buffer,
370
- /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
371
- &context->kvcache_buffer,
372
- /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
373
- &model->shared_weight_buffer,
374
- /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
375
- &context->sdpa_activation_buffer,
376
- /*output_offset=*/0,
377
- &context->control_buffer,
378
- /*control_offset=*/0,
379
- /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
380
- /*kv_stride=*/2 * context->max_tokens * model->head_dim,
381
- num_block_output_tokens,
382
- input_batch_start + input_batch_size - num_block_output_tokens,
383
- model->num_heads, model->num_kv_heads, model->head_dim);
384
- if (status != gptoss_status_success) {
385
- GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
386
- return status;
387
- }
388
-
389
- if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
390
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
391
- command_buffer,
392
- &model->f32_bf16w_dense_matmul_attn_output_fn,
393
- &context->sdpa_activation_buffer,
394
- /*input_offset=*/0,
395
- &model->shared_weight_buffer,
396
- /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
397
- &model->shared_weight_buffer,
398
- /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
399
- &context->residual_activation_buffer,
400
- /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
401
- &context->control_buffer,
402
- /*control_offset=*/0,
403
- /*num_tokens=*/num_block_output_tokens,
404
- /*num_cols=*/model->num_heads * model->head_dim,
405
- /*num_rows=*/model->embedding_dim);
406
- if (status != gptoss_status_success) {
407
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch");
408
- return status;
409
- }
410
- } else {
411
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
412
- command_buffer,
413
- &model->f32_bf16w_matmul_fn,
414
- model->attn_out_threadgroup_size,
415
- &context->sdpa_activation_buffer,
416
- /*input_offset=*/0,
417
- &model->shared_weight_buffer,
418
- /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
419
- &model->shared_weight_buffer,
420
- /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
421
- &context->residual_activation_buffer,
422
- /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
423
- &context->control_buffer,
424
- /*control_offset=*/0,
425
- /*num_tokens=*/num_block_output_tokens,
426
- /*num_cols=*/model->num_heads * model->head_dim,
427
- /*num_rows=*/model->embedding_dim);
428
- if (status != gptoss_status_success) {
429
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
430
- return status;
431
- }
432
- }
433
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
434
- command_buffer,
435
- &model->f32_bf16w_rmsnorm_fn,
436
- &context->residual_activation_buffer,
437
- /*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
438
- &model->shared_weight_buffer,
439
- /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
440
- &context->rmsnorm_activation_buffer,
441
- /*output_offset=*/0,
442
- &context->control_buffer,
443
- /*control_offset=*/0,
444
- num_block_output_tokens,
445
- model->embedding_dim,
446
- model->rmsnorm_epsilon);
447
- if (status != gptoss_status_success) {
448
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
449
- return status;
450
- }
451
- if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
452
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
453
- command_buffer,
454
- &model->f32_bf16w_dense_matmul_mlp_gate_fn,
455
- &context->rmsnorm_activation_buffer,
456
- /*input_offset=*/0,
457
- &model->shared_weight_buffer,
458
- /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
459
- &model->shared_weight_buffer,
460
- /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
461
- &context->gate_activation_buffer,
462
- /*output_offset=*/0,
463
- &context->control_buffer,
464
- /*control_offset=*/0,
465
- num_block_output_tokens,
466
- model->embedding_dim,
467
- model->num_experts);
468
- if (status != gptoss_status_success) {
469
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch");
470
- return status;
471
- }
472
- } else {
473
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
474
- command_buffer,
475
- &model->f32_bf16w_matmul_fn,
476
- model->mlp_gate_threadgroup_size,
477
- &context->rmsnorm_activation_buffer,
478
- /*input_offset=*/0,
479
- &model->shared_weight_buffer,
480
- /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
481
- &model->shared_weight_buffer,
482
- /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
483
- &context->gate_activation_buffer,
484
- /*output_offset=*/0,
485
- &context->control_buffer,
486
- /*control_offset=*/0,
487
- /*num_tokens=*/num_block_output_tokens,
488
- /*num_cols=*/model->embedding_dim,
489
- /*num_rows=*/model->num_experts);
490
- if (status != gptoss_status_success) {
491
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
492
- return status;
493
- }
494
- }
495
-
496
- const char* kernel_name = NULL;
497
- switch (model->num_experts) {
498
- case 32:
499
- kernel_name = "f32_topk_softmax_e32_k4_fn";
500
- status = gptoss_metal_command_buffer_encode_launch_f32_topk(
501
- command_buffer,
502
- &model->f32_topk_softmax_e32_k4_fn,
503
- &context->gate_activation_buffer, /*input_offset=*/0,
504
- &context->expert_activation_buffer, /*output_offset=*/0,
505
- &context->control_buffer, /*control_offset=*/0,
506
- num_block_output_tokens,
507
- model->num_experts,
508
- model->num_active_experts);
509
- break;
510
- case 128:
511
- kernel_name = "f32_topk_softmax_e128_k4_fn";
512
- status = gptoss_metal_command_buffer_encode_launch_f32_topk(
513
- command_buffer,
514
- &model->f32_topk_softmax_e128_k4_fn,
515
- &context->gate_activation_buffer, /*input_offset=*/0,
516
- &context->expert_activation_buffer, /*output_offset=*/0,
517
- &context->control_buffer, /*control_offset=*/0,
518
- num_block_output_tokens,
519
- model->num_experts,
520
- model->num_active_experts);
521
- break;
522
- default:
523
- status = gptoss_status_unsupported_argument;
524
- GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
525
- return status;
526
- }
527
- if (status != gptoss_status_success) {
528
- GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
529
- return status;
530
- }
531
-
532
- // If we have enough tokens in prefill, we will pick the prefill-optimized kernels.
533
- if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {
534
- status = gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
535
- command_buffer,
536
- &model->f32_expert_routing_metadata_fn,
537
- &context->expert_activation_buffer,
538
- /*expert_predictions_offset=*/0,
539
- &context->expert_offset_buffer,
540
- /*expert_offsets_offset=*/0,
541
- &context->token_to_expert_routing_buffer,
542
- /*intra_expert_offsets_offset=*/0,
543
- num_block_output_tokens * model->num_active_experts,
544
- model->num_experts);
545
- if (status != gptoss_status_success) {
546
- GPTOSS_LOG_ERROR("failed to encode f32_expert_routing_metadata kernel launch");
547
- return status;
548
- }
549
- status = gptoss_metal_command_buffer_encode_launch_f32_scatter(
550
- command_buffer,
551
- &model->f32_scatter_e4_fn,
552
- &context->rmsnorm_activation_buffer,
553
- /*input_offset=*/0,
554
- &context->expert_activation_buffer,
555
- /*expert_predictions_offset=*/0,
556
- &context->expert_offset_buffer,
557
- /*expert_offsets_offset=*/0,
558
- &context->token_to_expert_routing_buffer,
559
- /*intra_expert_offsets_offset=*/0,
560
- &context->swiglu_input_buffer,
561
- /*output_offset=*/0,
562
- model->embedding_dim,
563
- num_block_output_tokens,
564
- model->num_active_experts);
565
- if (status != gptoss_status_success) {
566
- GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch");
567
- return status;
568
- }
569
- // Dense MoE SwiGLU matmul.
570
- status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
571
- command_buffer,
572
- &model->f32_mf4w_moe_dense_matmul_swiglu_fn,
573
- &context->expert_offset_buffer,
574
- /*expert_offsets_offset=*/0,
575
- &context->swiglu_input_buffer,
576
- /*input_offset=*/0,
577
- &model->block_weight_buffers[n],
578
- /*weight_block_offset=*/0,
579
- &model->block_weight_buffers[n],
580
- /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
581
- &model->block_weight_buffers[n],
582
- /*bias_offset=*/model->mlp_swiglu_bias_offset,
583
- &context->swiglu_activation_buffer,
584
- /*output_offset=*/0,
585
- model->swiglu_limit,
586
- /*expert_stride_bytes=*/model->per_expert_block_weight_size,
587
- num_block_output_tokens,
588
- model->num_experts,
589
- model->embedding_dim,
590
- 2 * model->mlp_dim);
591
- if (status != gptoss_status_success) {
592
- GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
593
- return status;
594
- }
595
-
596
- // Dense MoE proj matmul.
597
- status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
598
- command_buffer,
599
- &model->f32_mf4w_moe_dense_matmul_fn,
600
- &context->expert_offset_buffer,
601
- /*expert_offsets_offset=*/0,
602
- &context->swiglu_activation_buffer,
603
- /*input_offset=*/0,
604
- &model->block_weight_buffers[n],
605
- /*weight_block_offset=*/model->mlp_out_block_offset,
606
- &model->block_weight_buffers[n],
607
- /*weight_scale_offset=*/model->mlp_out_scale_offset,
608
- &model->block_weight_buffers[n],
609
- /*bias_offset=*/model->mlp_out_bias_offset,
610
- &context->moe_activation_buffer,
611
- /*output_offset=*/0,
612
- /*expert_stride_bytes=*/model->per_expert_block_weight_size,
613
- num_block_output_tokens,
614
- model->num_experts,
615
- model->mlp_dim,
616
- model->embedding_dim);
617
- if (status != gptoss_status_success) {
618
- GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
619
- return status;
620
- }
621
- // Gather and accumulate.
622
- status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
623
- command_buffer,
624
- &model->f32_gather_and_accumulate_e4_fn,
625
- &context->moe_activation_buffer,
626
- /*input_offset=*/0,
627
- &context->expert_activation_buffer,
628
- /*expert_predictions_offset=*/0,
629
- &context->expert_offset_buffer,
630
- /*expert_offsets_offset=*/0,
631
- &context->token_to_expert_routing_buffer,
632
- /*intra_expert_offsets_offset=*/0,
633
- &context->residual_activation_buffer,
634
- /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
635
- model->embedding_dim,
636
- num_block_output_tokens,
637
- model->num_active_experts);
638
- if (status != gptoss_status_success) {
639
- GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch");
640
- return status;
641
- }
642
-
643
- } else {
644
- status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
645
- command_buffer,
646
- &model->f32_mf4w_moe_matmul_swiglu_fn,
647
- model->mlp_swiglu_threadgroup_size,
648
- &context->rmsnorm_activation_buffer,
649
- /*input_offset=*/0,
650
- &context->expert_activation_buffer,
651
- /*expert_offset=*/0,
652
- &model->block_weight_buffers[n],
653
- /*weight_block_offset=*/0,
654
- &model->block_weight_buffers[n],
655
- /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
656
- &model->block_weight_buffers[n],
657
- /*bias_offset=*/model->mlp_swiglu_bias_offset,
658
- &context->swiglu_activation_buffer,
659
- /*output_offset=*/0,
660
- &context->control_buffer,
661
- /*control_offset=*/0,
662
- model->swiglu_limit,
663
- model->per_expert_block_weight_size,
664
- num_block_output_tokens,
665
- model->num_active_experts,
666
- model->embedding_dim,
667
- model->mlp_dim);
668
- if (status != gptoss_status_success) {
669
- GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
670
- return status;
671
- }
672
-
673
- status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
674
- command_buffer,
675
- &model->f32_mf4w_moe_matmul_fn,
676
- model->mlp_out_threadgroup_size,
677
- &context->swiglu_activation_buffer,
678
- /*input_offset=*/0,
679
- &context->expert_activation_buffer,
680
- /*expert_offset=*/0,
681
- &model->block_weight_buffers[n],
682
- /*weight_block_offset=*/model->mlp_out_block_offset,
683
- &model->block_weight_buffers[n],
684
- /*weight_scale_offset=*/model->mlp_out_scale_offset,
685
- &model->block_weight_buffers[n],
686
- /*bias_offset=*/model->mlp_out_bias_offset,
687
- &context->moe_activation_buffer,
688
- /*output_offset=*/0,
689
- &context->control_buffer,
690
- /*control_offset=*/0,
691
- model->per_expert_block_weight_size,
692
- num_block_output_tokens,
693
- model->num_active_experts,
694
- model->mlp_dim,
695
- model->embedding_dim);
696
- if (status != gptoss_status_success) {
697
- GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
698
- return status;
699
- }
700
-
701
- status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
702
- command_buffer,
703
- &model->f32_accumulate_e4_fn,
704
- model->mlp_acc_threadgroup_size,
705
- model->max_threadgroups,
706
- &context->moe_activation_buffer,
707
- /*input_offset=*/0,
708
- &context->expert_activation_buffer,
709
- /*expert_offset=*/0,
710
- &context->residual_activation_buffer,
711
- /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
712
- &context->control_buffer,
713
- /*control_offset=*/0,
714
- model->embedding_dim,
715
- num_block_output_tokens,
716
- model->num_active_experts);
717
- if (status != gptoss_status_success) {
718
- GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
719
- return status;
720
- }
721
- }
722
- }
723
- }
724
-
725
- if (output_batch_size != 0) {
726
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
727
- command_buffer,
728
- &model->f32_bf16w_rmsnorm_fn,
729
- &context->residual_activation_buffer,
730
- /*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
731
- &model->shared_weight_buffer,
732
- /*weight_offset=*/model->rmsnorm_weight_offset,
733
- &context->rmsnorm_activation_buffer,
734
- /*output_offset=*/0,
735
- &context->control_buffer,
736
- /*control_offset=*/0,
737
- /*num_tokens=*/output_batch_size,
738
- /*num_channels=*/model->embedding_dim,
739
- model->rmsnorm_epsilon);
740
- if (status != gptoss_status_success) {
741
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
742
- return status;
743
- }
744
-
745
- status = gptoss_metal_command_buffer_encode_fill_buffer(
746
- command_buffer,
747
- &context->argmax_buffer,
748
- /*offset=*/0,
749
- /*size=*/sizeof(uint64_t) * output_batch_size,
750
- /*fill_value=*/0xFF);
751
- if (status != gptoss_status_success) {
752
- GPTOSS_LOG_ERROR("failed to encode fill buffer command");
753
- return status;
754
- }
755
-
756
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
757
- command_buffer,
758
- &model->f32_bf16w_unembedding_fn,
759
- model->unembedding_threadgroup_size,
760
- model->max_threadgroups,
761
- &context->rmsnorm_activation_buffer,
762
- /*input_offset=*/0,
763
- &model->shared_weight_buffer,
764
- /*weight_offset=*/model->unembedding_weight_offset,
765
- &context->score_buffer,
766
- /*output_offset=*/0,
767
- &context->argmax_buffer,
768
- /*argmax_offset=*/0,
769
- &context->control_buffer,
770
- /*control_offset=*/0,
771
- /*num_tokens=*/output_batch_size,
772
- /*num_cols=*/model->embedding_dim,
773
- /*num_rows=*/model->vocabulary_size);
774
- if (status != gptoss_status_success) {
775
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
776
- return status;
777
- }
778
- }
779
- }
780
- return gptoss_status_success;
781
- }
782
-
783
- enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
784
- gptoss_context_t context,
785
- const char* text,
786
- size_t text_length,
787
- size_t* num_tokens_out)
788
- {
789
- enum gptoss_status status = gptoss_status_success;
790
- const struct gptoss_model* model = context->model;
791
- const struct gptoss_tokenizer* tokenizer = model->tokenizer;
792
- size_t num_appended_tokens = 0;
793
- while (text_length != 0) {
794
- if (context->num_tokens == context->max_tokens) {
795
- status = gptoss_status_context_overflow;
796
- break;
797
- }
798
- const char* tokens = tokenizer->tokens_ptr;
799
- uint32_t best_token = UINT32_MAX;
800
- uint32_t best_token_length = 0;
801
- for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {
802
- uint16_t token_length;
803
- memcpy(&token_length, tokens, sizeof(uint16_t));
804
- tokens += sizeof(uint16_t);
805
- if (token_length <= text_length && token_length > best_token_length) {
806
- if (memcmp(text, tokens, token_length) == 0) {
807
- if (token_length > best_token_length) {
808
- best_token = (uint32_t) t;
809
- best_token_length = token_length;
810
- }
811
- }
812
- }
813
- tokens += token_length;
814
- }
815
-
816
- if (best_token == UINT32_MAX) {
817
- GPTOSS_LOG_ERROR("failed to tokenize text \"%.*s\"", (int) text_length, text);
818
- return gptoss_status_invalid_argument;
819
- }
820
-
821
- uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
822
- if (context->num_kv_tokens > context->num_tokens) {
823
- if (input_tokens[context->num_tokens] != best_token) {
824
- input_tokens[context->num_tokens] = best_token;
825
-
826
- // Invalidate the KV cache starting with the newly added token.
827
- context->num_kv_tokens = context->num_tokens;
828
- }
829
- context->num_tokens++;
830
- } else {
831
- input_tokens[context->num_tokens++] = best_token;
832
- }
833
- num_appended_tokens++;
834
- text += best_token_length;
835
- text_length -= best_token_length;
836
- }
837
- if (num_tokens_out != NULL) {
838
- *num_tokens_out = num_appended_tokens;
839
- }
840
- return status;
841
- }
842
-
843
- enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
844
- gptoss_context_t context,
845
- size_t num_tokens,
846
- const uint32_t* tokens)
847
- {
848
- const struct gptoss_model* model = context->model;
849
-
850
- // Validate all tokens
851
- for (size_t t = 0; t < num_tokens; t++) {
852
- const uint32_t token = tokens[t];
853
- if (token >= model->vocabulary_size) {
854
- GPTOSS_LOG_ERROR("token %" PRIu32 " at index %zu is out of bounds for vocabulary size %" PRIu32,
855
- token, t, context->model->vocabulary_size);
856
- return gptoss_status_invalid_argument;
857
- }
858
- }
859
-
860
- enum gptoss_status status = gptoss_status_success;
861
- uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
862
- while (num_tokens != 0) {
863
- if (context->num_tokens == context->max_tokens) {
864
- status = gptoss_status_context_overflow;
865
- break;
866
- }
867
-
868
- if (context->num_kv_tokens > context->num_tokens) {
869
- const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);
870
- size_t num_verified_tokens = 0;
871
- for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
872
- if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
873
- // Invalidate the KV cache starting with the newly added tokens.
874
- context->num_kv_tokens = context->num_tokens + num_verified_tokens;
875
- break;
876
- }
877
- }
878
-
879
- context->num_tokens += num_verified_tokens;
880
- tokens += num_verified_tokens;
881
- num_tokens -= num_verified_tokens;
882
- } else {
883
- const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);
884
- memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
885
- context->num_tokens += num_tokens_to_copy;
886
- tokens += num_tokens_to_copy;
887
- num_tokens -= num_tokens_to_copy;
888
- }
889
- }
890
-
891
- return status;
892
- }
893
-
894
- enum gptoss_status GPTOSS_ABI gptoss_context_process(
895
- gptoss_context_t context)
896
- {
897
- if (context->num_tokens > context->num_kv_tokens) {
898
- struct gptoss_metal_command_buffer command_buffer = {0};
899
-
900
- enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
901
- if (status != gptoss_status_success) {
902
- goto cleanup;
903
- }
904
-
905
- struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
906
- control->abort = 0;
907
-
908
- status = process_tokens(
909
- context,
910
- &command_buffer,
911
- /*input_tokens_offset=*/context->num_kv_tokens,
912
- /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
913
- /*num_output_tokens=*/0);
914
- if (status != gptoss_status_success) {
915
- goto cleanup;
916
- }
917
-
918
- status = gptoss_metal_command_buffer_commit(&command_buffer);
919
- if (status != gptoss_status_success) {
920
- goto cleanup;
921
- }
922
-
923
- status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
924
- if (status != gptoss_status_success) {
925
- goto cleanup;
926
- }
927
-
928
- context->num_kv_tokens = context->num_tokens;
929
-
930
- cleanup:
931
- gptoss_metal_command_buffer_release(&command_buffer);
932
- return status;
933
- }
934
-
935
- return gptoss_status_success;
936
- }
937
-
938
- enum gptoss_status GPTOSS_ABI gptoss_context_sample(
939
- gptoss_context_t context,
940
- float temperature,
941
- uint64_t seed,
942
- size_t max_tokens,
943
- uint32_t* tokens_out,
944
- size_t* num_tokens_out)
945
- {
946
- enum gptoss_status status = gptoss_status_success;
947
- const struct gptoss_model* model = context->model;
948
- struct gptoss_metal_command_buffer command_buffer = {0};
949
-
950
- *num_tokens_out = 0;
951
-
952
- const uint32_t num_original_tokens = context->num_tokens;
953
-
954
- status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
955
- if (status != gptoss_status_success) {
956
- goto cleanup;
957
- }
958
-
959
- struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
960
- control->abort = 0;
961
-
962
- for (size_t t = 0; t < max_tokens; t++) {
963
- if (context->num_kv_tokens < context->num_tokens) {
964
- status = process_tokens(
965
- context,
966
- &command_buffer,
967
- /*input_tokens_offset=*/context->num_kv_tokens,
968
- /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
969
- /*num_output_tokens=*/1);
970
- context->num_kv_tokens = context->num_tokens;
971
- } else {
972
- status = process_tokens(
973
- context,
974
- &command_buffer,
975
- /*input_tokens_offset=*/context->num_tokens - 1,
976
- /*num_input_tokens=*/1,
977
- /*num_output_tokens=*/1);
978
- }
979
- if (status != gptoss_status_success) {
980
- goto cleanup;
981
- }
982
-
983
- if (temperature != 0.0f) {
984
- assert(context->num_processed_tokens != 0);
985
- uint32_t num_threadgroups = 0;
986
- uint32_t num_dims_per_threadgroup = 0;
987
- status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
988
- &command_buffer,
989
- &model->f32_softmax_fn,
990
- /*threadgroup_size=*/512,
991
- model->max_threadgroups,
992
- &context->score_buffer,
993
- /*score_offset=*/0,
994
- &context->argmax_buffer,
995
- /*argmax_offset=*/0,
996
- &context->prob_buffer,
997
- /*prob_offset=*/0,
998
- &context->sum_buffer,
999
- /*sum_offset=*/0,
1000
- &context->control_buffer,
1001
- /*control_offset=*/0,
1002
- model->vocabulary_size,
1003
- /*num_tokens=*/1,
1004
- temperature,
1005
- &num_threadgroups,
1006
- &num_dims_per_threadgroup);
1007
- if (status != gptoss_status_success) {
1008
- GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
1009
- goto cleanup;
1010
- }
1011
-
1012
- status = gptoss_metal_command_buffer_encode_launch_f32_sample(
1013
- &command_buffer,
1014
- &model->f32_sample_fn,
1015
- /*min_threadgroup_size=*/512,
1016
- &context->prob_buffer,
1017
- /*prob_offset=*/0,
1018
- &context->sum_buffer,
1019
- /*sum_offset=*/0,
1020
- &context->token_buffer,
1021
- /*token_offset=*/context->num_tokens * sizeof(uint32_t),
1022
- &context->control_buffer,
1023
- /*control_offset=*/0,
1024
- /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
1025
- /*rng_offset=*/context->num_tokens,
1026
- /*num_blocks=*/num_threadgroups,
1027
- /*num_channels=*/model->vocabulary_size,
1028
- /*num_channels_per_block=*/num_dims_per_threadgroup);
1029
- if (status != gptoss_status_success) {
1030
- GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
1031
- goto cleanup;
1032
- }
1033
- } else {
1034
- status = gptoss_metal_command_buffer_encode_copy_buffer(
1035
- &command_buffer,
1036
- &context->argmax_buffer,
1037
- /*input_offset=*/0,
1038
- &context->token_buffer,
1039
- /*output_offset=*/context->num_tokens * sizeof(uint32_t),
1040
- /*size=*/sizeof(uint32_t));
1041
- if (status != gptoss_status_success) {
1042
- GPTOSS_LOG_ERROR("failed to encode copy buffer");
1043
- goto cleanup;
1044
- }
1045
- }
1046
- context->num_tokens += 1;
1047
- context->num_kv_tokens = context->num_tokens;
1048
- }
1049
-
1050
- gptoss_metal_command_buffer_commit(&command_buffer);
1051
- gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
1052
-
1053
- const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
1054
- const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
1055
- memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
1056
- *num_tokens_out = num_generated_tokens;
1057
-
1058
- cleanup:
1059
- gptoss_metal_command_buffer_release(&command_buffer);
1060
- return status;
1061
- }
1062
-
1063
- enum gptoss_status GPTOSS_ABI gptoss_context_reset(
1064
- gptoss_context_t context)
1065
- {
1066
- context->num_tokens = 0;
1067
-
1068
- // Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.
1069
- // If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.
1070
-
1071
- return gptoss_status_success;
1072
- }
1073
-
1074
- enum gptoss_status GPTOSS_ABI gptoss_context_retain(
1075
- gptoss_context_t context)
1076
- {
1077
- atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);
1078
- return gptoss_status_success;
1079
- }
1080
-
1081
- enum gptoss_status GPTOSS_ABI gptoss_context_release(
1082
- gptoss_context_t context)
1083
- {
1084
- if (context != NULL) {
1085
- if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
1086
- // Activation buffers
1087
- gptoss_metal_buffer_release(&context->residual_activation_buffer);
1088
- gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
1089
- gptoss_metal_buffer_release(&context->qkv_activation_buffer);
1090
- gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
1091
- gptoss_metal_buffer_release(&context->gate_activation_buffer);
1092
- gptoss_metal_buffer_release(&context->expert_activation_buffer);
1093
- gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
1094
- gptoss_metal_buffer_release(&context->moe_activation_buffer);
1095
- gptoss_metal_buffer_release(&context->expert_offset_buffer);
1096
- gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);
1097
- gptoss_metal_buffer_release(&context->swiglu_input_buffer);
1098
-
1099
- // Input/output buffers
1100
- gptoss_metal_buffer_release(&context->control_buffer);
1101
- gptoss_metal_buffer_release(&context->token_buffer);
1102
- gptoss_metal_buffer_release(&context->score_buffer);
1103
- gptoss_metal_buffer_release(&context->prob_buffer);
1104
- gptoss_metal_buffer_release(&context->sum_buffer);
1105
- gptoss_metal_buffer_release(&context->argmax_buffer);
1106
- gptoss_metal_buffer_release(&context->kvcache_buffer);
1107
-
1108
- gptoss_model_release(context->model);
1109
-
1110
- memset(context, 0, sizeof(struct gptoss_context));
1111
- free(context);
1112
- }
1113
- }
1114
- return gptoss_status_success;
1115
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gptoss_kernels/source/generate.c DELETED
@@ -1,317 +0,0 @@
1
- #include <assert.h>
2
- #include <inttypes.h>
3
- #include <math.h>
4
- #include <signal.h>
5
- #include <stdatomic.h>
6
- #include <stdbool.h>
7
- #include <stdio.h>
8
- #include <stdint.h>
9
- #include <stdlib.h>
10
- #include <string.h>
11
-
12
- #include <mach/mach_time.h>
13
-
14
- #include <gpt-oss.h>
15
-
16
- #include "internal/model.h"
17
-
18
- struct {
19
- atomic_uint_least64_t inference_bytes;
20
- atomic_size_t num_prefill_tokens;
21
- atomic_uint_least64_t prefill_microseconds;
22
- atomic_size_t num_generated_tokens;
23
- atomic_uint_least64_t generation_microseconds;
24
- } globals = {
25
- .inference_bytes = 0,
26
- .num_prefill_tokens = 0,
27
- .prefill_microseconds = 0,
28
- .num_generated_tokens = 0,
29
- .generation_microseconds = 0,
30
- };
31
-
32
- struct options {
33
- const char* model;
34
- const char* prompt;
35
- size_t context_length;
36
- size_t max_tokens;
37
- float temperature;
38
- bool verbose;
39
- };
40
-
41
- static inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {
42
- static mach_timebase_info_data_t timebase_info = {0};
43
- if (timebase_info.denom == 0) {
44
- mach_timebase_info(&timebase_info);
45
- }
46
- const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
47
- return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);
48
- }
49
-
50
- static inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {
51
- static mach_timebase_info_data_t timebase_info = {0};
52
- if (timebase_info.denom == 0) {
53
- mach_timebase_info(&timebase_info);
54
- }
55
- const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
56
- const uint64_t denominator = timebase_info.denom * UINT64_C(1000);
57
- return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;
58
- }
59
-
60
- static void print_usage(const char* program_name) {
61
- printf("Usage: %s <model-path> [-p <prompt>] [-n <tokens>]\n", program_name);
62
- }
63
-
64
- struct options parse_options(int argc, char** argv) {
65
- struct options options = (struct options) {
66
- .model = NULL,
67
- .prompt = NULL,
68
- .context_length = 0,
69
- .max_tokens = 0,
70
- .temperature = 0.0f,
71
- .verbose = false,
72
- };
73
- if (argc < 2) {
74
- fprintf(stderr, "Error: missing required command-line argument\n");
75
- print_usage(argv[0]);
76
- exit(EXIT_FAILURE);
77
- }
78
- for (int i = 1; i < argc; i++) {
79
- if (strcmp(argv[i], "--help") == 0) {
80
- print_usage(argv[0]);
81
- exit(EXIT_SUCCESS);
82
- } else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--prompt") == 0) {
83
- if (i + 1 >= argc) {
84
- fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
85
- print_usage(argv[0]);
86
- exit(EXIT_FAILURE);
87
- }
88
- options.prompt = argv[++i];
89
- } else if (strcmp(argv[i], "--context-length") == 0) {
90
- if (i + 1 >= argc) {
91
- fprintf(stderr, "Error: missing argument for --context-length\n");
92
- print_usage(argv[0]);
93
- exit(EXIT_FAILURE);
94
- }
95
- char* context_length_start = argv[++i];
96
- char* context_length_end = context_length_start;
97
- options.context_length = strtoul(context_length_start, &context_length_end, 10);
98
- if (context_length_end == context_length_start || *context_length_end != 0) {
99
- fprintf(stderr, "Error: failed to parse context length value \"%s\"\n", context_length_start);
100
- exit(EXIT_FAILURE);
101
- }
102
- } else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--max-tokens") == 0) {
103
- if (i + 1 >= argc) {
104
- fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
105
- print_usage(argv[0]);
106
- exit(EXIT_FAILURE);
107
- }
108
- char* max_tokens_start = argv[++i];
109
- char* max_tokens_end = max_tokens_start;
110
- options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);
111
- if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {
112
- fprintf(stderr, "Error: failed to max tokens value \"%s\"\n", max_tokens_start);
113
- exit(EXIT_FAILURE);
114
- }
115
- if (options.max_tokens == 0) {
116
- fprintf(stderr, "Error: invalid max tokens value %zu\n", options.max_tokens);
117
- exit(EXIT_FAILURE);
118
- }
119
- } else if (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--temperature") == 0) {
120
- if (i + 1 >= argc) {
121
- fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
122
- print_usage(argv[0]);
123
- exit(EXIT_FAILURE);
124
- }
125
- char* temperature_start = argv[++i];
126
- char* temperature_end = temperature_start;
127
- options.temperature = strtof(temperature_start, &temperature_end);
128
- if (temperature_end == temperature_start || *temperature_end != 0) {
129
- fprintf(stderr, "Error: failed to parse temperature value \"%s\"\n", temperature_start);
130
- exit(EXIT_FAILURE);
131
- }
132
- if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {
133
- fprintf(stderr, "Error: invalid temperature value %f\n", options.temperature);
134
- exit(EXIT_FAILURE);
135
- }
136
- } else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0) {
137
- options.verbose = true;
138
- } else {
139
- if (options.model == NULL) {
140
- options.model = argv[i];
141
- } else {
142
- fprintf(stderr, "Error: unexpected command-line argument %s\n", argv[i]);
143
- print_usage(argv[0]);
144
- exit(EXIT_FAILURE);
145
- }
146
- }
147
- }
148
- if (options.model == NULL) {
149
- fprintf(stderr, "Error: missing required model argument\n");
150
- print_usage(argv[0]);
151
- exit(EXIT_FAILURE);
152
- }
153
- if (options.prompt == NULL) {
154
- fprintf(stderr, "Error: missing required prompt argument\n");
155
- print_usage(argv[0]);
156
- exit(EXIT_FAILURE);
157
- }
158
- return options;
159
- }
160
-
161
-
162
- static void print_profile() {
163
- const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
164
- const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
165
- const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);
166
- const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
167
- const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
168
- if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
169
- printf("\n");
170
- }
171
- if (num_prefill_tokens != 0) {
172
- printf("Prefill speed (%zu tokens): %.1f tokens/second\n",
173
- num_prefill_tokens,
174
- (double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
175
- }
176
- if (num_generated_tokens != 0) {
177
- printf("Generation speed (%zu tokens): %.1f tokens/second\n",
178
- num_generated_tokens,
179
- (double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);
180
- }
181
- }
182
-
183
- static void ctrl_c_handler(int signum) {
184
- print_profile();
185
- exit(EXIT_SUCCESS);
186
- }
187
-
188
- int main(int argc, char *argv[]) {
189
- enum gptoss_status status;
190
- gptoss_model_t model = NULL;
191
- gptoss_tokenizer_t tokenizer = NULL;
192
- gptoss_context_t context = NULL;
193
-
194
- struct sigaction act;
195
- act.sa_handler = ctrl_c_handler;
196
- sigaction(SIGINT, &act, NULL);
197
-
198
- setvbuf(stdout, NULL, _IONBF, 0);
199
-
200
- struct options options = parse_options(argc, argv);
201
-
202
- const uint64_t load_start_time = mach_continuous_time();
203
- status = gptoss_model_create_from_file(options.model, &model);
204
- if (status != gptoss_status_success) {
205
- fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
206
- goto error;
207
- }
208
- size_t max_model_context_length = 0;
209
- status = gptoss_model_get_max_context_length(model, &max_model_context_length);
210
- if (status != gptoss_status_success) {
211
- fprintf(stderr, "Error: failed to query maximum context length\n");
212
- goto error;
213
- }
214
- assert(max_model_context_length != 0);
215
- if (options.context_length == 0) {
216
- options.context_length = max_model_context_length;
217
- } else if (options.context_length > max_model_context_length) {
218
- fprintf(stderr, "Error: context length %zu exceeds maximum context length %zu supported by the model\n", options.context_length, max_model_context_length);
219
- goto error;
220
- }
221
-
222
- status = gptoss_model_get_tokenizer(model, &tokenizer);
223
- if (status != gptoss_status_success) {
224
- fprintf(stderr, "Error: failed to retrieve Tokenizer\n");
225
- goto error;
226
- }
227
-
228
- uint32_t return_token_id = UINT32_MAX;
229
- status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);
230
- if (status != gptoss_status_success) {
231
- fprintf(stderr, "Error: failed to query end-of-text token ID\n");
232
- goto error;
233
- }
234
-
235
- status = gptoss_context_create(model, options.context_length, /*max_batch_tokens=*/0, &context);
236
- if (status != gptoss_status_success) {
237
- fprintf(stderr, "Error: failed to create Context object\n");
238
- goto error;
239
- }
240
- if (options.verbose) {
241
- printf("Model weights size: %.2lf MB\n", (double) model->weights_size * 0x1.0p-20);
242
- printf("Model allocation size: %.2lf MB\n", (double) model->allocation_size * 0x1.0p-20);
243
- printf("Context allocation size: %.2lf MB\n", (double) context->allocation_size * 0x1.0p-20);
244
- printf(" Including KV cache: %.2lf MB\n", (double) context->kvcache_size * 0x1.0p-20);
245
- }
246
-
247
- const uint64_t load_end_time = mach_continuous_time();
248
- const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);
249
- if (options.verbose) {
250
- printf("Loaded model in %.3f seconds\n", load_elapsed_seconds);
251
- }
252
-
253
- const uint64_t prefill_start_time = mach_continuous_time();
254
- size_t num_prefill_tokens = 0;
255
- status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);
256
- if (status != gptoss_status_success) {
257
- fprintf(stderr, "Error: failed to tokenize prompt \"%s\"\n", options.prompt);
258
- goto error;
259
- }
260
- atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);
261
- status = gptoss_context_process(context);
262
- if (status != gptoss_status_success) {
263
- fprintf(stderr, "Error: failed to process Context object\n");
264
- goto error;
265
- }
266
- const uint64_t prefill_end_time = mach_continuous_time();
267
-
268
- while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
269
-
270
- uint32_t predicted_token = UINT32_MAX;
271
- size_t num_predicted_tokens = 0;
272
- const uint64_t inference_start_timestamp = mach_continuous_time();
273
- status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
274
- if (status != gptoss_status_success) {
275
- fprintf(stderr, "Error: failed to sample from the Context object\n");
276
- goto error;
277
- }
278
- const uint64_t inference_end_timestamp = mach_continuous_time();
279
-
280
- if (predicted_token == return_token_id) {
281
- // Yield token -> stop generation
282
- break;
283
- }
284
-
285
- // Unembedding: detokenize
286
- size_t token_size = 0;
287
- const void* token_ptr = NULL;
288
- status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);
289
- if (status != gptoss_status_success) {
290
- fprintf(stderr, "Error: failed to detokenize predicted token %" PRIu32 "\n", predicted_token);
291
- goto error;
292
- }
293
- const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
294
- if (previous_num_generated_tokens == 0) {
295
- atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
296
- } else {
297
- atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
298
- }
299
- printf("%.*s", (int) token_size, (const char*) token_ptr);
300
-
301
- status = gptoss_context_append_tokens(context, 1, &predicted_token);
302
- if (status != gptoss_status_success) {
303
- fprintf(stderr, "Error: failed to append predicted token %" PRIu32 " to context\n", predicted_token);
304
- goto error;
305
- }
306
- }
307
-
308
- print_profile();
309
-
310
- return EXIT_SUCCESS;
311
-
312
- error:
313
- gptoss_context_release(context);
314
- gptoss_tokenizer_release(tokenizer);
315
- gptoss_model_release(model);
316
- return EXIT_FAILURE;
317
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gptoss_kernels/source/include/internal/log.h CHANGED
@@ -2,6 +2,9 @@
2
 
3
  #include <stdarg.h>
4
 
 
 
 
5
 
6
  void gptoss_format_log(const char* format, va_list args);
7
 
@@ -13,6 +16,10 @@ inline static void gptoss_log(const char* format, ...) {
13
  va_end(args);
14
  }
15
 
 
 
 
 
16
  #define GPTOSS_LOG_ERROR(message, ...) \
17
  gptoss_log("Error: " message "\n", ##__VA_ARGS__)
18
 
 
2
 
3
  #include <stdarg.h>
4
 
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
 
9
  void gptoss_format_log(const char* format, va_list args);
10
 
 
16
  va_end(args);
17
  }
18
 
19
+ #ifdef __cplusplus
20
+ } // extern "C"
21
+ #endif
22
+
23
  #define GPTOSS_LOG_ERROR(message, ...) \
24
  gptoss_log("Error: " message "\n", ##__VA_ARGS__)
25
 
gptoss_kernels/source/include/internal/metal.h CHANGED
@@ -1,7 +1,6 @@
1
  #pragma once
2
 
3
  #include <stddef.h>
4
-
5
  #include <gpt-oss/types.h>
6
 
7
  #ifdef __cplusplus
 
1
  #pragma once
2
 
3
  #include <stddef.h>
 
4
  #include <gpt-oss/types.h>
5
 
6
  #ifdef __cplusplus
gptoss_kernels/source/matmul.metal CHANGED
@@ -43,7 +43,10 @@ kernel void gptoss_f32_bf16w_matmul(
43
  bias += row;
44
  output += gid.y * args.num_rows + row;
45
 
46
- uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
 
 
 
47
 
48
  float4 sum4 = 0.0f;
49
  do {
@@ -97,7 +100,10 @@ kernel void gptoss_f32_bf16w_matmul_qkv(
97
  bias += row;
98
  q += gid.y * args.num_rows;
99
 
100
- uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
 
 
 
101
 
102
  float4 sum4 = 0.0f;
103
  do {
 
43
  bias += row;
44
  output += gid.y * args.num_rows + row;
45
 
46
+ uint num_iter = 0;
47
+ if (simdgroup_tid < num_column_vecs) {
48
+ num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
49
+ }
50
 
51
  float4 sum4 = 0.0f;
52
  do {
 
100
  bias += row;
101
  q += gid.y * args.num_rows;
102
 
103
+ uint num_iter = 0;
104
+ if (simdgroup_tid < num_column_vecs) {
105
+ num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
106
+ }
107
 
108
  float4 sum4 = 0.0f;
109
  do {
gptoss_kernels/source/metal.m CHANGED
@@ -9,7 +9,6 @@
9
  #include <internal/log.h>
10
  #include <internal/metal.h>
11
 
12
-
13
  static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
14
  if (!device) {
15
  return 0;
 
9
  #include <internal/log.h>
10
  #include <internal/metal.h>
11
 
 
12
  static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
13
  if (!device) {
14
  return 0;
gptoss_kernels/source/model.c DELETED
@@ -1,581 +0,0 @@
1
- #include <assert.h>
2
- #include <inttypes.h>
3
- #include <stdatomic.h>
4
- #include <stdint.h>
5
- #include <stdlib.h>
6
- #include <string.h>
7
-
8
- #include <errno.h> // errno, EISDIR, ENOENT, ENOTDIR
9
- #include <fcntl.h> // open
10
- #include <mach/vm_page_size.h> // vm_page_size
11
- #include <sys/mman.h> // mmap, PROT_READ, MAP_PRIVATE
12
- #include <sys/stat.h> // fstat, stat
13
- #include <sys/types.h> // off_t, ssize_t
14
- #include <unistd.h> // close
15
-
16
- #include <gpt-oss.h>
17
-
18
- #include "internal/datatype.h"
19
- #include "internal/kernel-args.h" // gptoss_expert_prediction
20
- #include "internal/log.h"
21
- #include "internal/uuid.h"
22
- #include "internal/storage.h"
23
- #include "internal/math.h"
24
- #include "internal/model.h"
25
-
26
-
27
- static size_t round_up_to_page_size(size_t bytes) {
28
- const size_t page_size_mask = (size_t) vm_page_size - 1;
29
- if ((bytes & page_size_mask) != 0) {
30
- bytes |= page_size_mask;
31
- bytes += 1;
32
- }
33
- return bytes;
34
- }
35
-
36
- static size_t round_down_to_page_size(size_t bytes) {
37
- const size_t page_size_mask = (size_t) vm_page_size - 1;
38
- return bytes & ~page_size_mask;
39
- }
40
-
41
- static enum gptoss_status read_fd(int fd, void* data, size_t size, const char* path) {
42
- assert(fd != -1);
43
- assert(data != NULL);
44
- assert(size != 0);
45
-
46
- size_t bytes_to_read = size;
47
- char* current_byte = (char*) data;
48
- do {
49
- const ssize_t read_result = read(fd, current_byte, bytes_to_read);
50
- if (read_result < 0) {
51
- GPTOSS_LOG_ERROR("reading %zu bytes from file %s failed with error %d",
52
- size, path, errno);
53
- return gptoss_status_io_error;
54
- }
55
- current_byte += (size_t) read_result;
56
- bytes_to_read -= (size_t) read_result;
57
- } while (bytes_to_read != 0);
58
- return gptoss_status_success;
59
- }
60
-
61
- static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {
62
- // radvisory.ra_count is int, so we can't prefetch 2GB+ at once
63
- const size_t prefetch_max = round_down_to_page_size((size_t) INT_MAX);
64
- do {
65
- const size_t prefetch_size = math_min(size, prefetch_max);
66
- const struct radvisory ra = {
67
- .ra_offset = offset,
68
- .ra_count = (int) prefetch_size,
69
- };
70
- if (fcntl(fd, F_RDADVISE, &ra) == -1) {
71
- GPTOSS_LOG_WARNING("fcntl(%s, F_RDADVISE, .ra_offset=%zu, .ra_count=%d) failed with error %d\n",
72
- path, (size_t) ra.ra_offset, ra.ra_count, errno);
73
- return;
74
- }
75
- offset += prefetch_size;
76
- size -= prefetch_size;
77
- } while (size != 0);
78
- }
79
-
80
- enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
81
- const char* path,
82
- gptoss_model_t* model_out)
83
- {
84
- *model_out = NULL;
85
-
86
- enum gptoss_status status = gptoss_status_success;
87
- struct gptoss_model* model = NULL;
88
- struct gptoss_tokenizer* tokenizer = NULL;
89
- int fd = -1;
90
- size_t file_offset = 0;
91
-
92
- fd = open(path, O_RDONLY);
93
- if (fd == -1) {
94
- GPTOSS_LOG_ERROR("open(%s) failed with error %d", path, errno);
95
- switch (errno) {
96
- case EISDIR:
97
- case ENOENT:
98
- case ENOTDIR:
99
- status = gptoss_status_invalid_argument;
100
- break;
101
- default:
102
- status = gptoss_status_io_error;
103
- break;
104
- }
105
- goto cleanup;
106
- }
107
-
108
- struct gptoss_file_header file_header;
109
- status = read_fd(fd, &file_header, sizeof(file_header), path);
110
- if (status != gptoss_status_success) {
111
- goto cleanup;
112
- }
113
- file_offset += sizeof(file_header);
114
-
115
- if (file_header.magic[0] != 'G' ||
116
- file_header.magic[1] != 'P' ||
117
- file_header.magic[2] != 'T' ||
118
- file_header.magic[3] != '-' ||
119
- file_header.magic[4] != 'O' ||
120
- file_header.magic[5] != 'S' ||
121
- file_header.magic[6] != 'S' ||
122
- file_header.magic[7] != ' ' ||
123
- file_header.magic[8] != 'v' ||
124
- file_header.magic[9] != '1' ||
125
- file_header.magic[10] != '.' ||
126
- file_header.magic[11] != '0' ||
127
- file_header.zero != 0)
128
- {
129
- GPTOSS_LOG_ERROR("invalid magic in file %s", path);
130
- status = gptoss_status_invalid_argument;
131
- goto cleanup;
132
- }
133
-
134
- struct gptoss_uuid model_uuid;
135
- status = read_fd(fd, &model_uuid, sizeof(model_uuid), path);
136
- if (status != gptoss_status_success) {
137
- goto cleanup;
138
- }
139
- file_offset += sizeof(model_uuid);
140
-
141
- if (!gptoss_is_gptoss_model_uuid(&model_uuid)) {
142
- GPTOSS_LOG_ERROR("unsupported model UUID " UUID_FORMAT, UUID_ARGS(model_uuid));
143
- status = gptoss_status_invalid_argument;
144
- goto cleanup;
145
- }
146
-
147
- struct gptoss_gptoss_model_header model_header;
148
- status = read_fd(fd, &model_header, sizeof(model_header), path);
149
- if (status != gptoss_status_success) {
150
- goto cleanup;
151
- }
152
- file_offset += sizeof(model_header);
153
-
154
- struct gptoss_uuid layout_uuid;
155
- status = read_fd(fd, &layout_uuid, sizeof(layout_uuid), path);
156
- if (status != gptoss_status_success) {
157
- goto cleanup;
158
- }
159
- file_offset += sizeof(layout_uuid);
160
-
161
- if (!gptoss_is_applegpu_layout_uuid(&layout_uuid)) {
162
- GPTOSS_LOG_ERROR("unsupported layout UUID " UUID_FORMAT, UUID_ARGS(layout_uuid));
163
- status = gptoss_status_invalid_argument;
164
- goto cleanup;
165
- }
166
-
167
- const size_t model_size = sizeof(struct gptoss_model) + model_header.num_blocks * sizeof(struct gptoss_metal_buffer);
168
- model = malloc(model_size);
169
- if (model == NULL) {
170
- GPTOSS_LOG_ERROR("failed to allocate %zu bytes for model descriptor", model_size);
171
- status = gptoss_status_insufficient_memory;
172
- goto cleanup;
173
- }
174
- memset(model, 0, model_size);
175
-
176
- atomic_store_explicit(&model->ref_count, 1, memory_order_relaxed);
177
- model->context_length = model_header.context_length;
178
- model->num_blocks = model_header.num_blocks;
179
- model->num_experts = model_header.num_experts;
180
- model->num_active_experts = model_header.num_active_experts;
181
- model->embedding_dim = model_header.embedding_dim;
182
- model->mlp_dim = model_header.mlp_dim;
183
- model->swiglu_limit = model_header.swiglu_limit;
184
- model->head_dim = model_header.head_dim;
185
- model->num_heads = model_header.num_heads;
186
- model->num_kv_heads = model_header.num_kv_heads;
187
- model->attention_window = model_header.attention_window;
188
- model->rope_theta = model_header.rope_theta;
189
- model->interpolation_scale = model_header.interpolation_scale;
190
- model->yarn_offset = model_header.yarn_offset;
191
- model->yarn_scale = model_header.yarn_scale;
192
- model->yarn_multiplier = model_header.yarn_multiplier;
193
- model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;
194
-
195
- struct gptoss_uuid tokenizer_uuid;
196
- status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);
197
- if (status != gptoss_status_success) {
198
- goto cleanup;
199
- }
200
- file_offset += sizeof(tokenizer_uuid);
201
-
202
- if (!gptoss_is_tiktoken_tokenizer_uuid(&tokenizer_uuid)) {
203
- GPTOSS_LOG_ERROR("unsupported tokenizer UUID " UUID_FORMAT, UUID_ARGS(tokenizer_uuid));
204
- status = gptoss_status_invalid_argument;
205
- goto cleanup;
206
- }
207
-
208
- struct gptoss_tiktoken_tokenizer_header tokenizer_header;
209
- status = read_fd(fd, &tokenizer_header, sizeof(tokenizer_header), path);
210
- if (status != gptoss_status_success) {
211
- goto cleanup;
212
- }
213
- file_offset += sizeof(tokenizer_header);
214
-
215
- tokenizer = malloc(sizeof(struct gptoss_tokenizer));
216
- if (tokenizer == NULL) {
217
- GPTOSS_LOG_ERROR("failed to allocate %zu bytes for tokenizer descriptor", sizeof(struct gptoss_tokenizer));
218
- status = gptoss_status_insufficient_memory;
219
- goto cleanup;
220
- }
221
- memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
222
- // Initialize all special token IDs to UINT32_MAX (0xFF in all bytes)
223
- memset(tokenizer->special_token_id, 0xFF, sizeof(tokenizer->special_token_id));
224
-
225
- atomic_store_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
226
- tokenizer->num_special_tokens = tokenizer_header.num_special_tokens;
227
- tokenizer->num_text_tokens = tokenizer_header.num_text_tokens;
228
- model->vocabulary_size = tokenizer_header.num_special_tokens + tokenizer_header.num_text_tokens;
229
- for (uint32_t t = 0; t < tokenizer_header.num_special_tokens; t++) {
230
- struct gptoss_uuid token_uuid;
231
- status = read_fd(fd, &token_uuid, sizeof(token_uuid), path);
232
- if (status != gptoss_status_success) {
233
- goto cleanup;
234
- }
235
- file_offset += sizeof(token_uuid);
236
-
237
- const enum gptoss_special_token token = gptoss_special_token_decode_uuid(&token_uuid);
238
- if (token != gptoss_special_token_invalid) {
239
- tokenizer->special_token_id[token - 1] = tokenizer_header.num_text_tokens + t;
240
- }
241
- }
242
-
243
- const size_t tokenizer_start_offset = file_offset;
244
- const size_t tokenizer_end_offset = tokenizer_start_offset + tokenizer_header.regex_size + tokenizer_header.tokens_size;
245
- const size_t tokenizer_mapping_start = round_down_to_page_size(tokenizer_start_offset);
246
- const size_t tokenizer_mapping_size = round_up_to_page_size(tokenizer_end_offset) - tokenizer_mapping_start;
247
- void* tokenizer_mapping_ptr = mmap(NULL, tokenizer_mapping_size, PROT_READ, MAP_PRIVATE, fd, tokenizer_mapping_start);
248
- if (tokenizer_mapping_ptr == (void*) -1) {
249
- GPTOSS_LOG_ERROR("failed to mmap(%s) tokenizer at offset %zu size %zu",
250
- path, tokenizer_mapping_start, tokenizer_mapping_size);
251
- status = gptoss_status_io_error;
252
- goto cleanup;
253
- }
254
- tokenizer->mapping_ptr = tokenizer_mapping_ptr;
255
- tokenizer->mapping_size = tokenizer_mapping_size;
256
- tokenizer->regex_ptr = (const char*) tokenizer_mapping_ptr + (tokenizer_start_offset - tokenizer_mapping_start);
257
- tokenizer->tokens_ptr = tokenizer->regex_ptr + tokenizer_header.regex_size;
258
-
259
- if (madvise(tokenizer_mapping_ptr, tokenizer_mapping_size, MADV_RANDOM | MADV_WILLNEED) != 0) {
260
- GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, tokenizer_mapping_size, errno);
261
- }
262
-
263
- prefetch_fd(fd, tokenizer_mapping_start, tokenizer_mapping_size, path);
264
-
265
- struct stat model_stat = {0};
266
- int stat_result = fstat(fd, &model_stat);
267
- if (stat_result != 0) {
268
- GPTOSS_LOG_ERROR("stat(%s) failed with error %d", path, errno);
269
- status = gptoss_status_io_error;
270
- goto cleanup;
271
- }
272
-
273
- const size_t model_mapping_start = round_up_to_page_size(tokenizer_end_offset);
274
- const size_t model_mapping_size = round_up_to_page_size((size_t) model_stat.st_size) - model_mapping_start;
275
- void* model_mapping_ptr = mmap(NULL, model_mapping_size, PROT_READ, MAP_PRIVATE, fd, model_mapping_start);
276
- if (model_mapping_ptr == (void*) -1) {
277
- GPTOSS_LOG_ERROR("failed to mmap(%s) model weights at offset %zu size %zu",
278
- path, model_mapping_start, model_mapping_size);
279
- status = gptoss_status_io_error;
280
- goto cleanup;
281
- }
282
- model->mapping_ptr = model_mapping_ptr;
283
- model->mapping_size = model_mapping_size;
284
-
285
- if (madvise(model_mapping_ptr, model_mapping_size, MADV_SEQUENTIAL | MADV_WILLNEED) != 0) {
286
- GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
287
- }
288
-
289
- prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
290
-
291
- if (mlock(model_mapping_ptr, model_mapping_size) != 0) {
292
- GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
293
- } else {
294
- model->lock_memory = true;
295
- }
296
-
297
- // Initialize Metal
298
- status = gptoss_metal_device_create_system_default(&model->device);
299
- if (status != gptoss_status_success) {
300
- goto cleanup;
301
- }
302
- model->max_threadgroups = model->device.num_cores * 3;
303
- status = gptoss_metal_command_queue_create(&model->device, &model->command_queue);
304
- if (status != gptoss_status_success) {
305
- goto cleanup;
306
- }
307
-
308
- // Metal kernels
309
- status = gptoss_metal_library_create_default(&model->device, &model->library);
310
- if (status != gptoss_status_success) {
311
- goto cleanup;
312
- }
313
- status = gptoss_metal_function_create(&model->library, "gptoss_bf16_f32_embeddings", &model->bf16_f32_embeddings_fn);
314
- if (status != gptoss_status_success) {
315
- goto cleanup;
316
- }
317
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_rmsnorm", &model->f32_bf16w_rmsnorm_fn);
318
- if (status != gptoss_status_success) {
319
- goto cleanup;
320
- }
321
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul", &model->f32_bf16w_matmul_fn);
322
- if (status != gptoss_status_success) {
323
- goto cleanup;
324
- }
325
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul_qkv", &model->f32_bf16w_matmul_qkv_fn);
326
- if (status != gptoss_status_success) {
327
- goto cleanup;
328
- }
329
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn);
330
- if (status != gptoss_status_success) {
331
- goto cleanup;
332
- }
333
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_attn_output", &model->f32_bf16w_dense_matmul_attn_output_fn);
334
- if (status != gptoss_status_success) {
335
- goto cleanup;
336
- }
337
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_mlp_gate", &model->f32_bf16w_dense_matmul_mlp_gate_fn);
338
- if (status != gptoss_status_success) {
339
- goto cleanup;
340
- }
341
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn);
342
- if (status != gptoss_status_success) {
343
- goto cleanup;
344
- }
345
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_rope", &model->f32_rope_fn);
346
- if (status != gptoss_status_success) {
347
- goto cleanup;
348
- }
349
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_expert_routing_metadata", &model->f32_expert_routing_metadata_fn);
350
- if (status != gptoss_status_success) {
351
- goto cleanup;
352
- }
353
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_scatter_e4", &model->f32_scatter_e4_fn);
354
- if (status != gptoss_status_success) {
355
- goto cleanup;
356
- }
357
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul_swiglu", &model->f32_mf4w_moe_dense_matmul_swiglu_fn);
358
- if (status != gptoss_status_success) {
359
- goto cleanup;
360
- }
361
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul", &model->f32_mf4w_moe_dense_matmul_fn);
362
- if (status != gptoss_status_success) {
363
- goto cleanup;
364
- }
365
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_gather_and_accumulate_e4", &model->f32_gather_and_accumulate_e4_fn);
366
- if (status != gptoss_status_success) {
367
- goto cleanup;
368
- }
369
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul_swiglu", &model->f32_mf4w_moe_matmul_swiglu_fn);
370
- if (status != gptoss_status_success) {
371
- goto cleanup;
372
- }
373
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul", &model->f32_mf4w_moe_matmul_fn);
374
- if (status != gptoss_status_success) {
375
- goto cleanup;
376
- }
377
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_accumulate_e4", &model->f32_accumulate_e4_fn);
378
- if (status != gptoss_status_success) {
379
- goto cleanup;
380
- }
381
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e32_k4", &model->f32_topk_softmax_e32_k4_fn);
382
- if (status != gptoss_status_success) {
383
- goto cleanup;
384
- }
385
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e128_k4", &model->f32_topk_softmax_e128_k4_fn);
386
- if (status != gptoss_status_success) {
387
- goto cleanup;
388
- }
389
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_softmax", &model->f32_softmax_fn);
390
- if (status != gptoss_status_success) {
391
- goto cleanup;
392
- }
393
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn);
394
- if (status != gptoss_status_success) {
395
- goto cleanup;
396
- }
397
- status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
398
- if (status != gptoss_status_success) {
399
- goto cleanup;
400
- }
401
-
402
- // Kernel launch parameters
403
- model->embeddings_threadgroup_size = 512;
404
- model->attn_qkv_threadgroup_size = 1024;
405
- model->attn_out_threadgroup_size = 768;
406
- model->mlp_gate_threadgroup_size = 256;
407
- model->mlp_swiglu_threadgroup_size = 192;
408
- model->mlp_out_threadgroup_size = 192;
409
- model->mlp_acc_threadgroup_size = 768;
410
- model->unembedding_threadgroup_size = 416;
411
-
412
- // Weight buffers
413
- const char* current_ptr = (const char*) model->mapping_ptr;
414
-
415
- const size_t embedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
416
- model->attn_rmsnorm_gain_offset = embedding_weight_size;
417
- const size_t rmsnorm_weight_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
418
- model->attn_qkv_weight_offset = model->attn_rmsnorm_gain_offset + rmsnorm_weight_size;
419
- const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
420
- const size_t attn_qkv_weight_size = math_round_up_po2(attn_qkv_dim * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
421
- model->attn_qkv_bias_offset = model->attn_qkv_weight_offset + attn_qkv_weight_size;
422
- const size_t attn_qkv_bias_size = math_round_up_po2(attn_qkv_dim * sizeof(gptoss_bfloat16), 16);
423
- model->attn_sdpa_sink_offset = model->attn_qkv_bias_offset + attn_qkv_bias_size;
424
- const size_t attn_sink_weight_size = math_round_up_po2(model->num_heads * sizeof(gptoss_bfloat16), 16);
425
- model->attn_out_weight_offset = model->attn_sdpa_sink_offset + attn_sink_weight_size;
426
- const size_t attn_out_weight_size = math_round_up_po2(model->embedding_dim * model->num_heads * model->head_dim * sizeof(gptoss_bfloat16), 16);
427
- model->attn_out_bias_offset = model->attn_out_weight_offset + attn_out_weight_size;
428
- const size_t attn_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
429
- model->mlp_rmsnorm_gain_offset = model->attn_out_bias_offset + attn_out_bias_size;
430
- model->mlp_gate_weight_offset = model->mlp_rmsnorm_gain_offset + rmsnorm_weight_size;
431
- const size_t mlp_gate_weight_size = math_round_up_po2(model->num_experts * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
432
- model->mlp_gate_bias_offset = model->mlp_gate_weight_offset + mlp_gate_weight_size;
433
- const size_t mlp_gate_bias_size = math_round_up_po2(model->num_experts * sizeof(gptoss_bfloat16), 16);
434
- const size_t per_block_shared_weights_size =
435
- rmsnorm_weight_size + attn_qkv_weight_size + attn_qkv_bias_size + attn_sink_weight_size + attn_out_weight_size + attn_out_bias_size +
436
- rmsnorm_weight_size + mlp_gate_weight_size + mlp_gate_bias_size;
437
- model->rmsnorm_weight_offset = embedding_weight_size + model->num_blocks * per_block_shared_weights_size;
438
- model->unembedding_weight_offset = model->rmsnorm_weight_offset + rmsnorm_weight_size;
439
- const size_t unembedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
440
-
441
- model->per_block_shared_weights_size = per_block_shared_weights_size;
442
- const size_t shared_weights_size =
443
- round_up_to_page_size(embedding_weight_size + rmsnorm_weight_size + unembedding_weight_size + model->num_blocks * per_block_shared_weights_size);
444
-
445
- status = gptoss_metal_buffer_wrap(&model->device, shared_weights_size, current_ptr, &model->shared_weight_buffer);
446
- if (status != gptoss_status_success) {
447
- GPTOSS_LOG_ERROR("failed to map expert-shared weight of size %zu onto a Metal buffer", shared_weights_size);
448
- goto cleanup;
449
- }
450
- current_ptr += shared_weights_size;
451
- model->weights_size += shared_weights_size;
452
-
453
- const size_t mlp_swiglu_weight_block_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 2, 16);
454
- model->mlp_swiglu_scale_offset = mlp_swiglu_weight_block_size;
455
- const size_t mlp_swiglu_weight_scale_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 32, 16);
456
- model->mlp_swiglu_bias_offset = model->mlp_swiglu_scale_offset + mlp_swiglu_weight_scale_size;
457
- const size_t mlp_swiglu_bias_size = math_round_up_po2(2 * model->mlp_dim * sizeof(gptoss_bfloat16), 16);
458
- model->mlp_out_block_offset = model->mlp_swiglu_bias_offset + mlp_swiglu_bias_size;
459
- const size_t mlp_out_weight_block_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 2, 16);
460
- model->mlp_out_scale_offset = model->mlp_out_block_offset + mlp_out_weight_block_size;
461
- const size_t mlp_out_weight_scale_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 32, 16);
462
- model->mlp_out_bias_offset = model->mlp_out_scale_offset + mlp_out_weight_scale_size;
463
- const size_t mlp_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
464
- model->per_expert_block_weight_size =
465
- mlp_swiglu_weight_block_size + mlp_swiglu_weight_scale_size + mlp_swiglu_bias_size + mlp_out_weight_block_size + mlp_out_weight_scale_size + mlp_out_bias_size;
466
- const size_t moe_block_weight_size = round_up_to_page_size(model->num_experts * model->per_expert_block_weight_size);
467
- for (uint32_t n = 0; n < model->num_blocks; n++) {
468
- status = gptoss_metal_buffer_wrap(&model->device, moe_block_weight_size, current_ptr, &model->block_weight_buffers[n]);
469
- if (status != gptoss_status_success) {
470
- GPTOSS_LOG_ERROR("failed to map block #%" PRIu32 " MoE weight of size %zu onto a Metal buffer",
471
- n, moe_block_weight_size);
472
- goto cleanup;
473
- }
474
- current_ptr += moe_block_weight_size;
475
- model->weights_size += moe_block_weight_size;
476
- }
477
-
478
- // Commit tokenizer
479
- model->tokenizer = tokenizer;
480
- tokenizer = NULL;
481
-
482
- // Commit model
483
- *model_out = model;
484
- model = NULL;
485
-
486
- cleanup:
487
- if (fd != -1) {
488
- close(fd);
489
- fd = -1;
490
- }
491
- gptoss_model_release(model); // does nothing if model is NULL
492
- gptoss_tokenizer_release(tokenizer); // does nothing if tokenizer is NULL
493
- return status;
494
- }
495
-
496
- enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
497
- gptoss_model_t model,
498
- gptoss_tokenizer_t* tokenizer_out)
499
- {
500
- gptoss_tokenizer_t tokenizer = model->tokenizer;
501
- atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
502
- *tokenizer_out = tokenizer;
503
- return gptoss_status_success;
504
- }
505
-
506
- enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
507
- gptoss_model_t model,
508
- size_t* max_context_length_out)
509
- {
510
- *max_context_length_out = model->context_length;
511
- return gptoss_status_success;
512
- }
513
-
514
- enum gptoss_status GPTOSS_ABI gptoss_model_retain(
515
- gptoss_model_t model)
516
- {
517
- atomic_fetch_add_explicit(&model->ref_count, 1, memory_order_relaxed);
518
- return gptoss_status_success;
519
- }
520
-
521
- enum gptoss_status GPTOSS_ABI gptoss_model_release(
522
- gptoss_model_t model)
523
- {
524
- if (model != NULL) {
525
- if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {
526
- gptoss_tokenizer_release(model->tokenizer);
527
-
528
- // Weight buffers
529
- gptoss_metal_buffer_release(&model->shared_weight_buffer);
530
- for (uint32_t n = 0; n < model->num_blocks; n++) {
531
- gptoss_metal_buffer_release(&model->block_weight_buffers[n]);
532
- }
533
-
534
- // Metal kernels
535
- gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);
536
- gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);
537
- gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);
538
- gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn);
539
- gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn);
540
- gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn);
541
- gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn);
542
- gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);
543
- gptoss_metal_function_release(&model->f32_rope_fn);
544
- gptoss_metal_function_release(&model->f32_expert_routing_metadata_fn);
545
- gptoss_metal_function_release(&model->f32_scatter_e4_fn);
546
- gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_swiglu_fn);
547
- gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_fn);
548
- gptoss_metal_function_release(&model->f32_gather_and_accumulate_e4_fn);
549
- gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);
550
- gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn);
551
- gptoss_metal_function_release(&model->f32_accumulate_e4_fn);
552
- gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
553
- gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
554
- gptoss_metal_function_release(&model->f32_softmax_fn);
555
- gptoss_metal_function_release(&model->f32_sample_fn);
556
- gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
557
- gptoss_metal_library_release(&model->library);
558
-
559
- gptoss_metal_command_queue_release(&model->command_queue);
560
- gptoss_metal_device_release(&model->device);
561
- // Weight buffers
562
-
563
- if (model->mapping_ptr != NULL && model->mapping_size != 0) {
564
- if (model->lock_memory) {
565
- if (munlock(model->mapping_ptr, model->mapping_size) != 0) {
566
- GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno);
567
- }
568
- }
569
-
570
- if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
571
- GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
572
- }
573
- }
574
-
575
- const size_t model_size = sizeof(struct gptoss_model) + model->num_blocks * sizeof(struct gptoss_metal_buffer);
576
- memset(model, 0, model_size);
577
- free(model);
578
- }
579
- }
580
- return gptoss_status_success;
581
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gptoss_kernels/source/tensor_wrappers.cpp ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <internal/metal-kernels.h>
2
+ #include <internal/metal.h>
3
+ #include <ATen/Tensor.h>
4
+
5
+ void f32_bf16w_matmul_torch(const at::Tensor &input,
6
+ const at::Tensor &weight_bf16,
7
+ const at::Tensor &bias_bf16,
8
+ at::Tensor &output,
9
+ int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size)
10
+ {
11
+ TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
12
+ TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
13
+ TORCH_CHECK(bias_bf16.dtype() == at::kBFloat16, "bias must be bfloat16");
14
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
15
+
16
+ TORCH_CHECK(input.dim() == 2, "input must be 2D");
17
+ TORCH_CHECK(weight_bf16.dim() == 2, "weight must be 2D");
18
+ TORCH_CHECK(bias_bf16.dim() == 1, "bias must be 1D");
19
+ TORCH_CHECK(output.dim() == 2, "output must be 2D");
20
+ TORCH_CHECK(input.size(0) == num_tokens && input.size(1) == num_cols,
21
+ "input shape must be [num_tokens, num_cols]");
22
+ TORCH_CHECK(weight_bf16.size(0) == num_cols && weight_bf16.size(1) == num_rows,
23
+ "weight shape must be [num_cols, num_rows]");
24
+ TORCH_CHECK(bias_bf16.size(0) == num_rows, "bias length must be num_rows");
25
+ TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_rows,
26
+ "output shape must be [num_tokens, num_rows]");
27
+
28
+ auto input_cpu = input.contiguous().to(at::kCPU);
29
+ auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
30
+ auto bias_cpu = bias_bf16.contiguous().to(at::kCPU);
31
+ auto out_cpu = output.detach().to(at::kCPU).contiguous().clone();
32
+
33
+ gptoss_metal_device device{}; gptoss_metal_library library{};
34
+ gptoss_metal_function fn{}; gptoss_metal_command_queue cq{};
35
+ gptoss_metal_command_buffer cb{};
36
+
37
+ TORCH_CHECK(gptoss_metal_device_create_system_default(&device) == gptoss_status_success, "device_create failed");
38
+ TORCH_CHECK(gptoss_metal_library_create_default(&device, &library) == gptoss_status_success, "library_create failed");
39
+ TORCH_CHECK(gptoss_metal_function_create(&library, "gptoss_f32_bf16w_matmul", &fn) == gptoss_status_success, "function_create failed");
40
+ TORCH_CHECK(gptoss_metal_command_queue_create(&device, &cq) == gptoss_status_success, "cq_create failed");
41
+ TORCH_CHECK(gptoss_metal_command_buffer_create(&cq, &cb) == gptoss_status_success, "cb_create failed");
42
+
43
+ const size_t in_bytes = (size_t)num_tokens * (size_t)num_cols * sizeof(float);
44
+ const size_t wt_bytes = (size_t)num_rows * (size_t)num_cols * sizeof(uint16_t);
45
+ const size_t bs_bytes = (size_t)num_rows * sizeof(uint16_t);
46
+ const size_t out_bytes = (size_t)num_tokens * (size_t)num_rows * sizeof(float);
47
+
48
+ gptoss_metal_buffer in_buf{}, wt_buf{}, bs_buf{}, out_buf{}, ctrl_buf{};
49
+ TORCH_CHECK(gptoss_metal_buffer_wrap(&device, in_bytes, input_cpu.data_ptr(), &in_buf) == gptoss_status_success, "wrap input failed");
50
+ TORCH_CHECK(gptoss_metal_buffer_wrap(&device, wt_bytes, weight_cpu.data_ptr(), &wt_buf) == gptoss_status_success, "wrap weight failed");
51
+ TORCH_CHECK(gptoss_metal_buffer_wrap(&device, bs_bytes, bias_cpu.data_ptr(), &bs_buf) == gptoss_status_success, "wrap bias failed");
52
+ TORCH_CHECK(gptoss_metal_buffer_create(&device, out_bytes, nullptr, &out_buf) == gptoss_status_success, "alloc out failed");
53
+ uint32_t ctrl_zero = 0;
54
+ TORCH_CHECK(gptoss_metal_buffer_create(&device, sizeof(uint32_t), &ctrl_zero, &ctrl_buf) == gptoss_status_success, "alloc ctrl failed");
55
+
56
+ TORCH_CHECK(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
57
+ &cb, &fn, (size_t)threadgroup_size,
58
+ &in_buf, 0, &wt_buf, 0, &bs_buf, 0, &out_buf, 0, &ctrl_buf, 0,
59
+ (uint32_t)num_tokens, (uint32_t)num_cols, (uint32_t)num_rows) == gptoss_status_success, "encode failed");
60
+
61
+ TORCH_CHECK(gptoss_metal_command_buffer_commit(&cb) == gptoss_status_success, "commit failed");
62
+ TORCH_CHECK(gptoss_metal_command_buffer_wait_completion(&cb, nullptr) == gptoss_status_success, "wait failed");
63
+
64
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr, out_bytes);
65
+ output.copy_(out_cpu.to(output.device(), /*non_blocking=*/false, /*copy=*/true));
66
+
67
+ (void) gptoss_metal_command_buffer_release(&cb);
68
+ (void) gptoss_metal_command_queue_release(&cq);
69
+ (void) gptoss_metal_function_release(&fn);
70
+ (void) gptoss_metal_library_release(&library);
71
+ (void) gptoss_metal_device_release(&device);
72
+ (void) gptoss_metal_buffer_release(&ctrl_buf);
73
+ (void) gptoss_metal_buffer_release(&out_buf);
74
+ (void) gptoss_metal_buffer_release(&bs_buf);
75
+ (void) gptoss_metal_buffer_release(&wt_buf);
76
+ (void) gptoss_metal_buffer_release(&in_buf);
77
+ }
gptoss_kernels/source/tokenizer.c DELETED
@@ -1,106 +0,0 @@
1
- #include <assert.h>
2
- #include <stdatomic.h>
3
- #include <stddef.h>
4
- #include <stdint.h>
5
- #include <stdlib.h>
6
- #include <string.h>
7
-
8
- #include <errno.h>
9
- #include <sys/mman.h>
10
-
11
- #include <gpt-oss.h>
12
-
13
- #include "internal/log.h"
14
- #include "internal/model.h"
15
-
16
-
17
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
18
- gptoss_tokenizer_t tokenizer,
19
- enum gptoss_special_token token_type,
20
- uint32_t* token_id_out)
21
- {
22
- uint32_t token_id = UINT32_MAX;
23
- if (token_type != gptoss_special_token_invalid && token_type < gptoss_special_token_max)
24
- {
25
- token_id = tokenizer->special_token_id[(uint32_t) token_type - 1];
26
- }
27
- if (token_id == UINT32_MAX) {
28
- return gptoss_status_invalid_argument;
29
- }
30
-
31
- *token_id_out = token_id;
32
- return gptoss_status_success;
33
- }
34
-
35
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
36
- gptoss_tokenizer_t tokenizer,
37
- uint32_t* num_text_tokens_out)
38
- {
39
- *num_text_tokens_out = tokenizer->num_text_tokens;
40
- return gptoss_status_success;
41
- }
42
-
43
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
44
- gptoss_tokenizer_t tokenizer,
45
- uint32_t* num_special_tokens_out)
46
- {
47
- *num_special_tokens_out = tokenizer->num_special_tokens;
48
- return gptoss_status_success;
49
- }
50
-
51
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
52
- gptoss_tokenizer_t tokenizer,
53
- uint32_t* num_tokens_out)
54
- {
55
- *num_tokens_out = tokenizer->num_text_tokens + tokenizer->num_special_tokens;
56
- return gptoss_status_success;
57
- }
58
-
59
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
60
- gptoss_tokenizer_t tokenizer,
61
- uint32_t token_id,
62
- const void** token_ptr_out,
63
- size_t* token_size_out)
64
- {
65
- if (token_id >= tokenizer->num_text_tokens) {
66
- return gptoss_status_invalid_argument;
67
- }
68
-
69
- const char* token_ptr = (const char*) tokenizer->tokens_ptr;
70
- for (uint32_t t = 0; t < token_id; t++) {
71
- // Reading unaligned uint16_t
72
- uint16_t token_length;
73
- memcpy(&token_length, token_ptr, sizeof(token_length));
74
-
75
- token_ptr += (size_t) token_length + sizeof(uint16_t);
76
- }
77
-
78
- *token_ptr_out = (const void*) (token_ptr + sizeof(uint16_t));
79
- *token_size_out = (size_t) *token_ptr;
80
- return gptoss_status_success;
81
- }
82
-
83
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
84
- gptoss_tokenizer_t tokenizer)
85
- {
86
- atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
87
- return gptoss_status_success;
88
- }
89
-
90
- enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
91
- gptoss_tokenizer_t tokenizer)
92
- {
93
- if (tokenizer != NULL) {
94
- if (atomic_fetch_sub_explicit(&tokenizer->ref_count, 1, memory_order_acquire) == 1) {
95
- if (tokenizer->mapping_ptr != NULL && tokenizer->mapping_size != 0) {
96
- if (munmap(tokenizer->mapping_ptr, tokenizer->mapping_size) != 0) {
97
- GPTOSS_LOG_WARNING("munmap for tokenizer mapping failed with error %d", errno);
98
- }
99
- }
100
-
101
- memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
102
- free(tokenizer);
103
- }
104
- }
105
- return gptoss_status_success;
106
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "cmake>=3.26",
4
+ "ninja",
5
+ "packaging",
6
+ "setuptools>=61",
7
+ "torch",
8
+ "wheel",
9
+ ]
10
+ build-backend = "setuptools.build_meta"
setup.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from shutil import which, move
4
+ import subprocess
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from setuptools import Extension, find_packages, setup
9
+ from setuptools.command.build_ext import build_ext
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def is_sccache_available() -> bool:
15
+ return which("sccache") is not None
16
+
17
+
18
+ def is_ccache_available() -> bool:
19
+ return which("ccache") is not None
20
+
21
+
22
+ def is_ninja_available() -> bool:
23
+ return which("ninja") is not None
24
+
25
+
26
+ class CMakeExtension(Extension):
27
+ def __init__(self, name: str, sourcedir: str = "") -> None:
28
+ super().__init__(name, sources=[], py_limited_api=True)
29
+ self.sourcedir = os.fspath(Path(sourcedir).resolve())
30
+
31
+
32
+ class CMakeBuild(build_ext):
33
+ def build_extension(self, ext: CMakeExtension) -> None:
34
+ ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
35
+ extdir = ext_fullpath.parent.resolve()
36
+
37
+ debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
38
+ cfg = "Debug" if debug else "Release"
39
+
40
+ cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
41
+
42
+ # Set Python3_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
43
+ # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
44
+ # from Python.
45
+ cmake_args = [
46
+ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
47
+ f"-DPython3_EXECUTABLE={sys.executable}",
48
+ f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
49
+ ]
50
+ build_args = []
51
+ if "CMAKE_ARGS" in os.environ:
52
+ cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
53
+
54
+ if not cmake_generator or cmake_generator == "Ninja":
55
+ try:
56
+ import ninja
57
+
58
+ ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
59
+ cmake_args += [
60
+ "-GNinja",
61
+ f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
62
+ ]
63
+ except ImportError:
64
+ pass
65
+
66
+ if is_sccache_available():
67
+ cmake_args += [
68
+ "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
69
+ "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
70
+ "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
71
+ "-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache",
72
+ "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache",
73
+ ]
74
+ elif is_ccache_available():
75
+ cmake_args += [
76
+ "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
77
+ "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
78
+ "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
79
+ "-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache",
80
+ "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache",
81
+ ]
82
+
83
+ num_jobs = os.getenv("MAX_JOBS", None)
84
+ if num_jobs is not None:
85
+ num_jobs = int(num_jobs)
86
+ logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
87
+ else:
88
+ try:
89
+ # os.sched_getaffinity() isn't universally available, so fall
90
+ # back to os.cpu_count() if we get an error here.
91
+ num_jobs = len(os.sched_getaffinity(0))
92
+ except AttributeError:
93
+ num_jobs = os.cpu_count()
94
+
95
+ build_temp = Path(self.build_temp) / ext.name
96
+ if not build_temp.exists():
97
+ build_temp.mkdir(parents=True)
98
+
99
+ subprocess.run(
100
+ ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
101
+ )
102
+ subprocess.run(
103
+ ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
104
+ )
105
+
106
+
107
+ setup(
108
+ name="gptoss_kernels",
109
+ # The version is just a stub, it's not used by the final build artefact.
110
+ version="0.1.0",
111
+ ext_modules=[CMakeExtension("gptoss_kernels._gptoss_kernels_931bc1b_dirty")],
112
+ cmdclass={"build_ext": CMakeBuild},
113
+ packages=find_packages(where="torch-ext", include=["gptoss_kernels*"]),
114
+ package_dir={"": "torch-ext"},
115
+ zip_safe=False,
116
+ install_requires=["torch"],
117
+ python_requires=">=3.9",
118
+ )
{gptoss_kernels/test β†’ test}/bf16-f32-embeddings.cc RENAMED
File without changes
{gptoss_kernels/test β†’ test}/embeddings-kernel-tester.hpp RENAMED
File without changes
{gptoss_kernels/test β†’ test}/f32-bf16w-matmul.cc RENAMED
File without changes
{gptoss_kernels/test β†’ test}/f32-bf16w-rmsnorm.cc RENAMED
File without changes
{gptoss_kernels/test β†’ test}/f32-random.cc RENAMED
File without changes
{gptoss_kernels/test β†’ test}/f32-rope.cc RENAMED
File without changes
{gptoss_kernels/test β†’ test}/fill-random-kernel-tester.hpp RENAMED
File without changes
{gptoss_kernels/test β†’ test}/matmul-kernel-tester.hpp RENAMED
File without changes
{gptoss_kernels/test β†’ test}/mf4-f32-convert.cc RENAMED
File without changes
{gptoss_kernels/test β†’ test}/rmsnorm-kernel-tester.hpp RENAMED
File without changes
{gptoss_kernels/test β†’ test}/rope-kernel-tester.hpp RENAMED
File without changes
{gptoss_kernels/test β†’ test}/u32-random.cc RENAMED
File without changes
torch-ext/gptoss_kernels/__init__.py CHANGED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+
4
+ def f32_bf16w_matmul(input: torch.Tensor, weight_bf16: torch.Tensor, bias_bf16: torch.Tensor, output: torch.Tensor, num_tokens: int, num_cols: int, num_rows: int, threadgroup_size: int) -> None:
5
+ ops.f32_bf16w_matmul_torch(input, weight_bf16, bias_bf16, output, num_tokens, num_cols, num_rows, threadgroup_size)
6
+ return output
7
+
8
+ __all__ = ["f32_bf16w_matmul"]
torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (868 Bytes). View file
 
torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (552 Bytes). View file
 
torch-ext/gptoss_kernels/_gptoss_kernels_931bc1b_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31cddc1925c6c7901a5619ff55420ae6249d2c48de202a23a7c4534e4ccdcd4c
3
+ size 126536
torch-ext/gptoss_kernels/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gptoss_kernels_931bc1b_dirty
3
+ ops = torch.ops._gptoss_kernels_931bc1b_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gptoss_kernels_931bc1b_dirty::{op_name}"
torch-ext/gptoss_kernels/test.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import _gptoss_kernels_931bc1b_dirty
2
+ import torch
3
+
4
+ print(dir(_gptoss_kernels_931bc1b_dirty))
5
+
6
+ from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
torch-ext/registration.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Registration macros from vLLM:
2
+ // https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
3
+
4
+ #pragma once
5
+
6
+ #include <Python.h>
7
+
8
+ #define _CONCAT(A, B) A##B
9
+ #define CONCAT(A, B) _CONCAT(A, B)
10
+
11
+ #define _STRINGIFY(A) #A
12
+ #define STRINGIFY(A) _STRINGIFY(A)
13
+
14
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
15
+ // could be a macro instead of a literal token.
16
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
17
+
18
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
19
+ // could be a macro instead of a literal token.
20
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
21
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
22
+
23
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
24
+ // via python's import statement.
25
+ #define REGISTER_EXTENSION(NAME) \
26
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
27
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
28
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
29
+ return PyModule_Create(&module); \
30
+ }
torch-ext/torch_binding.cpp CHANGED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/Tensor.h>
2
+ #include "torch_binding.h"
3
+ #include "registration.h"
4
+
5
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6
+ ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
7
+ ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch);
8
+ }
9
+
10
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void f32_bf16w_matmul_torch(const at::Tensor &input, const at::Tensor &weight_bf16, const at::Tensor &bias_bf16, at::Tensor &output, int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size);