-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpositional_encoding.cpp
More file actions
50 lines (42 loc) · 1.92 KB
/
positional_encoding.cpp
File metadata and controls
50 lines (42 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <iostream>
#include <vector>
#include <cmath>
// Generate positional encodings for patches
std::vector<std::vector<float>> create_positional_encoding(int num_patches, int embedding_dim) {
std::vector<std::vector<float>> positional_encoding(num_patches, std::vector<float>(embedding_dim));
// Iterate over each patch
for (int i = 0; i < num_patches; ++i) {
// Iterate over each dimension in the embedding space
for (int j = 0; j < embedding_dim; ++j) {
float angle = (float)i / std::pow(10000, (2.0 * ((float)j / 2.0)) / (float)embedding_dim);
// Apply sine or cosine function based on dimension index
if (j % 2 == 0) {
positional_encoding[i][j] = std::sin(angle);
} else {
positional_encoding[i][j] = std::cos(angle);
}
}
}
return positional_encoding;
}
// Add positional encoding to patch embeddings with scaling
std::vector<std::vector<float>> add_positional_encoding(
const std::vector<std::vector<float>>& patch_embeddings,
const std::vector<std::vector<float>>& positional_encoding)
{
// Get number of patches and embedding dimension
int num_patches = patch_embeddings.size();
int embedding_dim = patch_embeddings[0].size();
std::vector<std::vector<float>> patch_with_positional_encoding(num_patches, std::vector<float>(embedding_dim));
// Scaling factor to prevent excessive influence of positional encoding
float lambda = 0.1f;
// Iterate over each patch
for (int i = 0; i < num_patches; ++i) {
// Iterate over each dimension in the embedding space
for (int j = 0; j < embedding_dim; ++j) {
// Add positional encoding to patch embeddings with scaling
patch_with_positional_encoding[i][j] = patch_embeddings[i][j] + lambda * positional_encoding[i][j];
}
}
return patch_with_positional_encoding;
}