-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
126 lines (91 loc) · 3.54 KB
/
utils.py
File metadata and controls
126 lines (91 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations
import json
from pathlib import Path
from typing import Iterable, List, Tuple
import faiss
import numpy as np
from wordfreq import top_n_list
DATA_DIR = Path(__file__).resolve().parent / "data"
WORDS_PATH = DATA_DIR / "words.json"
EMBEDDINGS_PATH = DATA_DIR / "embeddings.npy"
FAISS_INDEX_PATH = DATA_DIR / "faiss.index"
def normalize_word(word: str) -> str:
return word.strip().lower()
def is_alpha_word(word: str) -> bool:
return word.isalpha()
def generate_vocabulary(
target_size: int = 50_000,
min_length: int = 2,
) -> List[str]:
"""Generate a frequency-based list of common daily English words using wordfreq."""
if target_size <= 0:
raise ValueError("target_size must be positive")
fetch_count = max(target_size * 2, 120_000)
words = top_n_list("en", fetch_count)
filtered: list[str] = []
seen: set[str] = set()
for word in words:
candidate = normalize_word(word)
if len(candidate) < min_length:
continue
if not is_alpha_word(candidate):
continue
if candidate in seen:
continue
seen.add(candidate)
filtered.append(candidate)
if len(filtered) >= target_size:
break
if len(filtered) < target_size:
raise RuntimeError(
f"Unable to generate {target_size} filtered words; got {len(filtered)}"
)
return filtered
def save_words(words: Iterable[str], path: Path = WORDS_PATH) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(list(words), f, ensure_ascii=False)
def load_words(path: Path = WORDS_PATH) -> List[str]:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def l2_normalize(vectors: np.ndarray) -> np.ndarray:
vectors = np.asarray(vectors, dtype=np.float32)
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-12)
return vectors / norms
def save_embeddings(embeddings: np.ndarray, path: Path = EMBEDDINGS_PATH) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
np.save(path, embeddings.astype(np.float32))
def load_embeddings(path: Path = EMBEDDINGS_PATH) -> np.ndarray:
return np.load(path)
def save_faiss_index(index: faiss.Index, path: Path = FAISS_INDEX_PATH) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
faiss.write_index(index, str(path))
def load_faiss_index(path: Path = FAISS_INDEX_PATH) -> faiss.Index:
return faiss.read_index(str(path))
def build_faiss_index(embeddings: np.ndarray) -> faiss.Index:
if embeddings.ndim != 2:
raise ValueError("embeddings must be a 2D matrix")
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(np.asarray(embeddings, dtype=np.float32))
return index
def find_rank_and_score(indices: np.ndarray, scores: np.ndarray, target_index: int) -> Tuple[int, float]:
positions = np.where(indices[0] == target_index)[0]
if positions.size == 0:
return len(indices[0]) + 1, float("-inf")
pos = int(positions[0])
return pos + 1, float(scores[0][pos])
def rank_to_color(rank: int) -> str:
if rank == 1:
return "#0b7a35" # exact match
if 2 <= rank <= 199:
return "#1f9d55" # very close
if 200 <= rank <= 1000:
return "#2e8b57" # green
if 1001 <= rank <= 10000:
return "#f2c94c" # yellow
if 10001 <= rank <= 25000:
return "#f2994a" # orange
if rank > 25000:
return "#9ca3af" # gray
return "#9ca3af"