99#define metal_context_h
1010
1111#include < unordered_set>
12+ #include < stdlib.h>
1213
1314#import < Metal/Metal.h>
1415
@@ -19,9 +20,10 @@ namespace gpu {
1920// ------------------------------------------------------------------------------
2021// / @brief Class representing a metal gpu context.
2122// /
23+ // / @tparam T Base type of the calculation.
2224// / @tparam SAFE_MATH Use @ref general_concepts_safe_math operations.
2325// ------------------------------------------------------------------------------
24- template <bool SAFE_MATH=false >
26+ template <jit::float_scalar T, bool SAFE_MATH=false >
2527 class metal_context {
2628 private:
2729// / The metal device.
@@ -82,6 +84,7 @@ namespace gpu {
8284 std::vector<std::string> names,
8385 const bool add_reduction=false ) {
8486 NSError *error;
87+ setenv (" MTL_HEADER_SEARCH_PATHS" , HEADER_DIR, 1 );
8588 library = [device newLibraryWithSource:[NSString stringWithCString:kernel_source.c_str ()
8689 encoding:NSUTF8StringEncoding]
8790 options:compile_options ()
@@ -141,9 +144,9 @@ namespace gpu {
141144 std::set<graph::leaf_node<float , SAFE_MATH> *> needed_buffers;
142145
143146 const size_t buffer_element_size = sizeof (float );
144- for (graph::shared_variable<float , SAFE_MATH> &input : inputs) {
147+ for (graph::shared_variable<T , SAFE_MATH> &input : inputs) {
145148 if (!kernel_arguments.contains (input.get ())) {
146- backend::buffer<float > buffer = input->evaluate ();
149+ backend::buffer<T > buffer = input->evaluate ();
147150 kernel_arguments[input.get ()] = [device newBufferWithBytes:buffer.data ()
148151 length:buffer.size ()*buffer_element_size
149152 options:MTLResourceStorageModeShared];
@@ -155,7 +158,7 @@ namespace gpu {
155158 needed_buffers.insert (input.get ());
156159 }
157160 }
158- for (graph::shared_leaf<float , SAFE_MATH> &output : outputs) {
161+ for (graph::shared_leaf<T , SAFE_MATH> &output : outputs) {
159162 if (!kernel_arguments.contains (output.get ())) {
160163 kernel_arguments[output.get ()] = [device newBufferWithLength:num_rays*sizeof (float )
161164 options:MTLResourceStorageModeShared];
@@ -296,8 +299,8 @@ namespace gpu {
296299// / @param[in] run Function to run before reduction.
297300// / @returns A lambda function to run the kernel.
298301// ------------------------------------------------------------------------------
299- std::function<float (void )> create_max_call (graph::shared_leaf<float , SAFE_MATH> &argument,
300- std::function<void (void )> run) {
302+ std::function<T (void )> create_max_call (graph::shared_leaf<T , SAFE_MATH> &argument,
303+ std::function<void (void )> run) {
301304 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
302305 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
303306 compute.computeFunction = [library newFunctionWithName:@" max_reduction" ];
@@ -375,7 +378,7 @@ namespace gpu {
375378// / @param[in] nodes Nodes to output.
376379// ------------------------------------------------------------------------------
377380 void print_results (const size_t index,
378- const graph::output_nodes<float , SAFE_MATH> &nodes) {
381+ const graph::output_nodes<T , SAFE_MATH> &nodes) {
379382 wait ();
380383 for (auto &out : nodes) {
381384 std::cout << static_cast <float *> ([kernel_arguments[out.get ()] contents])[index] << " " ;
@@ -390,10 +393,10 @@ namespace gpu {
390393// / @param[in] node Node to check the value for.
391394// / @returns The value at the index.
392395// ------------------------------------------------------------------------------
393- float check_value (const size_t index,
394- const graph::shared_leaf<float , SAFE_MATH> &node) {
396+ T check_value (const size_t index,
397+ const graph::shared_leaf<T , SAFE_MATH> &node) {
395398 wait ();
396- return static_cast <float *> ([kernel_arguments[node.get ()] contents])[index];
399+ return static_cast <T *> ([kernel_arguments[node.get ()] contents])[index];
397400 }
398401
399402// ------------------------------------------------------------------------------
@@ -402,8 +405,8 @@ namespace gpu {
402405// / @param[in] node Not to copy buffer to.
403406// / @param[in] source Host side buffer to copy from.
404407// ------------------------------------------------------------------------------
405- void copy_to_device (graph::shared_leaf<float , SAFE_MATH> node,
406- float *source) {
408+ void copy_to_device (graph::shared_leaf<T , SAFE_MATH> node,
409+ T *source) {
407410 const size_t size = [kernel_arguments[node.get ()] length];
408411 memcpy ([kernel_arguments[node.get ()] contents],
409412 source, size);
@@ -415,8 +418,8 @@ namespace gpu {
415418// / @param[in] node Node to copy buffer from.
416419// / @param[in,out] destination Host side buffer to copy to.
417420// ------------------------------------------------------------------------------
418- void copy_to_host (graph::shared_leaf<float , SAFE_MATH> node,
419- float *destination) {
421+ void copy_to_host (graph::shared_leaf<T , SAFE_MATH> node,
422+ T *destination) {
420423 command_buffer = [queue commandBuffer];
421424
422425 [command_buffer commit];
@@ -436,6 +439,10 @@ namespace gpu {
436439 source_buffer << " #include <metal_stdlib>" << std::endl;
437440 source_buffer << " #include <metal_simdgroup>" << std::endl;
438441 source_buffer << " using namespace metal;" << std::endl;
442+ if constexpr (jit::complex_scalar<T>) {
443+ source_buffer << " #define METAL_DEVICE_CODE" << std::endl;
444+ source_buffer << " #include <special_functions.hpp>" << std::endl;
445+ }
439446 }
440447
441448// ------------------------------------------------------------------------------
@@ -596,8 +603,22 @@ namespace gpu {
596603 << jit::to_string (' v' , in.get ())
597604 << " [index] = " ;
598605 if constexpr (SAFE_MATH) {
599- source_buffer << " isnan(" << registers[a.get ()]
600- << " ) ? 0.0 : " ;
606+ if constexpr (jit::complex_scalar<T>) {
607+ jit::add_type<T> (source_buffer);
608+ source_buffer << " (" ;
609+ source_buffer << " isnan(real(" << registers[a.get ()]
610+ << " )) ? 0.0 : real(" << registers[a.get ()]
611+ << " ), " ;
612+ source_buffer << " isnan(imag(" << registers[a.get ()]
613+ << " )) ? 0.0 : imag(" << registers[a.get ()]
614+ << " ));" << std::endl;
615+ } else {
616+ source_buffer << " isnan(" << registers[a.get ()]
617+ << " ) ? 0.0 : " << registers[a.get ()]
618+ << " ;" << std::endl;
619+ }
620+ } else {
621+ source_buffer << registers[a.get ()] << " ;" << std::endl;
601622 }
602623 source_buffer << registers[a.get ()] << " ;" << std::endl;
603624 out_registers.insert (out.get ());
@@ -614,8 +635,22 @@ namespace gpu {
614635 source_buffer << " " << jit::to_string (' o' , out.get ())
615636 << " [index] = " ;
616637 if constexpr (SAFE_MATH) {
617- source_buffer << " isnan(" << registers[a.get ()]
618- << " ) ? 0.0 : " ;
638+ if constexpr (jit::complex_scalar<T>) {
639+ jit::add_type<T> (source_buffer);
640+ source_buffer << " (" ;
641+ source_buffer << " isnan(real(" << registers[a.get ()]
642+ << " )) ? 0.0 : real(" << registers[a.get ()]
643+ << " ), " ;
644+ source_buffer << " isnan(imag(" << registers[a.get ()]
645+ << " )) ? 0.0 : imag(" << registers[a.get ()]
646+ << " ));" << std::endl;
647+ } else {
648+ source_buffer << " isnan(" << registers[a.get ()]
649+ << " ) ? 0.0 : " << registers[a.get ()]
650+ << " ;" << std::endl;
651+ }
652+ } else {
653+ source_buffer << registers[a.get ()] << " ;" << std::endl;
619654 }
620655 source_buffer << registers[a.get ()] << " ;" << std::endl;
621656 out_registers.insert (out.get ());
@@ -635,15 +670,30 @@ namespace gpu {
635670 const size_t size) {
636671 source_buffer << std::endl;
637672 source_buffer << " kernel void max_reduction(" << std::endl;
638- source_buffer << " constant float *input [[buffer(0)]]," << std::endl;
639- source_buffer << " device float *result [[buffer(1)]]," << std::endl;
673+ source_buffer << " constant " ;
674+ jit::add_type<T> (source_buffer);
675+ source_buffer << " *input [[buffer(0)]]," << std::endl;
676+ source_buffer << " device " ;
677+ jit::add_type<T> (source_buffer);
678+ source_buffer << " *result [[buffer(1)]]," << std::endl;
640679 source_buffer << " uint i [[thread_position_in_grid]]," << std::endl;
641680 source_buffer << " uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
642681 source_buffer << " uint k [[thread_index_in_simdgroup]]) {" << std::endl;
643682 source_buffer << " if (i < " << size << " ) {" << std::endl;
644- source_buffer << " float sub_max = input[i];" << std::endl;
683+ source_buffer << " " << jit::type_to_string<T> () << " sub_max = " ;
684+ if constexpr (jit::complex_scalar<T>) {
685+ source_buffer << " abs(input[i]);" << std::endl;
686+ } else {
687+ source_buffer << " input[i];" << std::endl;
688+ }
645689 source_buffer << " for (size_t index = i + 1024; index < " << size <<" ; index += 1024) {" << std::endl;
646- source_buffer << " sub_max = max(sub_max, input[index]);" << std::endl;
690+ source_buffer << " sub_max = max(sub_max, " ;
691+ if constexpr (jit::complex_scalar<T>) {
692+ source_buffer << " abs(input[index]" ;
693+ } else {
694+ source_buffer << " input[index]" ;
695+ }
696+ source_buffer << " );" << std::endl;
647697 source_buffer << " }" << std::endl;
648698 source_buffer << " threadgroup float thread_max[32];" << std::endl;
649699 source_buffer << " thread_max[j] = simd_max(sub_max);" << std::endl;
@@ -660,8 +710,8 @@ namespace gpu {
660710// /
661711// / @param[in] node Node to get the buffer for.
662712// ------------------------------------------------------------------------------
663- float *get_buffer (graph::shared_leaf<float , SAFE_MATH> &node) {
664- return static_cast <float *> ([kernel_arguments[node.get ()] contents]);
713+ T *get_buffer (graph::shared_leaf<T , SAFE_MATH> &node) {
714+ return static_cast <T *> ([kernel_arguments[node.get ()] contents]);
665715 }
666716 };
667717}
0 commit comments