Skip to content

Commit 0bf3d95

Browse files
cianciosacianciosa
authored andcommitted
Inital support for complex metal kernels.
1 parent 99463f5 commit 0bf3d95

5 files changed

Lines changed: 331 additions & 84 deletions

File tree

graph_framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ target_compile_definitions (graph_framework
2626
$<$<BOOL:${SHOW_USE_COUNT}>:SHOW_USE_COUNT>
2727
$<$<BOOL:${USE_INDEX_CACHE}>:USE_INDEX_CACHE>
2828
$<IF:$<BOOL:${USE_VERBOSE}>,USE_VERBOSE=true,USE_VERBOSE=false>
29+
$<$<BOOL:${USE_METAL}>:HEADER_DIR="$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>">
2930
)
3031

3132
target_include_directories (graph_framework

graph_framework/cuda_context.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,11 +973,13 @@ namespace gpu {
973973
source_buffer << "input[i];" << std::endl;
974974
}
975975
source_buffer << " for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl;
976+
source_buffer << " sub_max = max(sub_max, ";
976977
if constexpr (jit::complex_scalar<T>) {
977-
source_buffer << " sub_max = max(sub_max, abs(input[index]));" << std::endl;
978+
source_buffer << "abs(input[index]";
978979
} else {
979-
source_buffer << " sub_max = max(sub_max, input[index]);" << std::endl;
980+
source_buffer << "input[index]";
980981
}
982+
source_buffer << ");" << std::endl;
981983
source_buffer << " }" << std::endl;
982984
source_buffer << " __shared__ " << jit::type_to_string<T> () << " thread_max[32];" << std::endl;
983985
source_buffer << " for (int index = 16; index > 0; index /= 2) {" << std::endl;

graph_framework/jit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ namespace jit {
6464
#ifdef USE_CUDA
6565
gpu::cuda_context<T, SAFE_MATH>,
6666
#elif defined(USE_METAL)
67-
gpu::metal_context<SAFE_MATH>,
67+
gpu::metal_context<T, SAFE_MATH>,
6868
#else
6969
gpu::cpu_context<T, SAFE_MATH>,
7070
#endif

graph_framework/metal_context.hpp

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define metal_context_h
1010

1111
#include <unordered_set>
12+
#include <stdlib.h>
1213

1314
#import <Metal/Metal.h>
1415

@@ -19,9 +20,10 @@ namespace gpu {
1920
//------------------------------------------------------------------------------
2021
/// @brief Class representing a metal gpu context.
2122
///
23+
/// @tparam T Base type of the calculation.
2224
/// @tparam SAFE_MATH Use @ref general_concepts_safe_math operations.
2325
//------------------------------------------------------------------------------
24-
template<bool SAFE_MATH=false>
26+
template<jit::float_scalar T, bool SAFE_MATH=false>
2527
class metal_context {
2628
private:
2729
/// The metal device.
@@ -82,6 +84,7 @@ namespace gpu {
8284
std::vector<std::string> names,
8385
const bool add_reduction=false) {
8486
NSError *error;
87+
setenv("MTL_HEADER_SEARCH_PATHS", HEADER_DIR, 1);
8588
library = [device newLibraryWithSource:[NSString stringWithCString:kernel_source.c_str()
8689
encoding:NSUTF8StringEncoding]
8790
options:compile_options()
@@ -141,9 +144,9 @@ namespace gpu {
141144
std::set<graph::leaf_node<float, SAFE_MATH> *> needed_buffers;
142145

143146
const size_t buffer_element_size = sizeof(float);
144-
for (graph::shared_variable<float, SAFE_MATH> &input : inputs) {
147+
for (graph::shared_variable<T, SAFE_MATH> &input : inputs) {
145148
if (!kernel_arguments.contains(input.get())) {
146-
backend::buffer<float> buffer = input->evaluate();
149+
backend::buffer<T> buffer = input->evaluate();
147150
kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.data()
148151
length:buffer.size()*buffer_element_size
149152
options:MTLResourceStorageModeShared];
@@ -155,7 +158,7 @@ namespace gpu {
155158
needed_buffers.insert(input.get());
156159
}
157160
}
158-
for (graph::shared_leaf<float, SAFE_MATH> &output : outputs) {
161+
for (graph::shared_leaf<T, SAFE_MATH> &output : outputs) {
159162
if (!kernel_arguments.contains(output.get())) {
160163
kernel_arguments[output.get()] = [device newBufferWithLength:num_rays*sizeof(float)
161164
options:MTLResourceStorageModeShared];
@@ -296,8 +299,8 @@ namespace gpu {
296299
/// @param[in] run Function to run before reduction.
297300
/// @returns A lambda function to run the kernel.
298301
//------------------------------------------------------------------------------
299-
std::function<float(void)> create_max_call(graph::shared_leaf<float, SAFE_MATH> &argument,
300-
std::function<void(void)> run) {
302+
std::function<T(void)> create_max_call(graph::shared_leaf<T, SAFE_MATH> &argument,
303+
std::function<void(void)> run) {
301304
MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
302305
compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
303306
compute.computeFunction = [library newFunctionWithName:@"max_reduction"];
@@ -375,7 +378,7 @@ namespace gpu {
375378
/// @param[in] nodes Nodes to output.
376379
//------------------------------------------------------------------------------
377380
void print_results(const size_t index,
378-
const graph::output_nodes<float, SAFE_MATH> &nodes) {
381+
const graph::output_nodes<T, SAFE_MATH> &nodes) {
379382
wait();
380383
for (auto &out : nodes) {
381384
std::cout << static_cast<float *> ([kernel_arguments[out.get()] contents])[index] << " ";
@@ -390,10 +393,10 @@ namespace gpu {
390393
/// @param[in] node Node to check the value for.
391394
/// @returns The value at the index.
392395
//------------------------------------------------------------------------------
393-
float check_value(const size_t index,
394-
const graph::shared_leaf<float, SAFE_MATH> &node) {
396+
T check_value(const size_t index,
397+
const graph::shared_leaf<T, SAFE_MATH> &node) {
395398
wait();
396-
return static_cast<float *> ([kernel_arguments[node.get()] contents])[index];
399+
return static_cast<T *> ([kernel_arguments[node.get()] contents])[index];
397400
}
398401

399402
//------------------------------------------------------------------------------
@@ -402,8 +405,8 @@ namespace gpu {
402405
/// @param[in] node Not to copy buffer to.
403406
/// @param[in] source Host side buffer to copy from.
404407
//------------------------------------------------------------------------------
405-
void copy_to_device(graph::shared_leaf<float, SAFE_MATH> node,
406-
float *source) {
408+
void copy_to_device(graph::shared_leaf<T, SAFE_MATH> node,
409+
T *source) {
407410
const size_t size = [kernel_arguments[node.get()] length];
408411
memcpy([kernel_arguments[node.get()] contents],
409412
source, size);
@@ -415,8 +418,8 @@ namespace gpu {
415418
/// @param[in] node Node to copy buffer from.
416419
/// @param[in,out] destination Host side buffer to copy to.
417420
//------------------------------------------------------------------------------
418-
void copy_to_host(graph::shared_leaf<float, SAFE_MATH> node,
419-
float *destination) {
421+
void copy_to_host(graph::shared_leaf<T, SAFE_MATH> node,
422+
T *destination) {
420423
command_buffer = [queue commandBuffer];
421424

422425
[command_buffer commit];
@@ -436,6 +439,10 @@ namespace gpu {
436439
source_buffer << "#include <metal_stdlib>" << std::endl;
437440
source_buffer << "#include <metal_simdgroup>" << std::endl;
438441
source_buffer << "using namespace metal;" << std::endl;
442+
if constexpr (jit::complex_scalar<T>) {
443+
source_buffer << "#define METAL_DEVICE_CODE" << std::endl;
444+
source_buffer << "#include <special_functions.hpp>" << std::endl;
445+
}
439446
}
440447

441448
//------------------------------------------------------------------------------
@@ -596,8 +603,22 @@ namespace gpu {
596603
<< jit::to_string('v', in.get())
597604
<< "[index] = ";
598605
if constexpr (SAFE_MATH) {
599-
source_buffer << "isnan(" << registers[a.get()]
600-
<< ") ? 0.0 : ";
606+
if constexpr (jit::complex_scalar<T>) {
607+
jit::add_type<T> (source_buffer);
608+
source_buffer << " (";
609+
source_buffer << "isnan(real(" << registers[a.get()]
610+
<< ")) ? 0.0 : real(" << registers[a.get()]
611+
<< "), ";
612+
source_buffer << "isnan(imag(" << registers[a.get()]
613+
<< ")) ? 0.0 : imag(" << registers[a.get()]
614+
<< "));" << std::endl;
615+
} else {
616+
source_buffer << "isnan(" << registers[a.get()]
617+
<< ") ? 0.0 : " << registers[a.get()]
618+
<< ";" << std::endl;
619+
}
620+
} else {
621+
source_buffer << registers[a.get()] << ";" << std::endl;
601622
}
602623
source_buffer << registers[a.get()] << ";" << std::endl;
603624
out_registers.insert(out.get());
@@ -614,8 +635,22 @@ namespace gpu {
614635
source_buffer << " " << jit::to_string('o', out.get())
615636
<< "[index] = ";
616637
if constexpr (SAFE_MATH) {
617-
source_buffer << "isnan(" << registers[a.get()]
618-
<< ") ? 0.0 : ";
638+
if constexpr (jit::complex_scalar<T>) {
639+
jit::add_type<T> (source_buffer);
640+
source_buffer << " (";
641+
source_buffer << "isnan(real(" << registers[a.get()]
642+
<< ")) ? 0.0 : real(" << registers[a.get()]
643+
<< "), ";
644+
source_buffer << "isnan(imag(" << registers[a.get()]
645+
<< ")) ? 0.0 : imag(" << registers[a.get()]
646+
<< "));" << std::endl;
647+
} else {
648+
source_buffer << "isnan(" << registers[a.get()]
649+
<< ") ? 0.0 : " << registers[a.get()]
650+
<< ";" << std::endl;
651+
}
652+
} else {
653+
source_buffer << registers[a.get()] << ";" << std::endl;
619654
}
620655
source_buffer << registers[a.get()] << ";" << std::endl;
621656
out_registers.insert(out.get());
@@ -635,15 +670,30 @@ namespace gpu {
635670
const size_t size) {
636671
source_buffer << std::endl;
637672
source_buffer << "kernel void max_reduction(" << std::endl;
638-
source_buffer << " constant float *input [[buffer(0)]]," << std::endl;
639-
source_buffer << " device float *result [[buffer(1)]]," << std::endl;
673+
source_buffer << " constant ";
674+
jit::add_type<T> (source_buffer);
675+
source_buffer << " *input [[buffer(0)]]," << std::endl;
676+
source_buffer << " device ";
677+
jit::add_type<T> (source_buffer);
678+
source_buffer << " *result [[buffer(1)]]," << std::endl;
640679
source_buffer << " uint i [[thread_position_in_grid]]," << std::endl;
641680
source_buffer << " uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
642681
source_buffer << " uint k [[thread_index_in_simdgroup]]) {" << std::endl;
643682
source_buffer << " if (i < " << size << ") {" << std::endl;
644-
source_buffer << " float sub_max = input[i];" << std::endl;
683+
source_buffer << " " << jit::type_to_string<T> () << " sub_max = ";
684+
if constexpr (jit::complex_scalar<T>) {
685+
source_buffer << "abs(input[i]);" << std::endl;
686+
} else {
687+
source_buffer << "input[i];" << std::endl;
688+
}
645689
source_buffer << " for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl;
646-
source_buffer << " sub_max = max(sub_max, input[index]);" << std::endl;
690+
source_buffer << " sub_max = max(sub_max, ";
691+
if constexpr (jit::complex_scalar<T>) {
692+
source_buffer << "abs(input[index]";
693+
} else {
694+
source_buffer << "input[index]";
695+
}
696+
source_buffer << ");" << std::endl;
647697
source_buffer << " }" << std::endl;
648698
source_buffer << " threadgroup float thread_max[32];" << std::endl;
649699
source_buffer << " thread_max[j] = simd_max(sub_max);" << std::endl;
@@ -660,8 +710,8 @@ namespace gpu {
660710
///
661711
/// @param[in] node Node to get the buffer for.
662712
//------------------------------------------------------------------------------
663-
float *get_buffer(graph::shared_leaf<float, SAFE_MATH> &node) {
664-
return static_cast<float *> ([kernel_arguments[node.get()] contents]);
713+
T *get_buffer(graph::shared_leaf<T, SAFE_MATH> &node) {
714+
return static_cast<T *> ([kernel_arguments[node.get()] contents]);
665715
}
666716
};
667717
}

0 commit comments

Comments
 (0)