-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
94 lines (72 loc) · 3.06 KB
/
app.py
File metadata and controls
94 lines (72 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st
import pandas as pd
torch.classes.__path__ = [] # to handel async io issue
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@st.cache_resource
def load_resources():
model_path = "model"
test_data = pd.read_csv(r'data/samsum-test.csv')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model = model.to(device)
model.eval()
return model, tokenizer, test_data
model, tokenizer, test_data = load_resources()
def generate_summary(article_text, num_tokens, num_beams):
article_text = "[SOS] " + article_text + " [EOS]"
inputs = tokenizer(article_text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
early_stop = False
if num_beams>1:
early_stop=True
else:
early_stop=False
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=num_tokens,
num_beams=num_beams,
early_stopping=early_stop,
no_repeat_ngram_size=3
)
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return summary
# -------------------- Streamlit-----------------
if 'conversation' not in st.session_state:
st.session_state.conversation = ''
if 'summary' not in st.session_state:
st.session_state.summary = ''
if 'num_beams' not in st.session_state:
st.session_state.num_beams = 2
if 'num_tokens' not in st.session_state:
st.session_state.num_tokens = 60
st.title("🗨️SAMSummerizer")
conv_col, summary_col, options_col = st.columns((0.4,0.4, 0.2))
def clear_board():
st.session_state.conversation = ''
st.session_state.summary = ''
def generate_random():
random_index = torch.randint(low = 0, high = len(test_data)+1, size = (1,))[0].item()
conversation = test_data.loc[random_index, 'dialogue']
st.session_state.conversation = conversation
def summarize():
with st.spinner("model summarizing...", show_time=True):
st.session_state.summary = generate_summary(st.session_state.conversation, st.session_state.num_tokens, st.session_state.num_beams)
with conv_col:
st.text_area("Conversation:", key='conversation', height=220)
st.button("Generate Random🔄️", on_click=generate_random)
with summary_col:
st.text_area("Summary:", key='summary', disabled=True, height=220)
summary_btn_col, clean_btn_col = st.columns((0.5,0.5))
with summary_btn_col:
st.button("Summarize🧾", on_click=summarize)
with clean_btn_col:
st.button("Clear🗑️", on_click=clear_board)
with options_col:
with st.container(border=True):
st.text("⚙ Settings:")
st.slider(label = "Max no. of Beams:", min_value=1, max_value=5, key='num_beams')
st.slider(label = "Max no. of Tokens:", min_value=10, max_value=350, key='num_tokens')