-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinterface.py
More file actions
70 lines (60 loc) · 1.84 KB
/
interface.py
File metadata and controls
70 lines (60 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import asyncio
from global_variables import (
title,
description,
article_trainer,
article_helper,
article_tester,
)
import gradio as gr
from src.message_processor.message_processor import MessageProcessor
from llama_index.core.base.llms.types import (
ChatMessage,
)
from llama_index.core.memory.chat_summary_memory_buffer import ChatSummaryMemoryBuffer
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from src.documents_handler.load_documents import query_engine
from src.documents_handler.load_documents import vector_index
import os
llm = HuggingFaceInferenceAPI(
model_name="mistralai/Mistral-7B-Instruct-v0.3",
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)
vector_index_retriever = vector_index.as_retriever()
# Init Chat History & Mempory
chat_history = [ChatMessage()]
chat_memory = ChatSummaryMemoryBuffer(
token_limit=900,
chat_history=chat_history,
)
message_processor = MessageProcessor(
retriever=vector_index_retriever,
chat_history=chat_history,
llm=llm,
)
demo = gr.Blocks()
pitch_trainer = gr.Interface(
fn=message_processor.pitch_train_handler,
inputs=[
gr.Dropdown(type="value", value="hard", choices=["easy", "medium", "hard", "extreme"]),
gr.Audio(label="Use Your Microphone For Best Results" , type= "filepath"),
gr.Textbox(label="Add Additional Information Via Text Here", ),
],
outputs=[
gr.Textbox(label="Tonic Pitch Trainer"),
],
allow_flagging="never",
title=title,
description=description,
article=article_trainer,
)
with demo:
gr.TabbedInterface([
# pitch_helper,
# pitch_tester,
pitch_trainer
], ["Tonic Pitch Assistant", "Test Your Pitching", "Train For Your Pitch"])
demo.queue(max_size=5)
demo.launch(server_name="localhost", show_api=False)
# if __name__ == '__main__':
# asyncio.run(main())