-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsummarizer.py
More file actions
80 lines (60 loc) · 2.46 KB
/
summarizer.py
File metadata and controls
80 lines (60 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sys
import os
from common import inference, download, get_text_from_html, download_and_extract_text_from_pdf, tokenize_gpt2, detokenize_gpt2, split_text
os.environ['TOKENIZERS_PARALLELISM'] = "False"
MAX_TOKENS = 4000
MAX_NEW_TOKENS = 500
SUMMARY_PREPROMPT = "SUMMARIZE THE FOLLOWING DOCUMENT:\n=====\n"
SUMMARY_MIDPROMPT = ""
SUMMARY_POSTPROMPT = "\n=====\nSUMMARY:\n"
QUESTION_PREPROMPT = "ANSWER A QUESTION ABOUT THE FOLLOWING DOCUMENT:\n=====\n"
QUESTION_MIDPROMPT = "\n=====\nQUESTION: "
QUESTION_POSTPROMPT = "\nANSWER:\n"
SUMMARY_TOKEN_BUDGET = MAX_TOKENS - len(tokenize_gpt2(SUMMARY_PREPROMPT)) - len(tokenize_gpt2(SUMMARY_POSTPROMPT)) - MAX_NEW_TOKENS
QUESTION_TOKEN_BUDGET = MAX_TOKENS - len(tokenize_gpt2(QUESTION_PREPROMPT)) - len(tokenize_gpt2(QUESTION_POSTPROMPT)) - len(tokenize_gpt2(QUESTION_MIDPROMPT)) - MAX_NEW_TOKENS
def summarize(text, mode="summary", question=""):
# Tokenize
tokenized = tokenize_gpt2(text)
# Split
split = split_text(tokenized, SUMMARY_TOKEN_BUDGET if mode == "summary" else (QUESTION_TOKEN_BUDGET - len(tokenize_gpt2(question))))
# If split has several chunks
if len(split) > 1:
# Summarize each chunk
summaries = []
for i, chunk in enumerate(split):
# Decode the chunk
chunk_text = detokenize_gpt2(chunk)
# Summarize the chunk
summaries.append(summarize(chunk_text, mode, question))
# Print last summary and index
print(i, summaries[-1])
print("")
# Join the summaries
summaries = " ".join(summaries)
else:
summaries = text
if mode == "summary":
prompt = SUMMARY_PREPROMPT + summaries + SUMMARY_POSTPROMPT
else:
prompt = QUESTION_PREPROMPT + summaries + QUESTION_MIDPROMPT + question + QUESTION_POSTPROMPT
return inference(prompt, MAX_NEW_TOKENS)
# print()
# Take URL from the first argument
url = sys.argv[1]
# Take the rest of the arguments and concatenate with a space
question = " ".join(sys.argv[2:]).strip()
mode = "summary" if question == "" else "question"
print("Mode:", mode)
print("Question:", question)
# If the url ends in .pdf
if url.endswith(".pdf"):
# Get the text from the PDF
text = download_and_extract_text_from_pdf(url)
else:
# Download the page
html = download(url)
# Get the text from the URL
text = get_text_from_html(html)
# Summarize
result = summarize(text, mode, question)
print("Final result:", result)