-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
157 lines (122 loc) · 4.42 KB
/
server.py
File metadata and controls
157 lines (122 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = 'True'
import json
import numpy as np
import torch
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from starlette.middleware.base import _StreamingResponse
from torch import nn
from config import Config
from error_code import ReException, ErrorCode
from model import model_func_map
from proto import Model, UserItems, ReResponse
from service.feature_service import FeatureService
app = FastAPI()
model: nn.Module = None
HEALTH_PATH = "/health"
CLEAN_PATH = "/clean"
MODEL_PATH = "/model"
MODEL_LOAD_PATH = MODEL_PATH + "/load"
MODEL_SCORE_PATH = MODEL_PATH + "/score"
PROXY_PATH_SET = {
HEALTH_PATH,
CLEAN_PATH,
MODEL_LOAD_PATH,
MODEL_SCORE_PATH,
}
FEATURE_DIM = 421
feature_service = FeatureService()
@app.exception_handler(HTTPException)
async def exception_handler(request, exception: ReException):
return JSONResponse(
content=ReResponse(code=exception.status_code, status="fail", data=None, message=exception.detail).to_dict(),
status_code=exception.status_code
)
@app.middleware("http")
async def response_format(request: Request, call_next):
response = await call_next(request)
if request.url.path not in PROXY_PATH_SET:
return response
if type(response) is ReResponse:
return JSONResponse(content=response.to_dict())
elif type(response) is _StreamingResponse:
if response.status_code >= ErrorCode.UNKNOWN_ERROR.code:
response.status_code = 200
return response
async def build_content(stream_response):
content = b''
async for chunk in stream_response.body_iterator:
content += chunk
return content
content = await build_content(response)
content = content.decode("utf-8")
try:
content = json.loads(content)
except json.JSONDecodeError:
content = eval(content)
return JSONResponse(content=ReResponse(code=0, status="success", data=content).to_dict())
else:
return JSONResponse(content=response.body())
@app.get("/")
def index():
return FileResponse("index.html")
@app.get("/health")
def health():
return "ok"
@app.post("/model/load")
async def load_model(model_info: Model):
global model
model_type = model_info.type.strip().lower()
if model_type not in model_func_map:
raise ReException(ErrorCode.INVALID_MODEL)
model = model_func_map[model_type](model_info.dim)
try:
model.load_state_dict(torch.load(model_info.model))
model.eval()
except FileNotFoundError as fnfe:
raise ReException(ErrorCode.MODEL_NOT_FOUND)
except Exception as e:
raise ReException(ErrorCode.LOAD_MODEL_FAILED)
@app.post("/clean")
def clean():
global model
if model:
model = None
torch.cuda.empty_cache()
@app.post("/model/score")
def score(user_items: UserItems):
if not model:
raise ReException(ErrorCode.MODEL_NOT_LOAD_YET)
try:
with torch.no_grad():
user_features = feature_service.get_user_feature_by_id(user_items.user_id)
batch_features = []
item_score_map = {}
hit_items = []
for item_id in user_items.item_ids:
item_features = feature_service.get_item_feature_by_id(item_id)
if item_features is None:
item_score_map[item_id] = 0.0
continue
if user_features is None:
user_features = np.zeros(model.dim - item_features.size)
batch_features.append(torch.cat(
(
torch.tensor(user_features, dtype=torch.float32),
torch.tensor(item_features, dtype=torch.float32)
),
dim=0
))
hit_items.append(item_id)
with torch.no_grad():
score = model(torch.stack(batch_features))
hit_item_scores = score.squeeze().tolist()
for h_item_id, h_item_score in zip(hit_items, hit_item_scores):
item_score_map[h_item_id] = h_item_score
return item_score_map
except Exception as e:
raise (ReException(ErrorCode.INFERENCE_FAILED))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=Config.SERVER.HOST, port=Config.SERVER.PORT)