-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathclassify.py
More file actions
157 lines (128 loc) · 4.54 KB
/
classify.py
File metadata and controls
157 lines (128 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import asyncio
import os
import openai
from aiohttp import ClientSession
from dotenv import load_dotenv
from errors import *
load_dotenv()
openai.api_key = os.environ.get('openai_token', '').rstrip("\n")
if not openai.api_key:
raise ConfigError("OpenAI key not found")
class FileClassifier:
@staticmethod
# Works with short summaries
async def classify_short(text: str, labels: dict[str, str]) -> str | None:
"""
Classifies a given text with into a set of categories
Returns: the chosen label / None if no label is chosen
:param str text: The text to be classified
:param dict[str, str] labels: a dict of labels + summaries
"""
# TODO: tweak
PRE_PROMPT = "Classify the document content into the right folder based on its summary.\n"
GPT_ARGS = {
'model': 'text-davinci-003',
'temperature': 0,
'max_tokens': 60,
'top_p': 1.0,
'frequency_penalty': 0.0,
'presence_penalty': 0.0
}
text = text.replace('\n', ' ')
labels = labels.copy()
for l in labels:
labels[l] = labels[l].strip()
prompt = f"{PRE_PROMPT} Document Content: {text}\n\n"
prompt += "Folders:\n" + \
'\n'.join(f'- [{l}]({labels[l]})' for l in labels)
prompt += "\n\n The right folder path:"
print(prompt)
# TODO: add folders to this
print("_____")
print(prompt)
print("_____")
response = await openai.Completion.acreate(
prompt=prompt,
**GPT_ARGS,
)
if type(response) is not dict: raise InvalidResponse()
try:
# i hate union types >:(
cat = response['choices'][0].text.strip()
print(cat)
for k, v in labels.items():
if cat.find(v) != -1:
return k
return None
except (KeyError, IndexError):
raise InvalidResponse()
# Works with the longer summaries
@staticmethod
async def classify_long(text: str, labels: dict[str, str]) -> str | None:
"""
Classifies a given text with into a set of categories
Returns: the chosen label / None if no label is chosen
:param str text: The text to be classified
:param dict[str, str] labels: a dict of labels + summaries
"""
# TODO: tweak
PRE_PROMPT = "Classify the text into one of the following categories:"
GPT_ARGS = {
'model': 'text-davinci-003',
'temperature': 0,
'max_tokens': 60,
'top_p': 1.0,
'frequency_penalty': 0.0,
'presence_penalty': 0.0
}
text = text.replace('\n', ' ')
labels = labels.copy()
for l in labels:
labels[l] = labels[l].strip()
prompt = f"{PRE_PROMPT} {', '.join(labels.values())}\n\n" \
f"Text: \n{text}"
response = await openai.Completion.acreate(
prompt=prompt,
**GPT_ARGS,
)
if type(response) is not dict: raise InvalidResponse()
try:
# i hate union types >:(
cat = response['choices'][0].text.strip()
for k, v in labels.items():
if cat.find(v) != -1:
return k
return None
except (KeyError, IndexError):
return None
@staticmethod
async def summarize(text, max_chars=100, max_words=5):
"""
Summarizes the text
Returns: summary
:param str text: The text to be classified
:param int max_chars: a dict of labels + summaries
"""
# TODO: tweak
PRE_PROMPT = f"Write a short summary (within {max_chars} characters and {max_words} words) for the following text:"
GPT_ARGS = {
'model': 'text-davinci-003',
'temperature': 0,
'max_tokens': 60,
'top_p': 1.0,
'frequency_penalty': 0.0,
'presence_penalty': 0.0
}
async with ClientSession() as session:
openai.aiosession.set(session)
prompt = f"{PRE_PROMPT}\n\n" \
f"Text: \n{text}"
response = openai.Completion.create(
prompt=prompt,
**GPT_ARGS
)
if type(response) is not dict: raise InvalidResponse()
try:
return response['choices'][0].text.strip()
except (KeyError, IndexError):
return None