-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpredict.py
More file actions
165 lines (137 loc) · 4.96 KB
/
predict.py
File metadata and controls
165 lines (137 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from cog import BasePredictor, Input
from cog import Path as CogPath
import sys
import time
import shutil
import torch
import core.globals
import insightface
if not torch.cuda.is_available():
core.globals.providers = ["CPUExecutionProvider"]
print("gpu poor detected using cpu to run the model")
import glob
import os
from pathlib import Path
import cv2
from typing import Iterator
from subprocess import call, check_call
from core.processor import get_face_swapper, process_video, process_img
from core.utils import (
is_img,
detect_fps,
set_fps,
create_video,
add_audio,
extract_frames,
)
from core.config import get_face
from core.enhancer import enhance_face, get_face_enhancer
def status(string):
print("Status: " + string)
def run_cmd(command):
try:
call(command, shell=True)
except KeyboardInterrupt:
print("Process interrupted")
sys.exit(1)
class Predictor(BasePredictor):
def setup(self):
time.sleep(10)
# check_call("nvidia-smi", shell=True)
self.face_analyser = insightface.app.FaceAnalysis(
name="buffalo_l", providers=core.globals.providers
)
if os.path.isfile("inswapper_128_fp16.onnx"):
print("Model already downloaded")
else:
run_cmd(
"wget https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx"
)
if os.path.isfile("GFPGANv1.4.pth"):
print("Model already downloaded")
else:
run_cmd(
"wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
)
get_face_swapper()
get_face_enhancer()
self.face_analyser.prepare(ctx_id=0, det_size=(640, 640))
# assert torch.cuda.is_available()
def predict(
self,
source: CogPath = Input(description="Source", default=None),
target: CogPath = Input(description="Target", default=None),
reference_image: CogPath = Input(description="Reference Image", default=None),
keep_fps: bool = Input(description="Keep FPS", default=True),
keep_frames: bool = Input(description="Keep Frames", default=True),
) -> Iterator[CogPath]:
print("source: ", source)
print("target: ", target)
print("reference_image: ", reference_image)
print("keep_fps: ", keep_fps)
print("keep_frames: ", keep_frames)
if not source or not os.path.isfile(source):
print("\n[WARNING] Please select an image containing a face.")
return
elif not target or not os.path.isfile(target):
print("\n[WARNING] Please select a video/image to swap face in.")
return
source = str(source)
target = str(target)
reference_image = str(reference_image) if reference_image else None
face_analyser = self.face_analyser
test_face = get_face(cv2.imread(source), face_analyser)
if not test_face:
print(
"\n[WARNING] No face detected in source image. Please try with another one.\n"
)
return
if is_img(target):
output = process_img(source, target, face_analyser, reference_image)
yield CogPath(output)
status("swap successful!")
return
video_name = "output.mp4"
output_dir = "./output"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
Path(output_dir).mkdir(exist_ok=True)
status("detecting video's FPS...")
fps = detect_fps(target)
if not keep_fps and fps > 30:
this_path = output_dir + "/" + video_name + ".mp4"
set_fps(target, this_path, 30)
target, fps = this_path, 30
else:
shutil.copy(target, output_dir)
status("extracting frames...")
extract_frames(target, output_dir)
frame_paths = tuple(
sorted(
glob.glob(output_dir + f"/*.png"),
key=lambda x: int(x.split("/")[-1].replace(".png", "")),
)
)
status("swapping in progress...")
start_time = time.time()
process_video(source, frame_paths, face_analyser, reference_image)
end_time = time.time()
print(f"Processing time: {end_time - start_time:.2f} seconds")
status("creating video...")
output_file = create_video(video_name, fps, output_dir)
status("adding audio...")
output_file = add_audio(output_dir, target, keep_frames)
print("\n\nVideo saved as:", output_file, "\n\n")
yield CogPath(output_file)
status("swap successful!")
if __name__ == "__main__":
predictor = Predictor()
predictor.setup()
for output in predictor.predict(
source=CogPath("source3.jpg"),
target=CogPath("target.mp4"),
reference_image=CogPath("ref.png"),
):
print(output)
break
print("done")