Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/core/algorithm/hnsw/hnsw_algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace zvec {
namespace core {

HnswAlgorithm::HnswAlgorithm(HnswEntity &entity)
HnswAlgorithm::HnswAlgorithm(HnswStreamerEntityNew &entity)
: entity_(entity),
mt_(std::chrono::system_clock::now().time_since_epoch().count()),
lock_pool_(kLockCnt) {}
Expand Down Expand Up @@ -113,7 +113,7 @@ void HnswAlgorithm::select_entry_point(level_t level, node_id_t *entry_point,
auto &entity = ctx->get_entity();
HnswDistCalculator &dc = ctx->dist_calculator();
while (true) {
const Neighbors neighbors = entity.get_neighbors(level, *entry_point);
const Neighbors neighbors = entity.get_neighbors_new(level, *entry_point);
if (ailego_unlikely(ctx->debugging())) {
(*ctx->mutable_stats_get_neighbors())++;
}
Expand All @@ -123,7 +123,7 @@ void HnswAlgorithm::select_entry_point(level_t level, node_id_t *entry_point,
}

std::vector<IndexStorage::MemoryBlock> neighbor_vec_blocks;
int ret = entity.get_vector(&neighbors[0], size, neighbor_vec_blocks);
int ret = entity.get_vector_new(&neighbors[0], size, neighbor_vec_blocks);
if (ailego_unlikely(ctx->debugging())) {
(*ctx->mutable_stats_get_vector())++;
}
Expand Down Expand Up @@ -208,7 +208,7 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point,
}

candidates.pop();
const Neighbors neighbors = entity.get_neighbors(level, main_node);
const Neighbors neighbors = entity.get_neighbors_new(level, main_node);
ailego_prefetch(neighbors.data);
if (ailego_unlikely(ctx->debugging())) {
(*ctx->mutable_stats_get_neighbors())++;
Expand All @@ -232,7 +232,8 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point,
}

std::vector<IndexStorage::MemoryBlock> neighbor_vec_blocks;
int ret = entity.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks);
int ret =
entity.get_vector_new(neighbor_ids.data(), size, neighbor_vec_blocks);
if (ailego_unlikely(ctx->debugging())) {
(*ctx->mutable_stats_get_vector())++;
}
Expand Down Expand Up @@ -332,7 +333,7 @@ void HnswAlgorithm::expand_neighbors_by_group(TopkHeap &topk,
node_id_t main_node = top->first;

candidates.pop();
const Neighbors neighbors = entity.get_neighbors(0, main_node);
const Neighbors neighbors = entity.get_neighbors_new(0, main_node);
if (ailego_unlikely(ctx->debugging())) {
(*ctx->mutable_stats_get_neighbors())++;
}
Expand All @@ -356,7 +357,7 @@ void HnswAlgorithm::expand_neighbors_by_group(TopkHeap &topk,

std::vector<IndexStorage::MemoryBlock> neighbor_vec_blocks;
int ret =
entity.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks);
entity.get_vector_new(neighbor_ids.data(), size, neighbor_vec_blocks);
if (ailego_unlikely(ctx->debugging())) {
(*ctx->mutable_stats_get_vector())++;
}
Expand Down Expand Up @@ -463,7 +464,7 @@ void HnswAlgorithm::reverse_update_neighbors(HnswDistCalculator &dc,

uint32_t lock_idx = id & kLockMask;
lock_pool_[lock_idx].lock();
const Neighbors neighbors = entity_.get_neighbors(level, id);
const Neighbors neighbors = entity_.get_neighbors_new(level, id);
size_t size = neighbors.size();
ailego_assert_with(size <= max_neighbor_cnt, "invalid neighbor size");
if (size < max_neighbor_cnt) {
Expand Down
6 changes: 3 additions & 3 deletions src/core/algorithm/hnsw/hnsw_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <ailego/parallel/lock.h>
#include "hnsw_context.h"
#include "hnsw_dist_calculator.h"
#include "hnsw_entity.h"
#include "hnsw_streamer_entity_new.h"

namespace zvec {
namespace core {
Expand All @@ -29,7 +29,7 @@ class HnswAlgorithm {

public:
//! Constructor
explicit HnswAlgorithm(HnswEntity &entity);
explicit HnswAlgorithm(HnswStreamerEntityNew &entity);

//! Destructor
~HnswAlgorithm() = default;
Expand Down Expand Up @@ -116,7 +116,7 @@ class HnswAlgorithm {
static constexpr uint32_t kLockCnt{1U << 8};
static constexpr uint32_t kLockMask{kLockCnt - 1U};

HnswEntity &entity_;
HnswStreamerEntityNew &entity_;
mutable std::mt19937 mt_{};
std::vector<double> level_probas_{};

Expand Down
Loading
Loading