-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsetup.py
More file actions
98 lines (80 loc) · 2.45 KB
/
setup.py
File metadata and controls
98 lines (80 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import sys
from pathlib import Path
from setuptools import setup
from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
)
library_name = "digeo"
def get_extensions():
is_windows = sys.platform == "win32"
debug_mode = os.getenv("DEBUG", "0") == "1"
use_cuda = os.getenv("USE_CUDA", "1") == "1"
if debug_mode:
print("Compiling in debug mode")
use_cuda = use_cuda and CUDA_HOME is not None
print(f"Using CUDA: {use_cuda}")
extension = CUDAExtension if use_cuda else CppExtension
extra_link_args = []
extra_compile_args = {"cxx": [], "nvcc": []}
if is_windows:
# MSVC Flags
if debug_mode:
extra_compile_args["cxx"] = [
"/Od",
"/Z7",
"-DTORCH_USE_CUDA_DSA",
]
extra_link_args = ["/DEBUG"]
else:
extra_compile_args["cxx"] = ["/O2"]
else:
# GCC / Clang Flags
if debug_mode:
extra_compile_args["cxx"] = [
"-O0",
"-g",
"-fdiagnostics-color=always",
"-DTORCH_USE_CUDA_DSA",
]
extra_link_args = ["-O0", "-g"]
else:
extra_compile_args["cxx"] = ["-O3", "-fdiagnostics-color=always"]
# Device NVCC Compiler Flags
extra_compile_args["nvcc"] = [
"-O0" if debug_mode else "-O3",
"--fmad=false",
"--prec-div=true",
"--prec-sqrt=true",
]
if debug_mode:
extra_compile_args["nvcc"] += ["-g", "-G", "-lineinfo", "-DTORCH_USE_CUDA_DSA"]
if use_cuda:
extra_compile_args["cxx"] += ["-DUSE_CUDA"]
extra_compile_args["nvcc"] += ["-DUSE_CUDA"]
# Sources
extensions_dir = Path("src/digeo/ops/cuda")
sources = list(extensions_dir.glob("*.cpp"))
cuda_sources = list(extensions_dir.glob("*.cu"))
print(f"Found C++ sources: {[str(s) for s in sources]}")
if use_cuda:
print(f"Found CUDA sources: {[str(s) for s in cuda_sources]}")
if use_cuda:
sources += cuda_sources
ext_modules = [
extension(
f"{library_name}.ops.cuda._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
py_limited_api=False,
)
]
return ext_modules
setup(
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension},
)