-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvec_lib.hpp
More file actions
201 lines (160 loc) · 4.48 KB
/
vec_lib.hpp
File metadata and controls
201 lines (160 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#ifndef VEC_LIB_HPP_
#define VEC_LIB_HPP_
#include <bit>
#include <cstdint>
#include <immintrin.h>
class Vec8x32f {
private:
__m256 vec_;
public:
explicit Vec8x32f(const float* f32data) {
vec_ = _mm256_loadu_ps(f32data);
}
explicit Vec8x32f(float val) {
vec_ = _mm256_set1_ps(val);
}
Vec8x32f(const Vec8x32f& other) : vec_{other.vec_} {}
explicit Vec8x32f(__m256 raw_vec) : vec_{raw_vec} {}
void Load(const float* f32data) {
vec_ = _mm256_loadu_ps(f32data);
}
void Store(float* f32data) {
_mm256_storeu_ps(f32data, vec_);
}
__m256 GetRawVec() {
return vec_;
}
Vec8x32f& operator=(Vec8x32f other) {
vec_ = other.vec_;
return *this;
}
Vec8x32f operator*(Vec8x32f other) {
return Vec8x32f{
_mm256_mul_ps(vec_, other.vec_)
};
}
Vec8x32f operator+(Vec8x32f other) {
return Vec8x32f{
_mm256_add_ps(vec_, other.vec_)
};
}
uint32_t operator<=(Vec8x32f other) {
return std::bit_cast<uint32_t>(
_mm256_movemask_ps(
_mm256_cmp_ps(vec_, other.vec_, _CMP_LE_OQ)
)
);
}
};
class Vec4x64u {
private:
__m256i vec_;
public:
explicit Vec4x64u(const uint64_t* u64data) {
vec_ = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u64data));
}
explicit Vec4x64u(uint64_t val = 0) {
vec_ = _mm256_set1_epi64x(std::bit_cast<int64_t>(val));
}
Vec4x64u(const Vec4x64u& other) : vec_{other.vec_} {}
explicit Vec4x64u(__m256i val) : vec_{val} {}
void Load(const uint64_t* u64data) {
vec_ = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u64data));
}
void Store(uint64_t* u64data) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(u64data), vec_);
}
__m256i GetRawVec() {
return vec_;
}
Vec4x64u& operator=(Vec4x64u other) {
vec_ = other.vec_;
return *this;
}
Vec4x64u operator*(Vec4x64u other) {
return Vec4x64u{
_mm256_mul_epu32(vec_, other.vec_)
};
}
Vec4x64u operator+(Vec4x64u other) {
return Vec4x64u{
_mm256_add_epi64(vec_, other.vec_)
};
}
Vec4x64u operator>>(uint32_t count) {
__m128i vec_count = _mm_set1_epi64x(count);
return Vec4x64u{
_mm256_srl_epi64(vec_, vec_count)
};
}
Vec4x64u operator&(Vec4x64u other) {
return Vec4x64u{
_mm256_and_si256(vec_, other.vec_)
};
}
};
class Vec8x32u {
private:
__m256i vec_;
public:
explicit Vec8x32u(const uint32_t* u32data) {
vec_ = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u32data));
}
explicit Vec8x32u(uint32_t val = 0) {
vec_ = _mm256_set1_epi32(std::bit_cast<int32_t>(val));
}
Vec8x32u(const Vec8x32u& other) : vec_{other.vec_} {}
explicit Vec8x32u(__m256i val) : vec_{val} {}
void Load(const uint32_t* u32data) {
vec_ = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u32data));
}
void Store(uint32_t* u32data) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(u32data), vec_);
}
void Reduce2u64Tou32(Vec4x64u vec_lo, Vec4x64u vec_hi) {
__m256i perm_mask = _mm256_set_epi32(0, 0, 0, 0, 6, 4, 2, 0);
__m128i lo = _mm256_castsi256_si128(
_mm256_permutevar8x32_epi32(vec_lo.GetRawVec(), perm_mask)
);
__m128i hi = _mm256_castsi256_si128(
_mm256_permutevar8x32_epi32(vec_hi.GetRawVec(), perm_mask)
);
vec_ = _mm256_set_m128i(hi, lo);
}
__m256i GetRawVec() {
return vec_;
}
Vec8x32u& operator=(Vec8x32u other) {
vec_ = other.vec_;
return *this;
}
};
inline uint32_t CountOnes(uint32_t val) {
return std::bit_cast<uint32_t>(_mm_popcnt_u32(val));
}
// FromT is deduced by compiler
template <typename ToT, typename FromT>
ToT VecBitCast(FromT vec);
template <>
Vec8x32u VecBitCast(Vec8x32f vec) {
return Vec8x32u{
_mm256_castps_si256(vec.GetRawVec())
};
}
template <>
Vec8x32f VecBitCast(Vec8x32u vec) {
return Vec8x32f{
_mm256_castsi256_ps(vec.GetRawVec())
};
}
// FromT is deduced by compiler
template <typename ToT, typename FromT>
ToT VecValueCast(FromT vec);
template <>
Vec8x32f VecValueCast(Vec8x32u vec) {
return Vec8x32f{
_mm256_cvtepi32_ps(vec.GetRawVec())
};
}
// FIXME implement casts
#endif // VEC_LIB_HPP_