theonlyengine commited on
Commit
4a98549
·
verified ·
1 Parent(s): 7f2f8c3

Create setup.py

Browse files
Files changed (1) hide show
  1. setup.py +532 -0
setup.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import sys
4
+ import warnings
5
+ import os
6
+ import re
7
+ import ast
8
+ import glob
9
+ import shutil
10
+ from pathlib import Path
11
+ from packaging.version import parse, Version
12
+ import platform
13
+
14
+ from setuptools import setup, find_packages
15
+ import subprocess
16
+
17
+ import urllib.request
18
+ import urllib.error
19
+ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
20
+
21
+ import torch
22
+ from torch.utils.cpp_extension import (
23
+ BuildExtension,
24
+ CppExtension,
25
+ CUDAExtension,
26
+ CUDA_HOME,
27
+ ROCM_HOME,
28
+ IS_HIP_EXTENSION,
29
+ )
30
+
31
+
32
+ with open("README.md", "r", encoding="utf-8") as fh:
33
+ long_description = fh.read()
34
+
35
+
36
+ # ninja build does not work unless include_dirs are abs path
37
+ this_dir = os.path.dirname(os.path.abspath(__file__))
38
+
39
+ BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
40
+
41
+ if BUILD_TARGET == "auto":
42
+ if IS_HIP_EXTENSION:
43
+ IS_ROCM = True
44
+ else:
45
+ IS_ROCM = False
46
+ else:
47
+ if BUILD_TARGET == "cuda":
48
+ IS_ROCM = False
49
+ elif BUILD_TARGET == "rocm":
50
+ IS_ROCM = True
51
+
52
+ PACKAGE_NAME = "flash_attn"
53
+
54
+ BASE_WHEEL_URL = (
55
+ "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
56
+ )
57
+
58
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
59
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
60
+ FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
61
+ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
62
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
63
+ FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
64
+
65
+
66
+ def get_platform():
67
+ """
68
+ Returns the platform name as used in wheel filenames.
69
+ """
70
+ if sys.platform.startswith("linux"):
71
+ return f'linux_{platform.uname().machine}'
72
+ elif sys.platform == "darwin":
73
+ mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
74
+ return f"macosx_{mac_version}_x86_64"
75
+ elif sys.platform == "win32":
76
+ return "win_amd64"
77
+ else:
78
+ raise ValueError("Unsupported platform: {}".format(sys.platform))
79
+
80
+
81
+ def get_cuda_bare_metal_version(cuda_dir):
82
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
83
+ output = raw_output.split()
84
+ release_idx = output.index("release") + 1
85
+ bare_metal_version = parse(output[release_idx].split(",")[0])
86
+
87
+ return raw_output, bare_metal_version
88
+
89
+
90
+ def check_if_cuda_home_none(global_option: str) -> None:
91
+ if CUDA_HOME is not None:
92
+ return
93
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
94
+ # in that case.
95
+ warnings.warn(
96
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
97
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
98
+ "only images whose names contain 'devel' will provide nvcc."
99
+ )
100
+
101
+
102
+ def check_if_rocm_home_none(global_option: str) -> None:
103
+ if ROCM_HOME is not None:
104
+ return
105
+ # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
106
+ # in that case.
107
+ warnings.warn(
108
+ f"{global_option} was requested, but hipcc was not found."
109
+ )
110
+
111
+
112
+ def append_nvcc_threads(nvcc_extra_args):
113
+ nvcc_threads = os.getenv("NVCC_THREADS") or "4"
114
+ return nvcc_extra_args + ["--threads", nvcc_threads]
115
+
116
+
117
+ def rename_cpp_to_cu(cpp_files):
118
+ for entry in cpp_files:
119
+ shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
120
+
121
+
122
+ def validate_and_update_archs(archs):
123
+ # List of allowed architectures
124
+ allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
125
+
126
+ # Validate if each element in archs is in allowed_archs
127
+ assert all(
128
+ arch in allowed_archs for arch in archs
129
+ ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
130
+
131
+
132
+ cmdclass = {}
133
+ ext_modules = []
134
+
135
+ # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
136
+ # files included in the source distribution, in case the user compiles from source.
137
+ if IS_ROCM:
138
+ subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
139
+ else:
140
+ subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
141
+
142
+ if not SKIP_CUDA_BUILD and not IS_ROCM:
143
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
144
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
145
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
146
+
147
+ # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
148
+ # See https://github.com/pytorch/pytorch/pull/70650
149
+ generator_flag = []
150
+ torch_dir = torch.__path__[0]
151
+ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
152
+ generator_flag = ["-DOLD_GENERATOR_PATH"]
153
+
154
+ check_if_cuda_home_none("flash_attn")
155
+ # Check, if CUDA11 is installed for compute capability 8.0
156
+ cc_flag = []
157
+ if CUDA_HOME is not None:
158
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
159
+ if bare_metal_version < Version("11.6"):
160
+ raise RuntimeError(
161
+ "FlashAttention is only supported on CUDA 11.6 and above. "
162
+ "Note: make sure nvcc has a supported version by running nvcc -V."
163
+ )
164
+ # cc_flag.append("-gencode")
165
+ # cc_flag.append("arch=compute_75,code=sm_75")
166
+ cc_flag.append("-gencode")
167
+ cc_flag.append("arch=compute_80,code=sm_80")
168
+ if CUDA_HOME is not None:
169
+ if bare_metal_version >= Version("11.8"):
170
+ cc_flag.append("-gencode")
171
+ cc_flag.append("arch=compute_90,code=sm_90")
172
+
173
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
174
+ # torch._C._GLIBCXX_USE_CXX11_ABI
175
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
176
+ if FORCE_CXX11_ABI:
177
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
178
+ ext_modules.append(
179
+ CUDAExtension(
180
+ name="flash_attn_2_cuda",
181
+ sources=[
182
+ "csrc/flash_attn/flash_api.cpp",
183
+ "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
184
+ "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
185
+ "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
186
+ "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
187
+ "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
188
+ "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
189
+ "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
190
+ "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
191
+ "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
192
+ "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
193
+ "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
194
+ "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
195
+ "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
196
+ "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
197
+ "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
198
+ "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
199
+ "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
200
+ "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
201
+ "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
202
+ "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
203
+ "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
204
+ "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
205
+ "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
206
+ "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
207
+ "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
208
+ "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
209
+ "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
210
+ "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
211
+ "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
212
+ "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
213
+ "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
214
+ "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
215
+ "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
216
+ "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
217
+ "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
218
+ "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
219
+ "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
220
+ "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
221
+ "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
222
+ "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
223
+ "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
224
+ "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
225
+ "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
226
+ "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
227
+ "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
228
+ "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
229
+ "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
230
+ "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
231
+ "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
232
+ "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
233
+ "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
234
+ "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
235
+ "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
236
+ "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
237
+ "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
238
+ "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
239
+ "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
240
+ "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
241
+ "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
242
+ "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
243
+ "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
244
+ "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
245
+ "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
246
+ "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
247
+ "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
248
+ "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
249
+ "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
250
+ "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
251
+ "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
252
+ "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
253
+ "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
254
+ "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
255
+ "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
256
+ "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
257
+ "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
258
+ "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
259
+ "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
260
+ "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
261
+ "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
262
+ "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
263
+ "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
264
+ "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
265
+ "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
266
+ "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
267
+ ],
268
+ extra_compile_args={
269
+ "cxx": ["-O3", "-std=c++17"] + generator_flag,
270
+ "nvcc": append_nvcc_threads(
271
+ [
272
+ "-O3",
273
+ "-std=c++17",
274
+ "-U__CUDA_NO_HALF_OPERATORS__",
275
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
276
+ "-U__CUDA_NO_HALF2_OPERATORS__",
277
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
278
+ "--expt-relaxed-constexpr",
279
+ "--expt-extended-lambda",
280
+ "--use_fast_math",
281
+ # "--ptxas-options=-v",
282
+ # "--ptxas-options=-O2",
283
+ # "-lineinfo",
284
+ # "-DFLASHATTENTION_DISABLE_BACKWARD",
285
+ # "-DFLASHATTENTION_DISABLE_DROPOUT",
286
+ # "-DFLASHATTENTION_DISABLE_ALIBI",
287
+ # "-DFLASHATTENTION_DISABLE_SOFTCAP",
288
+ # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
289
+ # "-DFLASHATTENTION_DISABLE_LOCAL",
290
+ ]
291
+ + generator_flag
292
+ + cc_flag
293
+ ),
294
+ },
295
+ include_dirs=[
296
+ Path(this_dir) / "csrc" / "flash_attn",
297
+ Path(this_dir) / "csrc" / "flash_attn" / "src",
298
+ Path(this_dir) / "csrc" / "cutlass" / "include",
299
+ ],
300
+ )
301
+ )
302
+ elif not SKIP_CUDA_BUILD and IS_ROCM:
303
+ ck_dir = "csrc/composable_kernel"
304
+
305
+ #use codegen get code dispatch
306
+ if not os.path.exists("./build"):
307
+ os.makedirs("build")
308
+
309
+ os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
310
+ os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
311
+
312
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
313
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
314
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
315
+
316
+ # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
317
+ # See https://github.com/pytorch/pytorch/pull/70650
318
+ generator_flag = []
319
+ torch_dir = torch.__path__[0]
320
+ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
321
+ generator_flag = ["-DOLD_GENERATOR_PATH"]
322
+
323
+ check_if_rocm_home_none("flash_attn")
324
+ cc_flag = []
325
+
326
+ archs = os.getenv("GPU_ARCHS", "native").split(";")
327
+ validate_and_update_archs(archs)
328
+
329
+ cc_flag = [f"--offload-arch={arch}" for arch in archs]
330
+
331
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
332
+ # torch._C._GLIBCXX_USE_CXX11_ABI
333
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
334
+ if FORCE_CXX11_ABI:
335
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
336
+
337
+ sources = ["csrc/flash_attn_ck/flash_api.cpp",
338
+ "csrc/flash_attn_ck/mha_bwd.cpp",
339
+ "csrc/flash_attn_ck/mha_fwd.cpp",
340
+ "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
341
+ "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
342
+ f"build/fmha_*wd*.cpp"
343
+ )
344
+
345
+ rename_cpp_to_cu(sources)
346
+
347
+ renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
348
+ "csrc/flash_attn_ck/mha_bwd.cu",
349
+ "csrc/flash_attn_ck/mha_fwd.cu",
350
+ "csrc/flash_attn_ck/mha_varlen_bwd.cu",
351
+ "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
352
+ extra_compile_args = {
353
+ "cxx": ["-O3", "-std=c++17"] + generator_flag,
354
+ "nvcc":
355
+ [
356
+ "-O3","-std=c++17",
357
+ "-mllvm", "-enable-post-misched=0",
358
+ "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
359
+ "-fgpu-flush-denormals-to-zero",
360
+ "-DCK_ENABLE_BF16",
361
+ "-DCK_ENABLE_BF8",
362
+ "-DCK_ENABLE_FP16",
363
+ "-DCK_ENABLE_FP32",
364
+ "-DCK_ENABLE_FP64",
365
+ "-DCK_ENABLE_FP8",
366
+ "-DCK_ENABLE_INT8",
367
+ "-DCK_USE_XDL",
368
+ "-DUSE_PROF_API=1",
369
+ "-D__HIP_PLATFORM_HCC__=1",
370
+ # "-DFLASHATTENTION_DISABLE_BACKWARD",
371
+ ]
372
+ + generator_flag
373
+ + cc_flag
374
+ ,
375
+ }
376
+
377
+ include_dirs = [
378
+ Path(this_dir) / "csrc" / "composable_kernel" / "include",
379
+ Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
380
+ Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
381
+ ]
382
+
383
+ ext_modules.append(
384
+ CUDAExtension(
385
+ name="flash_attn_2_cuda",
386
+ sources=renamed_sources,
387
+ extra_compile_args=extra_compile_args,
388
+ include_dirs=include_dirs,
389
+ )
390
+ )
391
+
392
+
393
+ def get_package_version():
394
+ with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
395
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
396
+ public_version = ast.literal_eval(version_match.group(1))
397
+ local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
398
+ if local_version:
399
+ return f"{public_version}+{local_version}"
400
+ else:
401
+ return str(public_version)
402
+
403
+
404
+ def get_wheel_url():
405
+ torch_version_raw = parse(torch.__version__)
406
+ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
407
+ platform_name = get_platform()
408
+ flash_version = get_package_version()
409
+ torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
410
+ cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
411
+
412
+ if IS_ROCM:
413
+ torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
414
+ hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
415
+ wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
416
+ else:
417
+ # Determine the version numbers that will be used to determine the correct wheel
418
+ # We're using the CUDA version used to build torch, not the one currently installed
419
+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
420
+ torch_cuda_version = parse(torch.version.cuda)
421
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
422
+ # to save CI time. Minor versions should be compatible.
423
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
424
+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
425
+ cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
426
+
427
+ # Determine wheel URL based on CUDA version, torch version, python version and OS
428
+ wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
429
+
430
+ wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
431
+
432
+ return wheel_url, wheel_filename
433
+
434
+
435
+ class CachedWheelsCommand(_bdist_wheel):
436
+ """
437
+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
438
+ find an existing wheel (which is currently the case for all flash attention installs). We use
439
+ the environment parameters to detect whether there is already a pre-built version of a compatible
440
+ wheel available and short-circuits the standard full build pipeline.
441
+ """
442
+
443
+ def run(self):
444
+ if FORCE_BUILD:
445
+ return super().run()
446
+
447
+ wheel_url, wheel_filename = get_wheel_url()
448
+ print("Guessing wheel URL: ", wheel_url)
449
+ try:
450
+ urllib.request.urlretrieve(wheel_url, wheel_filename)
451
+
452
+ # Make the archive
453
+ # Lifted from the root wheel processing command
454
+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
455
+ if not os.path.exists(self.dist_dir):
456
+ os.makedirs(self.dist_dir)
457
+
458
+ impl_tag, abi_tag, plat_tag = self.get_tag()
459
+ archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
460
+
461
+ wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
462
+ print("Raw wheel path", wheel_path)
463
+ os.rename(wheel_filename, wheel_path)
464
+ except (urllib.error.HTTPError, urllib.error.URLError):
465
+ print("Precompiled wheel not found. Building from source...")
466
+ # If the wheel could not be downloaded, build from source
467
+ super().run()
468
+
469
+
470
+ class NinjaBuildExtension(BuildExtension):
471
+ def __init__(self, *args, **kwargs) -> None:
472
+ # do not override env MAX_JOBS if already exists
473
+ if not os.environ.get("MAX_JOBS"):
474
+ import psutil
475
+
476
+ # calculate the maximum allowed NUM_JOBS based on cores
477
+ max_num_jobs_cores = max(1, os.cpu_count() // 2)
478
+
479
+ # calculate the maximum allowed NUM_JOBS based on free memory
480
+ free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
481
+ max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
482
+
483
+ # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
484
+ max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
485
+ os.environ["MAX_JOBS"] = str(max_jobs)
486
+
487
+ super().__init__(*args, **kwargs)
488
+
489
+
490
+ setup(
491
+ name=PACKAGE_NAME,
492
+ version=get_package_version(),
493
+ packages=find_packages(
494
+ exclude=(
495
+ "build",
496
+ "csrc",
497
+ "include",
498
+ "tests",
499
+ "dist",
500
+ "docs",
501
+ "benchmarks",
502
+ "flash_attn.egg-info",
503
+ )
504
+ ),
505
+ author="Tri Dao",
506
+ author_email="tri@tridao.me",
507
+ description="Flash Attention: Fast and Memory-Efficient Exact Attention",
508
+ long_description=long_description,
509
+ long_description_content_type="text/markdown",
510
+ url="https://github.com/Dao-AILab/flash-attention",
511
+ classifiers=[
512
+ "Programming Language :: Python :: 3",
513
+ "License :: OSI Approved :: BSD License",
514
+ "Operating System :: Unix",
515
+ ],
516
+ ext_modules=ext_modules,
517
+ cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
518
+ if ext_modules
519
+ else {
520
+ "bdist_wheel": CachedWheelsCommand,
521
+ },
522
+ python_requires=">=3.8",
523
+ install_requires=[
524
+ "torch",
525
+ "einops",
526
+ ],
527
+ setup_requires=[
528
+ "packaging",
529
+ "psutil",
530
+ "ninja",
531
+ ],
532
+ )