-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmat.h
More file actions
173 lines (141 loc) · 3.96 KB
/
mat.h
File metadata and controls
173 lines (141 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
//
// Created by xixuan on 4/6/17.
//
#ifndef MAT_H
#define MAT_H
#include <string>
#include <cublas_v2.h>
#define _THREADS_PER_BLOCK 256
#define _EXPONENT_CAP 10
#define _EXPONENT_BOT -10
#define _SMALL_NUM_BOT 1e-8
#define IDX(i, j, ld_i) (((j)*(ld_i))+(i))
#define IDXX(i, j, k, ld_i, ld_j) (((k)*((ld_i)*(ld_j)))+((j)*(ld_i))+(i))
////////////////////
// VECTOR-RELATED //
////////////////////
void elementwise_vec_prod( double const *left_vec,
double const *right_vec,
double *dev_result,
int length);
void elementwise_vec_sum( double const *left_vec,
double const *right_vec,
double *dev_result,
int length);
void elementwise_vec_sub( double const *left_vec,
double const *right_vec,
double *dev_result,
int length);
void elementwise_vec_div( double const *left_vec,
double const *right_vec,
double *dev_result,
int length);
void vec_scaling( double const *dev_mat,
double *dev_result,
int length,
double scaler);
void vec_shifting( double const *dev_mat,
double *dev_result,
int length,
double scaler);
void vec_squaring( double const *dev_mat,
double *dev_result,
int length);
void vec_cubing( double const *dev_mat,
double *dev_result,
int length);
void vec_exponentiating( double const *dev_mat,
double *dev_result,
int length);
void vec_sum( cublasHandle_t handle,
const double *dev_vec,
double *host_result,
int length);
void copy( double const *dev_vec,
double *dev_result,
int length);
void set_vec_to_val( double *dev_vec,
int length,
double val);
////////////////////
// MATRIX-RELATED //
////////////////////
void row_1st_pd( double const *dev_mat,
double *dev_result,
int num_rows,
int num_cols,
double delta);
void row_1st_inc( double const *dev_mat,
double *dev_result,
int num_rows,
int num_cols);
void col_1st_pd( double const *dev_mat,
double *dev_result,
int num_rows,
int num_cols,
double delta);
void colwise_mat_shift_by_row_vec( double const *dev_mat,
double const *dev_vec,
double *dev_result,
int num_rows,
int num_cols,
double min_K,
double delta);
void colwise_sum_of_mat( double const *dev_mat,
double *dev_result,
int num_rows,
int num_cols);
void colwise_mat_div_with_row_vec( double const *dev_mat,
double const *dev_vec,
double *dev_result,
double num_rows,
double num_cols);
void colwise_mat_prod_with_row_vec( double const *dev_mat,
double const *dev_vec,
double *dev_result,
double num_rows,
double num_cols);
void rowwise_mat_prod_with_col_vec( double const *dev_mat,
double const *dev_vec,
double *dev_result,
double num_rows,
double num_cols);
void colwise_normalization( double *dev_mat,
double num_rows,
double num_cols);
void colwise_mat_accu_prod( const double *dev_mat,
double *dev_result,
int num_rows,
int num_cols,
double init_value);
void get_diag_line_of_square_mat( const double *dev_mat,
double *dev_result,
int num_rs_or_cs);
// allocate, deallocate and initialization methods
void dev_alloc_and_init( double const * const &host_ptr,
double *&dev_ptr,
std::string name,
int size);
void dev_alloc( double *&dev_ptr,
std::string name,
int size);
void dev_release( double *&dev_ptr,
std::string name);
void dev_download( double *&host_ptr,
double * const &dev_ptr,
std::string name,
int size);
// tools for debugging
void print( double * const &host_ptr,
int num_rows,
int num_cols);
void dev_print( double * const &dev_ptr,
int num_rows,
int num_cols);
void dev_print_3d( double * const &dev_ptr,
int num_rows,
int num_cols,
int num_pages);
double dev_get( double const *dev_ptr,
int index);
#endif // MAT_H