diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 2e385f2c55..12e61f56e2 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -30,6 +30,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#ifdef TENSORFLOW_USE_LIBXSMM +#include "tensorflow/core/kernels/xsmm_conv2d.h" +#endif #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -88,6 +91,79 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_LIBXSMM +template +struct LaunchXsmmBackwardFilter { + bool operator()(OpKernelContext* context, const Device& d, + typename TTypes::ConstTensor input_backward, + typename TTypes::Tensor kernel, + typename TTypes::ConstTensor output_backward, + int input_rows, int input_cols, int row_stride, + int col_stride, int pad_h, int pad_w, TensorFormat data_format) const { + return false; + } +}; + +template <> +struct LaunchXsmmBackwardFilter { + bool operator()(OpKernelContext* context, const CPUDevice& d, + typename TTypes::ConstTensor input, + typename TTypes::Tensor filter, + typename TTypes::ConstTensor output, + int input_rows, int input_cols, int row_stride, + int col_stride,int pad_h, int pad_w, TensorFormat data_format) const { + auto batch = input.dimension(0); + auto in_depth = input.dimension(3); + auto out_depth = output.dimension(3); + auto filter_rows = filter.dimension(0); + auto filter_cols = filter.dimension(1); + + auto num_threads = + context->device()->tensorflow_cpu_worker_threads()->num_threads; + // See libxsmm_dnn.h for this struct definition. + libxsmm_dnn_conv_desc desc; + desc.N = batch; + desc.C = in_depth; + desc.H = input_rows; + desc.W = input_cols; + desc.K = out_depth; + desc.R = filter_rows; + desc.S = filter_cols; + desc.u = row_stride; + desc.v = col_stride; + desc.pad_h = pad_h; + desc.pad_w = pad_w; + desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now. + desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now. + desc.pad_h_out = 0; + desc.pad_w_out = 0; + desc.threads = num_threads; + desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; + desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; + desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK; + desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; + desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE; + desc.datatype = LIBXSMM_DNN_DATATYPE_F32; + + + if (!CanUseXsmmConv2D(desc, data_format)) { + return false; + } + if(desc.u > 1 || desc.v > 1) + return false; + + auto input_ptr = input.data(); + auto filter_ptr = filter.data(); + auto output_ptr = output.data(); + bool success = functor::XsmmBkwFilterConv2D()( + context, desc, input_ptr, filter_ptr, output_ptr); + return success; + } +}; +#endif + + + template class Conv2DFastBackpropFilterOp : public OpKernel { public: @@ -135,6 +211,39 @@ class Conv2DFastBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); + #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD + + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); + + if ( pad_left == pad_right && pad_top == pad_bottom ) { + + if (LaunchXsmmBackwardFilter()( + context, context->eigen_device(), + input.tensor(),filter_backprop->tensor(), + out_backprop.tensor(), dims.spatial_dims[0].input_size, + dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride, + (int)dims.spatial_dims[1].stride,(int)pad_top, (int)pad_left, data_format_)) { + return; + } + } + #endif + + + + functor::SpatialConvolutionBackwardKernel()( context->eigen_device(), filter_backprop->tensor(), input.tensor(), out_backprop.tensor(), @@ -213,6 +322,19 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); + #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD + if ( pad_left == pad_right && pad_top == pad_bottom ) { + + if (LaunchXsmmBackwardFilter()( + context, context->eigen_device(), + input.tensor(),filter_backprop->tensor(), + out_backprop.tensor(), dims.spatial_dims[0].input_size, + dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride, + (int)dims.spatial_dims[1].stride,(int)pad_top, (int)pad_left, data_format_)) { + return; + } + } + #endif // The total dimension size of each kernel. const int filter_total_size = dims.spatial_dims[0].filter_size * diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 8bc79bebd9..7e0912b4db 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -131,7 +131,7 @@ struct LaunchXsmmBackwardInputConvolution { typename TTypes::ConstTensor kernel, typename TTypes::ConstTensor output_backward, int input_rows, int input_cols, int row_stride, - int col_stride, TensorFormat data_format) const { + int col_stride, int pad_h, int pad_w, TensorFormat data_format) const { return false; } }; @@ -143,7 +143,7 @@ struct LaunchXsmmBackwardInputConvolution { typename TTypes::ConstTensor kernel, typename TTypes::ConstTensor output_backward, int input_rows, int input_cols, int row_stride, - int col_stride, TensorFormat data_format) const { + int col_stride, int pad_h, int pad_w, TensorFormat data_format) const { auto batch = input_backward.dimension(0); auto in_depth = input_backward.dimension(3); auto out_depth = output_backward.dimension(3); @@ -162,10 +162,10 @@ struct LaunchXsmmBackwardInputConvolution { desc.S = filter_cols; desc.u = row_stride; desc.v = col_stride; - desc.pad_h = 0; - desc.pad_w = 0; - desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now. - desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now. + desc.pad_h = pad_h; + desc.pad_w = pad_w; + desc.pad_h_in = 0; + desc.pad_w_in = 0; desc.pad_h_out = 0; desc.pad_w_out = 0; desc.threads = num_threads; @@ -174,7 +174,7 @@ struct LaunchXsmmBackwardInputConvolution { desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; + desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; auto input_ptr = input_backward.data(); @@ -236,13 +236,30 @@ class Conv2DFastBackpropInputOp : public OpKernel { context->allocate_output(0, input_shape, &in_backprop)); #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD - if (LaunchXsmmBackwardInputConvolution()( + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); + + if ( pad_left == pad_right && pad_top == pad_bottom ) { + if (LaunchXsmmBackwardInputConvolution()( context, context->eigen_device(), in_backprop->tensor(), filter.tensor(), out_backprop.tensor(), dims.spatial_dims[0].input_size, - dims.spatial_dims[1].input_size, dims.spatial_dims[0].stride, - dims.spatial_dims[1].stride, data_format_)) { - return; + dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride, + (int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left, data_format_)) { + return; + } } #endif @@ -309,21 +326,38 @@ class Conv2DCustomBackpropInputOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); + // TODO(andydavis) Consider moving code shared with + // Conv2DCustomBackpropFilterOp into a shared helper function. #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD - if (LaunchXsmmBackwardInputConvolution()( + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); + + if ( pad_left == pad_right && pad_top == pad_bottom ) { + if (LaunchXsmmBackwardInputConvolution()( context, context->eigen_device(), in_backprop->tensor(), filter.tensor(), out_backprop.tensor(), dims.spatial_dims[0].input_size, - dims.spatial_dims[1].input_size, dims.spatial_dims[0].stride, - dims.spatial_dims[1].stride, data_format_)) { - return; + dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride, + (int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left, data_format_)) { + return; + } } -#endif - - // TODO(andydavis) Consider moving code shared with - // Conv2DCustomBackpropFilterOp into a shared helper function. +#else int64 pad_top, pad_bottom; int64 pad_left, pad_right; +#endif OP_REQUIRES_OK( context, GetWindowedOutputSizeVerbose( diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index facfe4467d..8076daf387 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -213,8 +213,8 @@ class LaunchXsmmConvOp { desc.v = stride_cols; desc.pad_h = pad_rows; desc.pad_w = pad_cols; - desc.pad_h_in = pad_rows; // libxsmm supports only physical padding for now - desc.pad_w_in = pad_cols; // libxsmm supports only physical padding for now + desc.pad_h_in = 0; + desc.pad_w_in = 0; desc.pad_h_out = 0; desc.pad_w_out = 0; desc.threads = num_threads; @@ -222,13 +222,17 @@ class LaunchXsmmConvOp { desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; + desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; if (!CanUseXsmmConv2D(desc, data_format)) { return false; } + if (!CanUseXsmmConv2D(desc, data_format)) { + return false; + } + auto input_ptr = input.template flat().data(); auto filter_ptr = filter.template flat().data(); auto output_ptr = output->template flat().data(); diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 2ed0522ce4..46e743b4cf 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -837,15 +837,6 @@ class SparseMatMul { }; #ifdef TENSORFLOW_USE_LIBXSMM -#ifdef EXTRA_CACHE_LOGGING -static tensorflow::mutex global_cache_stats_lock; -static int total_num_entries_outstanding GUARDED_BY(global_cache_stats_lock) = - 0; -static int total_num_entries_in_cache GUARDED_BY(global_cache_stats_lock) = 0; -#endif // EXTRA_CACHE_LOGGING - -static const int max_entries_per_graph_node = 40; - template class LibxsmmSparseMatMul { typedef Eigen::Tensor MatrixL; @@ -861,7 +852,6 @@ class LibxsmmSparseMatMul { MatrixMapR; public: -#if 1 // This structure contains a set of libxsmm kernels for sizes that have been // encountered previously by this operator so that libxsmm does not need to // reallocate its scratchpad memory each time (which hurts performance @@ -880,181 +870,57 @@ class LibxsmmSparseMatMul { // useful (it is an empty struct right now) typename SparseMatMul::TensorInfoCache non_libxsmm_cache; // Currently not used - TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCacheEntry); - ~TensorInfoCacheEntry() { -#ifdef EXTRA_CACHE_LOGGING - LOG(INFO) << "Deleting tensor cache entry at " << (void*)this; -#endif // EXTRA_CACHE_LOGGING - libxsmm_spmdm_destroy(&handle); - } }; - // protects entries; invariant: entries is a valid std::list. + // protects entries; invariant: entries is a valid std::multimap tensorflow::mutex lock; // Because there could be multiple matrix multiplies with the same sizes // going on at the same time, we need to allow multiple cache entries for a // given set of parameters. Taking and returning entries is used to make // sure the same cache entry is not used from two threads at a time. - using entries_map_type = std::list, - std::unique_ptr>>; // multimap in LRU order - entries_map_type entries GUARDED_BY( - lock); // MRU element at end so reverse search will find it first - int num_entries_outstanding GUARDED_BY(lock); - - TensorInfoCache() : lock(), entries(), num_entries_outstanding(0) {} + std::multimap, + std::unique_ptr> + entries GUARDED_BY(lock); + + TensorInfoCache() : lock(), entries() {} // Look up and remove first entry with these parameters, creating one if // there isn't one std::unique_ptr take_cache_entry(int M, int K, int N, int max_threads) -#ifdef EXTRA_CACHE_LOGGING - LOCKS_EXCLUDED(lock, global_cache_stats_lock) -#else - LOCKS_EXCLUDED(lock) -#endif - { + LOCKS_EXCLUDED(lock) { tensorflow::mutex_lock ml(lock); -#ifdef EXTRA_CACHE_LOGGING - tensorflow::mutex_lock ml2(global_cache_stats_lock); -#endif auto key = std::make_tuple(M, K, N, max_threads); - auto it_rev = - std::find_if(entries.rbegin(), entries.rend(), - [&](const typename entries_map_type::value_type& e) { - return e.first == key; - }); - auto it = - (it_rev == entries.rend() ? entries.end() : std::next(it_rev).base()); + auto it = entries.find(key); if (it != entries.end()) { auto val = std::move(it->second); entries.erase(it); - ++num_entries_outstanding; -#ifdef EXTRA_CACHE_LOGGING - ++total_num_entries_outstanding; - --total_num_entries_in_cache; - LOG(INFO) << "Used existing cache entry at " << (void*)val.get() - << " for " << M << "x" << K << "x" << N << " max_threads " - << max_threads - << ", num_entries_outstanding = " << num_entries_outstanding - << ", new cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif return val; } else { - while (!entries.empty() && - entries.size() + num_entries_outstanding + 1 > - max_entries_per_graph_node) { -#ifdef EXTRA_CACHE_LOGGING - LOG(INFO) << "Removing old cache entry at " - << (void*)entries.front().second.get(); -#endif - entries.pop_front(); - } std::unique_ptr e{ new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; // setup scoped allocator, which uses cpu_allocator() for this scope const libxsmm_tf_allocator tf_allocator; libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); - ++num_entries_outstanding; -#ifdef EXTRA_CACHE_LOGGING - ++total_num_entries_outstanding; - LOG(INFO) << "Created cache entry at " << (void*)e.get() << " for " << M - << "x" << K << "x" << N << " max_threads " << max_threads - << ", num_entries_outstanding = " << num_entries_outstanding - << ", new cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif return e; } } // Add a cache entry with certain parameters void return_cache_entry(std::unique_ptr e) -#ifdef EXTRA_CACHE_LOGGING - LOCKS_EXCLUDED(lock, global_cache_stats_lock) -#else - LOCKS_EXCLUDED(lock) -#endif - { + LOCKS_EXCLUDED(lock) { tensorflow::mutex_lock ml(lock); -#ifdef EXTRA_CACHE_LOGGING - tensorflow::mutex_lock ml2(global_cache_stats_lock); -#endif auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads); - --num_entries_outstanding; -#ifdef EXTRA_CACHE_LOGGING - --total_num_entries_outstanding; - LOG(INFO) << "Returned cache entry at " << (void*)e.get() << " for " - << e->M << "x" << e->K << "x" << e->N << " max_threads " - << e->max_threads - << ", num_entries_outstanding = " << num_entries_outstanding - << ", prev cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif - entries.push_back(std::make_pair(key, std::move(e))); -#ifdef EXTRA_CACHE_LOGGING - ++total_num_entries_in_cache; -#endif + entries.insert(std::make_pair(key, std::move(e))); } ~TensorInfoCache() { tensorflow::mutex_lock ml(lock); -#ifdef EXTRA_CACHE_LOGGING - tensorflow::mutex_lock ml2(global_cache_stats_lock); - LOG(INFO) << "Deleting TensorInfoCache, cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif - CHECK_EQ(num_entries_outstanding, 0); + for (auto& p : entries) { + libxsmm_spmdm_destroy(&p.second->handle); + } entries.clear(); } private: TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); }; -#else - // This structure contains a set of libxsmm kernels for sizes that have been - // encountered previously by this operator so that libxsmm does not need to - // reallocate its scratchpad memory each time (which hurts performance - // substantially). - struct TensorInfoCache { - struct TensorInfoCacheEntry { - // Parameters for kernel - int M; - int K; - int N; - int max_threads; - // libxsmm handle and matrix data - libxsmm_spmdm_handle handle; - libxsmm_CSR_sparseslice* output_csr; - // Chain to non-libxsmm implementation's cache in case that ever becomes - // useful (it is an empty struct right now) - typename SparseMatMul::TensorInfoCache - non_libxsmm_cache; // Currently not used - }; - TensorInfoCache() {} - // Look up and remove first entry with these parameters, creating one if - // there isn't one - std::unique_ptr take_cache_entry(int M, int K, int N, - int max_threads) { - std::unique_ptr e{ - new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; - libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); - return e; - } - // Add a cache entry with certain parameters - void return_cache_entry(std::unique_ptr e) { - libxsmm_spmdm_destroy(&e->handle); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); - }; -#endif // Perform matrix multiplication of "left" and "right", and store the result // in *"output". @@ -1479,21 +1345,21 @@ inline void SparseMatMul::ComputeBlockSizes( template void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool, - ptrdiff_t max_thread_count, const F& f) { + const F& f) { int num_threads = thread_pool->num_threads; if (num_threads == 0) { LOG(FATAL) << "Have 0 threads in thread pool"; } else if (num_threads == 1) { - f(0, 1); + f(0); } else { BlockingCounter counter(num_threads - 1); for (int i = 1; i < num_threads; ++i) { thread_pool->workers->Schedule([&, i]() { - f(i, num_threads); + f(i); counter.DecrementCount(); }); } - f(0, num_threads); + f(0); counter.Wait(); } } @@ -1522,24 +1388,21 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper, const libxsmm_spmdm_handle* handle, - char transA, char transB, libxsmm_CSR_sparseslice* A_sparse, - const bfloat16* B, char transC, float* C, int block_id, int tid, - int nthreads) { - const uint16 alpha = 1; - const uint16 beta = 0; + char transA, char transB, const bfloat16* alpha, + libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, + const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { return libxsmm_spmdm_compute_bfloat16_thread( - handle, transA, transB, &alpha, A_sparse, - reinterpret_cast(B), transC, &beta, C, block_id, tid, - nthreads); + handle, transA, transB, reinterpret_cast(alpha), A_sparse, + reinterpret_cast(B), transC, + reinterpret_cast(beta), C, block_id, tid, nthreads); } void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper, const libxsmm_spmdm_handle* handle, char transA, - char transB, libxsmm_CSR_sparseslice* A_sparse, const float* B, char transC, - float* C, int block_id, int tid, int nthreads) { - const float alpha = 1.f; - const float beta = 0.f; - return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, &alpha, - A_sparse, B, transC, &beta, C, + char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, + const float* B, char transC, const float* beta, float* C, int block_id, + int tid, int nthreads) { + return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, + A_sparse, B, transC, beta, C, block_id, tid, nthreads); } @@ -1590,13 +1453,11 @@ inline void LibxsmmSparseMatMul::Compute( const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); const int right_dim0 = right.dimension(0); const int right_dim1 = right.dimension(1); - const int output_dim0 = - transpose_output ? output->dimension(1) : output->dimension(0); - const int output_dim1 = - transpose_output ? output->dimension(0) : output->dimension(1); CHECK_EQ(left_dim1, right_dim0); - CHECK_EQ(left_dim0, output_dim0); - CHECK_EQ(right_dim1, output_dim1); + CHECK_EQ(left_dim0, + (transpose_output ? output->dimension(1) : output->dimension(0))); + CHECK_EQ(right_dim1, + (transpose_output ? output->dimension(0) : output->dimension(1))); if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { // Causes problems in libxsmm SparseMatMul::Compute( @@ -1614,50 +1475,42 @@ inline void LibxsmmSparseMatMul::Compute( // Convert the left matrix to compressed sparse row (CSR) format ptrdiff_t total_num_creation_blocks = libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle); - ptrdiff_t total_num_mult_blocks = - libxsmm_spmdm_get_num_compute_blocks(&entry->handle); - bool use_libxsmm = - !(total_num_creation_blocks + total_num_mult_blocks < num_threads && - !transpose_left && !transpose_output); - if (!use_libxsmm) { - // Avoid some performance issues in libxsmm (FIXME) - cache->return_cache_entry(std::move(entry)); - SparseMatMul::Compute( - nullptr /* Assumes no cached data for fallback */, left, right, - transpose_left, thread_pool, transpose_output, output); - return; - } std::atomic cur_create_block_number; cur_create_block_number.store(0); - do_on_all_threads(thread_pool, total_num_creation_blocks, - [&](int i, int actual_num_threads) { - PinnedToCurrentCPU pin; - while (true) { - int work_item = cur_create_block_number.fetch_add(1); - if (work_item >= total_num_creation_blocks) break; - wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( - empty_type_wrapper{}, &entry->handle, - (transpose_left ? 'T' : 'N'), left_data, - entry->output_csr, work_item, i, - actual_num_threads); - } - }); + do_on_all_threads(thread_pool, [&](int i) { + PinnedToCurrentCPU pin; + while (true) { + int work_item = cur_create_block_number.fetch_add(1); + if (work_item >= total_num_creation_blocks) break; + wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( + empty_type_wrapper{}, &entry->handle, + (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item, + i, num_threads); + } + }); // Do matrix-matrix multiplication + // TODO(jewillco): libxsmm doesn't support beta != 1 yet -- remove when + // release + // includes beta handling + memset(output_data, 0, left_dim0 * right_dim1 * sizeof(TR)); + ptrdiff_t total_num_mult_blocks = + libxsmm_spmdm_get_num_compute_blocks(&entry->handle); std::atomic cur_mult_block_number; cur_mult_block_number.store(0); - do_on_all_threads( - thread_pool, total_num_mult_blocks, [&](int i, int actual_num_threads) { - PinnedToCurrentCPU pin; - while (true) { - int work_item = cur_mult_block_number.fetch_add(1); - if (work_item >= total_num_mult_blocks) break; - wrapper_libxsmm_spmdm_compute_generic_thread( - empty_type_wrapper{}, &entry->handle, - (transpose_left ? 'T' : 'N'), 'N', entry->output_csr, right_data, - (transpose_output ? 'T' : 'N'), output_data, work_item, i, - actual_num_threads); - } - }); + do_on_all_threads(thread_pool, [&](int i) { + PinnedToCurrentCPU pin; + while (true) { + int work_item = cur_mult_block_number.fetch_add(1); + if (work_item >= total_num_mult_blocks) break; + const TL alpha(1.0); // Stored in a variable so we can get a pointer + const TL beta(0.0); // Stored in a variable so we can get a pointer + wrapper_libxsmm_spmdm_compute_generic_thread( + empty_type_wrapper{}, &entry->handle, + (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, + right_data, (transpose_output ? 'T' : 'N'), &beta, output_data, + work_item, i, num_threads); + } + }); // Put handle + CSR storage back into cache cache->return_cache_entry(std::move(entry)); } @@ -1803,17 +1656,15 @@ inline void SparseMatMul::Compute( SparseMatMulOp); #endif +REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); + REGISTER_SPARSE_MATMUL(float, bfloat16); REGISTER_SPARSE_MATMUL(bfloat16, float); #ifdef TENSORFLOW_USE_LIBXSMM -REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16); - REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); #else -REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); - REGISTER_SPARSE_MATMUL(float, float); #endif diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h index bff6a0c9b3..61bd6593c3 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.h +++ b/tensorflow/core/kernels/sparse_matmul_op.h @@ -255,13 +255,12 @@ EIGEN_STRONG_INLINE Packet8d pbroadcast_second(const Packet8d& a_in) { } template <> EIGEN_STRONG_INLINE Packet8d pbroadcast_third(const Packet8d& a_in) { - Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1); + Packet2d a = _mm512_extractf32x4_ps(a_in, 1); return _mm512_broadcastsd_pd(a); } template <> EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth(const Packet8d& a_in) { - Packet2d a = - _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3); + Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3); return _mm512_broadcastsd_pd(a); } template <> @@ -418,17 +417,14 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth(const Packet8f& a) { template EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) { - return _mm512_castsi512_ps(_mm512_slli_epi32( - _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))), - 16)); + return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)), + 16); } template EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) { - return _mm512_castsi512_ps( - _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_castpd_si256( - _mm512_extractf64x4_pd(_mm512_castps_pd(from), 1))), - 16)); + return _mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16); } #endif diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 823cdf7e09..c9c53edefc 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -24,16 +24,20 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void); #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/xsmm_conv2d.h" - +#include #include +#include +#if 0 +#include +#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/threadpool.h" #include "include/libxsmm_cpuid.h" -#include "libxsmm_dnn_handle.h" -#include "libxsmm_malloc.h" +#include "include/libxsmm_malloc.h" +#include "libxsmm_main.h" // TODO: API to avoid incl. header from src/ namespace tensorflow { @@ -59,10 +63,6 @@ bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc, VLOG(1) << "Cannot use XSMM convolutions: unsupported format!"; return false; } - if (desc.pad_h_in != 0 || desc.pad_w_in != 0) { - VLOG(1) << "Cannot use XSMM convolutions: unsupported padding!"; - return false; - } if (desc.K % VECTOR_SIZE != 0) { VLOG(1) << "Cannot use XSMM convolutions: output features count not" " divisible by vector size!"; @@ -109,7 +109,6 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, i } } - class libxsmm_dnn_conv_desc_wrap{ public: @@ -127,45 +126,68 @@ class libxsmm_dnn_conv_desc_wrap{ d.S == w.d.S && d.u == w.d.u && d.v == w.d.v && + d.pad_h == w.d.pad_h && + d.pad_w == w.d.pad_w && d.pad_h_in == w.d.pad_h_in && - d.pad_w_in == w.d.pad_w_in - ); + d.pad_w_in == w.d.pad_w_in && + d.pad_h_out == w.d.pad_h_out && + d.pad_w_out == w.d.pad_w_out && + d.threads == w.d.threads && + d.algo == d.algo && + d.buffer_format == w.d.buffer_format && + d.filter_format == w.d.filter_format && + d.fuse_ops == w.d.fuse_ops && + d.options == w.d.options && + d.datatype == w.d.datatype); } }; - - + + struct HashFunction{ std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{ + + + + + //unsigned char ptr[sizeof(&w.d)]; + + + //memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d)) + + + // + std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw; N << w.d.N; C << w.d.C; H << w.d.H; W << w.d.W; K << w.d.K; R << w.d.R; S << w.d.S; u << w.d.u; - v << w.d.v; padh << w.d.pad_h_in; - padw << w.d.pad_w_in; - - + v << w.d.v; padh << w.d.pad_h; + padw << w.d.pad_w; + + std::string out_ = N.str() + C.str()\ + H.str() + W.str()\ + K.str() + R.str()\ + S.str() + u.str()\ + v.str() + padh.str()\ + padw.str(); - return ( std::hash()(out_)); + //return ( std::hash()((unsigned long long)&(w.d))); } }; - + + class handles{ public: libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) { - std::unordered_map::iterator i = libxsmm_handles.find(w); + + std::unique_lock lock(mutex_); + std::unordered_map::iterator i = libxsmm_handles.find(w); if (i == libxsmm_handles.end()){ libxsmm_dnn_err_t status; - libxsmm_dnn_layer* libxsmm_handle = - libxsmm_dnn_create_conv_layer(w.d, &status); + libxsmm_dnn_layer* libxsmm_handle = libxsmm_dnn_create_conv_layer(w.d, &status); chk_libxsmm_err(status, "Create handle"); libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); return libxsmm_handle; @@ -174,23 +196,31 @@ class handles{ return i->second; } ~handles(){ - std::unordered_map::iterator i; + std::unordered_map::iterator i; for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second), "Destroy handle"); } private: - std::unordered_map libxsmm_handles; + std::mutex mutex_; + std::unordered_map libxsmm_handles; + }; + static handles libxsmm_handles; +//#define LIBXSMM_DETAILED_TIMING + template static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, libxsmm_dnn_compute_kind kind, InputPtr input, FilterPtr filter, OutputPtr output) { +#if defined(LIBXSMM_DETAILED_TIMING) + unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6, l_tick7, l_tick8, l_tick9, l_tick10; + l_tick1 = libxsmm_timer_tick(); +#endif // setup scoped allocator, which adopts the allocator from the context const libxsmm_tf_allocator tf_allocator(*ctx); libxsmm_dnn_err_t status; @@ -198,17 +228,17 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, libxsmm_dnn_conv_desc_wrap w(desc); void* scratch; - if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) - libxsmm_handle = libxsmm_handles.find(w); - else { - libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); - chk_libxsmm_err(status, "Create handle"); - } + //if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) + libxsmm_handle = libxsmm_handles.find(w); + //else{ + // libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); + // chk_libxsmm_err(status, "Create handle"); + //} status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); if (status == LIBXSMM_DNN_WARN_FALLBACK) { - chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), - "Destroy handle"); + //chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), + // "Destroy handle"); return false; // Use non-libxsmm code } chk_libxsmm_err(status, "Check codegen status"); @@ -217,57 +247,67 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, libxsmm_dnn_buffer* libxsmm_output; libxsmm_dnn_filter* libxsmm_filter; - /* - const DeviceBase::CpuWorkerThreads* worker_threads = - ctx->device()->tensorflow_cpu_worker_threads(); - - int num_threads = worker_threads->num_threads; -*/ +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick2 = libxsmm_timer_tick(); +#endif int ifmblock = (libxsmm_handle->ifmblock); - int ofmblock = (libxsmm_handle->ofmblock); + int ofmblock = (libxsmm_handle->ofmblock); - int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1; + int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1; int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1; - float *native_filter = (float*)libxsmm_aligned_scratch( - blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), - 2097152); + float *native_filter = (float*)libxsmm_aligned_scratch( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152); + const DeviceBase::CpuWorkerThreads* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; - - if(blocksofm > num_threads){ - int work = blocksofm; - BlockingCounter count(num_threads); - for (int i = 0; i < num_threads; ++i) { +#if 1 + if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || kind == LIBXSMM_DNN_COMPUTE_KIND_BWD){ + if(blocksofm > num_threads){ + int work = blocksofm; + BlockingCounter count(num_threads); + for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &count]() { int start = work/num_threads*i; - int end = (start + work/num_threads) > work ? work: start + work/num_threads; + int end = (start + work/num_threads) > work ? work: start + work/num_threads; copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock,start, end); count.DecrementCount(); }); + } + count.Wait(); } - count.Wait(); - } - else{ - - int work = blocksofm; - int num_threads = work; - - BlockingCounter count(num_threads); - for (int i = 0; i < num_threads; ++i) { + else{ + + int work = blocksofm; + int num_threads = work; + + BlockingCounter count(num_threads); + for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &count]() { int start = i; int end = i+1; copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock, start, end); count.DecrementCount(); }); + } + count.Wait(); } - count.Wait(); } + //Added: for weight update + else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD){ + libxsmm_filter = libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status); + chk_libxsmm_err(status, "Link filter");//weight update is in RSCK as filter should be returned in RSCK format + } +#else + memset( native_filter, 0, blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float)); +#endif + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick3 = libxsmm_timer_tick(); +#endif libxsmm_input = libxsmm_dnn_link_buffer( libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); @@ -275,14 +315,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, libxsmm_output = libxsmm_dnn_link_buffer( libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); chk_libxsmm_err(status, "Link output buffer"); + if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || kind == LIBXSMM_DNN_COMPUTE_KIND_BWD){ libxsmm_filter = libxsmm_dnn_link_filter( libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); chk_libxsmm_err(status, "Link filter"); - - chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output"); - - + } if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { + chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output"); + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT), "Bind input forward"); chk_libxsmm_err( @@ -290,7 +330,9 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, "Bind output forward"); chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER), "Bind filter forward"); - } else { + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { + chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input"); + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT), "Bind input backward"); chk_libxsmm_err( @@ -298,17 +340,42 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, "Bind output backward"); chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER), "Bind filter backward"); + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter"); + + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT), + "Bind input weight udpate"); + chk_libxsmm_err( + libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT), + "Bind output weight update"); + chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_GRADIENT_FILTER), + "Bind filter weight update"); + } else { + /* shouldn't happen */ } +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick4 = libxsmm_timer_tick(); +#endif + /* bind scratch */ - scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152); + scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status ), 2097152); chk_libxsmm_err( status, "scratch allocation" ); - chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" ); + chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ), "binding scratch" ); + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick5 = libxsmm_timer_tick(); +#endif if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER); } +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick6 = libxsmm_timer_tick(); +#endif + +#if 1 BlockingCounter counter(num_threads); for (int i = 0; i < num_threads; ++i) { @@ -319,6 +386,24 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, }); } counter.Wait(); +#else + #pragma omp parallel + { + chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()), "Worker"); + } +#endif + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick7 = libxsmm_timer_tick(); +#endif + + if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + libxsmm_dnn_reduce_wu_filters( libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER ); + } + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick8 = libxsmm_timer_tick(); +#endif /* clean up */ chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" ); @@ -326,21 +411,47 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" ); chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" ); chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" ); - } else { + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" ); chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" ); chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" ); + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" ); + chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" ); + chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER ), "release filter" ); + } else { + /* shouldn't happen */ } chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick9 = libxsmm_timer_tick(); +#endif - if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) - chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), - "Destroy handle"); + //if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) + //chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), + // "Destroy handle"); libxsmm_free(native_filter); libxsmm_free(scratch); + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick10 = libxsmm_timer_tick(); + printf("time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", desc.N, desc.C, desc.K, desc.R, desc.S, + libxsmm_timer_duration(l_tick1, l_tick2), + libxsmm_timer_duration(l_tick2, l_tick3), + libxsmm_timer_duration(l_tick3, l_tick4), + libxsmm_timer_duration(l_tick4, l_tick5), + libxsmm_timer_duration(l_tick5, l_tick6), + libxsmm_timer_duration(l_tick6, l_tick7), + libxsmm_timer_duration(l_tick7, l_tick8), + libxsmm_timer_duration(l_tick8, l_tick9), + libxsmm_timer_duration(l_tick9, l_tick10), + libxsmm_timer_duration(l_tick1, l_tick10) ); +#endif + return true; // Succeeded } diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index dac04440d0..818a7e64bd 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -89,11 +89,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "libxsmm_archive", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.7.1.tar.gz", - "https://github.com/hfp/libxsmm/archive/1.7.1.tar.gz", + "http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.8.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.8.tar.gz", ], - sha256 = "9d3f63ce3eed62f04e4036de6f2be2ce0ff07781ca571af6e0bf85b077edf17a", - strip_prefix = "libxsmm-1.7.1", + sha256 = "0330201afb5525d0950ec861fec9dd75eb40a03845ebe03d2c635cf8bfc14fea", + strip_prefix = "libxsmm-1.8", build_file = str(Label("//third_party:libxsmm.BUILD")), ) diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index 037009c072..9559f8be80 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -56,22 +56,26 @@ genrule( cc_library( name = "xsmm_avx", srcs = [ - "src/libxsmm_main.c", + "src/libxsmm_cpuid_x86.c", + "src/libxsmm_dnn.c", + "src/libxsmm_dnn_convolution_backward.c", + "src/libxsmm_dnn_convolution_forward.c", + "src/libxsmm_dnn_convolution_weight_update.c", + "src/libxsmm_dnn_convolution_winograd_backward.c", + "src/libxsmm_dnn_convolution_winograd_forward.c", + "src/libxsmm_dnn_convolution_winograd_weight_update.c", + "src/libxsmm_dnn_handle.c", "src/libxsmm_dump.c", - "src/libxsmm_malloc.c", + "src/libxsmm_fsspmdm.c", "src/libxsmm_gemm.c", + "src/libxsmm_main.c", + "src/libxsmm_malloc.c", + "src/libxsmm_perf.c", + "src/libxsmm_spmdm.c", + "src/libxsmm_sync.c", "src/libxsmm_timer.c", "src/libxsmm_trace.c", "src/libxsmm_trans.c", - "src/libxsmm_sync.c", - "src/libxsmm_perf.c", - "src/libxsmm_spmdm.c", - "src/libxsmm_dnn.c", - "src/libxsmm_dnn_handle.c", - "src/libxsmm_dnn_convolution_forward.c", - "src/libxsmm_dnn_convolution_backward.c", - "src/libxsmm_dnn_convolution_weight_update.c", - "src/libxsmm_cpuid_x86.c", ] + glob([ "src/generator_*.c", ]), @@ -79,6 +83,7 @@ cc_library( "include/libxsmm_cpuid.h", "include/libxsmm_dnn.h", "include/libxsmm_frontend.h", + "include/libxsmm_fsspmdm.h", "include/libxsmm_generator.h", "include/libxsmm_intrinsics_x86.h", "include/libxsmm_macros.h", @@ -91,14 +96,15 @@ cc_library( "include/libxsmm.h", "include/libxsmm_config.h", "include/libxsmm_dispatch.h", - ], + ] + glob([ # trigger rebuild if template changed + "src/template/*.c", + ]), copts = [ "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings. "-Wno-vla", # Libxsmm convolutions heavily use VLA. ], defines = [ "LIBXSMM_BUILD", - "LIBXSMM_CPUID_X86_NOINLINE", "__BLAS=0", ], includes = [ @@ -111,24 +117,28 @@ cc_library( py_library( name = "libxsmm_scripts", + srcs_version = "PY2AND3", srcs = glob(["scripts/*.py"]), data = ["version.txt"], ) py_binary( name = "libxsmm_interface", + srcs_version = "PY2AND3", srcs = ["scripts/libxsmm_interface.py"], deps = [":libxsmm_scripts"], ) py_binary( name = "libxsmm_config", + srcs_version = "PY2AND3", srcs = ["scripts/libxsmm_config.py"], deps = [":libxsmm_scripts"], ) py_binary( name = "libxsmm_dispatch", + srcs_version = "PY2AND3", srcs = ["scripts/libxsmm_dispatch.py"], deps = [":libxsmm_scripts"], )