-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapi_server.py
More file actions
170 lines (134 loc) · 7.03 KB
/
api_server.py
File metadata and controls
170 lines (134 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# 실행 명령어: uvicorn api_server:app --host 0.0.0.0 --port 8000 --reload
import os
import requests
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import whisper
import librosa
# 코드 임포트
from model import EndToEndStressModel
from preprocess import get_preprocessors
app = FastAPI(title="V&T Real Inference API")
# --- 전역 변수 세팅 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {}
# 백엔드에서 받을 JSON 규격 (Presigned URL)
class AnalyzeRequest(BaseModel):
file_url: str
@app.on_event("startup")
async def load_models():
"""서버 시작 시 무거운 AI 모델들을 메모리에 한 번만 로드합니다."""
print(f"🚀 AI 모델 로딩 시작... (Device: {device})")
# 1. Whisper 모델 로드
models["whisper"] = whisper.load_model("base").to(device)
# 2. 전처리기 로드
audio_processor, tokenizer = get_preprocessors()
models["audio_processor"] = audio_processor
models["tokenizer"] = tokenizer
# 3. 융합 모델 로드
vnt_model = EndToEndStressModel().to(device)
# 학습된 가중치 로드
vnt_model.load_state_dict(torch.load("best_model_best.pth", map_location=device))
vnt_model.eval() # 추론 모드로 전환 (Dropout 등 비활성화)
models["vnt_model"] = vnt_model
print("✅ 모든 AI 모델 로딩 완료!")
@app.post("/api/v1/analyze")
async def analyze_audio(req: AnalyzeRequest):
"""
S3 Presigned URL을 받아 다운로드 후, STT 분리 및 V&T 모델 추론을 수행합니다.
"""
temp_audio_path = "temp_downloaded.wav"
try:
# 1. S3 Presigned URL에서 오디오 파일 다운로드
print("\n[Step 1] S3에서 오디오 다운로드 중...")
response = requests.get(req.file_url)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="S3 URL에서 파일을 다운로드할 수 없습니다.")
with open(temp_audio_path, "wb") as f:
f.write(response.content)
# 2. Whisper로 문장 단위 분리 및 인메모리 슬라이싱
print("[Step 2] Whisper STT 문장 분리 진행 중...")
stt_result = models["whisper"].transcribe(temp_audio_path, language="ko", condition_on_previous_text=False)
waveform, sr = librosa.load(temp_audio_path, sr=16000)
final_report = []
dissonance_sum = 0.0
valid_chunk_count = 0
print("[Step 3] 멀티모달 V&T 추론 시작...")
for idx, seg in enumerate(stt_result["segments"]):
start_time, end_time = seg["start"], seg["end"]
text = seg["text"].strip()
# 노이즈/환각 필터링
if (end_time - start_time) < 0.5:
continue
# Numpy 슬라이싱
start_sample, end_sample = int(start_time * sr), int(end_time * sr)
audio_chunk = waveform[start_sample:end_sample]
audio_values = models["audio_processor"](
audio_chunk, return_tensors="pt", sampling_rate=16000,
padding='max_length', max_length=80000, truncation=True
).input_values.to(device)
text_inputs = models["tokenizer"](
text, return_tensors="pt", padding='max_length', max_length=128, truncation=True
).to(device)
# 3. 모델 추론
with torch.no_grad():
fusion_logits, audio_logits, text_logits = models["vnt_model"](
audio_values,
text_inputs['input_ids'],
text_inputs['attention_mask']
)
# ----------------------------------------------------
# Threshold Tuning 적용 부분
# ----------------------------------------------------
# 1. Logits를 Softmax로 변환하여 0.0 ~ 1.0 사이의 확률값으로 만듦
fusion_probs = F.softmax(fusion_logits, dim=1).squeeze(0)
audio_probs = F.softmax(audio_logits, dim=1).squeeze(0)
text_probs = F.softmax(text_logits, dim=1).squeeze(0)
# 2. 부정(스트레스) 클래스의 확률만 뽑아냄
conflict_prob = fusion_probs[1].item() * 100
audio_negative_prob = audio_probs[1].item() * 100
text_negative_prob = text_probs[1].item() * 100
# 3. 임계값(Threshold) 설정
THRESHOLD = 35.0
# 4. 임계값을 기준으로 최종 감정 문자열 결정
is_conflict = bool(conflict_prob > 60.0) # 불일치는 보수적으로 60점 유지
audio_emo_str = "부정/스트레스" if audio_negative_prob > THRESHOLD else "긍정/안정"
text_emo_str = "부정/스트레스" if text_negative_prob > THRESHOLD else "긍정/안정"
# ----------------------------------------------------
final_report.append({
"time_range": f"{start_time:.1f}s - {end_time:.1f}s",
"stt_chunk": text,
"text_emotion": text_emo_str,
"audio_emotion": audio_emo_str,
"dissonance_score": round(conflict_prob, 2),
"is_conflict": is_conflict
})
dissonance_sum += conflict_prob
valid_chunk_count += 1
print(f" -> [{start_time:.1f}s] {text} | 텍스트:{text_emo_str}({text_negative_prob:.1f}%), 음성:{audio_emo_str}({audio_negative_prob:.1f}%) | 불일치율: {conflict_prob:.1f}%")
# 4. 종합 결과 계산
overall_dissonance = round(dissonance_sum / valid_chunk_count, 2) if valid_chunk_count > 0 else 0.0
OVERALL_THRESHOLD = 40.0
if overall_dissonance > OVERALL_THRESHOLD:
primary_emotion = "감정 불일치 (숨겨진 스트레스)"
else:
primary_emotion = "감정 일치 (표현과 진심이 같음)"
return {
"status": "success",
"message": "AI 멀티모달 감정 불일치 분석이 완료되었습니다.",
"data": {
"overall_analysis": {
"primary_emotion": primary_emotion,
"dissonance_index": overall_dissonance
},
"time_series_analysis": final_report
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"AI 분석 서버 에러: {str(e)}")
finally:
# 다운로드했던 임시 파일 삭제 (용량 관리)
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)