-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizers.go
More file actions
120 lines (97 loc) · 2.76 KB
/
tokenizers.go
File metadata and controls
120 lines (97 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package main
import (
"fmt"
"path/filepath"
"github.com/pkoukk/tiktoken-go"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
)
type Tokenizer interface {
CountTokens(text string) (TokenCount, error)
Name() string
}
type TiktokenTokenizer struct {
encoding *tiktoken.Tiktoken
model string
tokenLimit int
warnLimit int
}
func NewTiktokenTokenizer(model string, limit int) (*TiktokenTokenizer, error) {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, fmt.Errorf("failed to get tiktoken encoding: %w", err)
}
warnLimit := int(float64(limit) * 0.8)
return &TiktokenTokenizer{
encoding: encoding,
model: model,
tokenLimit: limit,
warnLimit: warnLimit,
}, nil
}
func (t *TiktokenTokenizer) CountTokens(text string) (TokenCount, error) {
tokens := t.encoding.Encode(text, nil, nil)
count := len(tokens)
tokenCount := TokenCount{
Count: count,
TokensPerc: float64(count) / float64(t.tokenLimit) * 100,
Truncated: count > t.tokenLimit,
TokenLimit: t.tokenLimit,
WarnLimit: t.warnLimit,
}
return tokenCount, nil
}
func (t *TiktokenTokenizer) Name() string {
return fmt.Sprintf("tiktoken-%s", t.model)
}
type HuggingFaceTokenizer struct {
tokenizer *tokenizer.Tokenizer
modelPath string
tokenLimit int
warnLimit int
}
func NewHuggingFaceTokenizer(modelPath string, limit int) (*HuggingFaceTokenizer, error) {
tok, err := pretrained.FromFile(modelPath)
if err != nil {
return nil, fmt.Errorf("failed to load HuggingFace tokenizer: %w", err)
}
warnLimit := int(float64(limit) * 0.8)
return &HuggingFaceTokenizer{
tokenizer: tok,
modelPath: modelPath,
tokenLimit: limit,
warnLimit: warnLimit,
}, nil
}
func (h *HuggingFaceTokenizer) CountTokens(text string) (TokenCount, error) {
encoding, err := h.tokenizer.EncodeSingle(text)
if err != nil {
return TokenCount{}, fmt.Errorf("failed to encode text: %w", err)
}
count := len(encoding.Ids)
tokenCount := TokenCount{
Count: count,
TokensPerc: float64(count) / float64(h.tokenLimit) * 100,
Truncated: count > h.tokenLimit,
TokenLimit: h.tokenLimit,
WarnLimit: h.warnLimit,
}
return tokenCount, nil
}
func (h *HuggingFaceTokenizer) Name() string {
return fmt.Sprintf("huggingface-%s", filepath.Base(h.modelPath))
}
func NewTokenizer(tokType TokenizerType, modelPath string, limit int) (Tokenizer, error) {
switch tokType {
case TiktokenGPT35:
return NewTiktokenTokenizer("gpt-3.5-turbo", limit)
case TiktokenGPT4:
return NewTiktokenTokenizer("gpt-4", limit)
case TiktokenClaude:
return NewTiktokenTokenizer("claude", limit)
case HuggingFace:
return NewHuggingFaceTokenizer(modelPath, limit)
default:
return nil, fmt.Errorf("unsupported tokenizer type: %s", tokType)
}
}