diff --git a/interface.py b/interface.py new file mode 100644 index 0000000..6c88bf4 --- /dev/null +++ b/interface.py @@ -0,0 +1,91 @@ +from flask import Flask, request, Response, send_file + +import jsonpickle +import numpy as np +import cv2 +import io + + +from makeup import change_color +from test import evaluate +from PIL import Image + +import matplotlib.pyplot as plt + +app = Flask(__name__) + +# localhost:5000/demo?u_lip=0,0,255&l_lip=0,0,255&hair=0,0,255&skin=255,255,255 + + +@app.route("/demo", methods=["GET"]) +def apply_makeup(): + r = request + np_array = np.frombuffer(r.data, np.uint8) + img = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + + cp = "cp/79999_iter.pth" + + parsing = evaluate(img, cp) + parsing = cv2.resize(parsing, img.shape[0:2], interpolation=cv2.INTER_NEAREST) + # print([key for key, value in **request.args.get().items()]) + modified_img = change_color(img, parsing, **request.args) + + cv2.imwrite("img.jpg", cv2.cvtColor(modified_img, cv2.COLOR_BGR2RGB)) + response = { + "message": "image received. size={}x{}".format(img.shape[1], img.shape[0]) + } + + # encode response using jsonpickle + response_pickled = jsonpickle.encode(response) + + # pil_image = Image.fromarray(modified_img) + # pil_image = io.StringIO(Image.fromarray(modified_img)) + + return send_file( + "img.jpg", + mimetype="image/jpeg", + attachment_filename="new.jpeg", + as_attachment=False, + ) + + +if __name__ == "__main__": + app.debug = True + app.run() + + +# +# from flask import Flask, request, Response, send_file +# import jsonpickle +# import numpy as np +# import cv2 +# +# import ImageProcessingFlask +# +# # Initialize the Flask application +# app = Flask(__name__) +# +# +# # route http posts to this method +# @app.route('/api/test', methods=['POST']) +# def test(): +# r = request +# # convert string of image data to uint8 +# nparr = np.fromstring(r.data, np.uint8) +# # decode image +# img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) +# +# # do some fancy processing here.... +# +# img = ImageProcessingFlask.render(img) +# +# +# #_, img_encoded = cv2.imencode('.jpg', img) +# #print ( img_encoded) +# +# cv2.imwrite( 'new.jpeg', img) +# +# +# #response_pickled = jsonpickle.encode(response) +# #return Response(response=response_pickled, status=200, mimetype="application/json") +# return send_file( 'new.jpeg', mimetype="image/jpeg", attachment_filename="new.jpeg", as_attachment=True) diff --git a/makeup.py b/makeup.py index 7d8dbee..a19fc7f 100644 --- a/makeup.py +++ b/makeup.py @@ -1,15 +1,28 @@ import cv2 -import os -import numpy as np -from skimage.filters import gaussian -from test import evaluate import argparse +import numpy as np +import matplotlib.pyplot as plt +from test import evaluate +from skimage.filters import gaussian -def parse_args(): - parse = argparse.ArgumentParser() - parse.add_argument('--img-path', default='imgs/116.jpg') - return parse.parse_args() +# plt.switch_backend("qt5Agg") +plt.switch_backend("tkAgg") + +SEGMENTS = { + "background": 0, + "skin": 1, + "r_brow": 2, + "l_brow": 3, + "r_eye": 4, + "l_eye": 5, + "nose": 10, + "u_lip": 12, + "l_lip": 13, + "neck": 14, + "hair": 17, + "hat": 18, +} def sharpen(img): @@ -31,8 +44,8 @@ def sharpen(img): return np.array(img_out, dtype=np.uint8) -def hair(image, parsing, part=17, color=[230, 50, 20]): - b, g, r = color #[10, 50, 250] # [10, 250, 10] +def hair(image, parsing, part=17, color=[230, 250, 250]): + b, g, r = color # [10, 50, 250] # [10, 250, 10] tar_color = np.zeros_like(image) tar_color[:, :, 0] = b tar_color[:, :, 1] = g @@ -44,7 +57,7 @@ def hair(image, parsing, part=17, color=[230, 50, 20]): if part == 12 or part == 13: image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2] else: - image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1] + image_hsv[:, :, :] = tar_hsv[:, :, :] changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR) @@ -52,56 +65,151 @@ def hair(image, parsing, part=17, color=[230, 50, 20]): changed = sharpen(changed) changed[parsing != part] = image[parsing != part] + return changed -if __name__ == '__main__': +def change_color(image, parsed_mask, **kwargs): + """ + + :param image: + :param parsed_mask: + :param query: + :return: + + Query (kwargs) example: + + { + 'background': (R, G, B) + 'neck': (R, G, B) + 'skin': (R, G, B) + 'hat': (R, G, B) + 'nose': (R, G, B) + 'l_eye': (R, G, B) + 'r_eye': (R, G, B) + 'u_lip': (R, G, B) + 'l_lip': (R, G, B) + 'l_brow': (R, G, B) + 'r_brow': (R, G, B) + } + """ + # Permuting color spaces form RGB to BGR + query = {SEGMENTS[key]: color for key, color in kwargs.items()} + image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + changed_image = None + if not query: + return image + for key, color in query.items(): + if not isinstance(color, tuple): + color = tuple(color.split(",")) + b, g, r = color + # Allocate mask + mask = np.zeros_like(image_hsv) + mask[:, :, 0] = b + mask[:, :, 1] = g + mask[:, :, 2] = r + target_hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV) + + if key == 12 or key == 13: + image_hsv[:, :, 0:2] = target_hsv[:, :, 0:2] + + elif key == 17: + image_hsv = sharpen(image_hsv) + + else: + image_hsv[:, :, 0:1] = target_hsv[:, :, 0:1] + + new_image = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR) + new_image[parsed_mask != key] = image[parsed_mask != key] + + image = new_image + changed_image = new_image.copy() + + return cv2.cvtColor(changed_image, cv2.COLOR_BGR2RGB) + + +if __name__ == "__main__": # 1 face # 11 teeth # 12 upper lip # 13 lower lip # 17 hair - args = parse_args() + """ + 0: 'background' + 14: 'neck' + 1: 'skin' + 18: 'hat' + 10: 'nose' + 5: 'l_eye' + 4: 'r_eye' + 12: 'u_lip' + 13: 'l_lip' + 3: 'l_brow' + 2: 'r_brow' + + + 8: 'l_ear' + 9: 'r_ear' + 10: 'mouth' + 17: 'hair' + 15: 'ear_r' + 16: 'neck_l' + 18: 'cloth' + 3: 'eye_g' + """ + parse = argparse.ArgumentParser() + parse.add_argument("--img-path", default="imgs/before.jpg") + args = parse.parse_args() table = { - 'hair': 17, - 'upper_lip': 12, - 'lower_lip': 13 + "hair": 17, + "upper_lip": 12, + "lower_lip": 13, } - image_path = args.img_path - cp = 'cp/79999_iter.pth' + image_path = "./imgs/6.jpg" + cp = "cp/79999_iter.pth" image = cv2.imread(image_path) ori = image.copy() parsing = evaluate(image_path, cp) parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST) - - parts = [table['hair'], table['upper_lip'], table['lower_lip']] - - colors = [[230, 50, 20], [20, 70, 180], [20, 70, 180]] - - for part, color in zip(parts, colors): - image = hair(image, parsing, part, color) - - cv2.imshow('image', cv2.resize(ori, (512, 512))) - cv2.imshow('color', cv2.resize(image, (512, 512))) - - cv2.waitKey(0) + parts = [ + table["hair"], + table["lower_lip"], + table["upper_lip"], + ] + + alpha_slider_max = 255 + title_window = "Linear Blend" + + #change_color(image, parsing, u_lip=(255, 0, 0), l_lip=(255, 0, 0)) + for i in range(1): + image = cv2.imread(image_path) + + lips = np.random.randint(1, 255, (3)) + hair_ = np.random.randint(1, 255, (3)) + colors = np.array([hair_, lips, lips]) + + for part, color in zip(parts, colors): + image = hair(image, parsing, part, np.array([0,0,0])) + + # kernel = np.ones((5, 5), np.float32) / 25 + # dst = cv.filter2D(image, -1, kernel) + dst = cv2.bilateralFilter(image, 30, 75, 75) + + img = np.hstack((ori, dst)) + plt.imshow(cv2.cvtColor(cv2.resize(img, (2048, 1024)), cv2.COLOR_BGR2RGB)) + plt.show() + # cv2.imwrite("makeup.jpg", cv2.resize(img, (1536, 512))) + + # cv2.imshow('color', cv2.resize(image, (512, 512))) + # cv2.imwrite('image_1.jpg', cv2.resize(ori, (512, 512))) + # cv2.imwrite('makeup.jpg', cv2.resize(img, (1536, 512))) + + k = cv2.waitKey(0) & 0xFF + if k == 27: + print("killed") cv2.destroyAllWindows() - - - - - - - - - - - - - - - diff --git a/model.py b/model.py index 040f41f..58a4336 100644 --- a/model.py +++ b/model.py @@ -8,18 +8,21 @@ import torchvision from resnet import Resnet18 + # from modules.bn import InPlaceABNSync as BatchNorm2d class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): super(ConvBNReLU, self).__init__() - self.conv = nn.Conv2d(in_chan, - out_chan, - kernel_size = ks, - stride = stride, - padding = padding, - bias = False) + self.conv = nn.Conv2d( + in_chan, + out_chan, + kernel_size=ks, + stride=stride, + padding=padding, + bias=False, + ) self.bn = nn.BatchNorm2d(out_chan) self.init_weight() @@ -32,7 +35,9 @@ def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) + class BiSeNetOutput(nn.Module): def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): @@ -50,7 +55,8 @@ def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] @@ -68,7 +74,7 @@ class AttentionRefinementModule(nn.Module): def __init__(self, in_chan, out_chan, *args, **kwargs): super(AttentionRefinementModule, self).__init__() self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) - self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) self.bn_atten = nn.BatchNorm2d(out_chan) self.sigmoid_atten = nn.Sigmoid() self.init_weight() @@ -86,7 +92,8 @@ def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) class ContextPath(nn.Module): @@ -110,16 +117,16 @@ def forward(self, x): avg = F.avg_pool2d(feat32, feat32.size()[2:]) avg = self.conv_avg(avg) - avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + avg_up = F.interpolate(avg, (H32, W32), mode="nearest") feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg_up - feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode="nearest") feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up - feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode="nearest") feat16_up = self.conv_head16(feat16_up) return feat8, feat16_up, feat32_up # x8, x8, x16 @@ -128,7 +135,8 @@ def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] @@ -163,7 +171,8 @@ def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] @@ -181,18 +190,12 @@ class FeatureFusionModule(nn.Module): def __init__(self, in_chan, out_chan, *args, **kwargs): super(FeatureFusionModule, self).__init__() self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) - self.conv1 = nn.Conv2d(out_chan, - out_chan//4, - kernel_size = 1, - stride = 1, - padding = 0, - bias = False) - self.conv2 = nn.Conv2d(out_chan//4, - out_chan, - kernel_size = 1, - stride = 1, - padding = 0, - bias = False) + self.conv1 = nn.Conv2d( + out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False + ) self.relu = nn.ReLU(inplace=True) self.sigmoid = nn.Sigmoid() self.init_weight() @@ -213,7 +216,8 @@ def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params = [], [] @@ -248,22 +252,29 @@ def forward(self, x): feat_out16 = self.conv_out16(feat_cp8) feat_out32 = self.conv_out32(feat_cp16) - feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) - feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) - feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + feat_out = F.interpolate(feat_out, (H, W), mode="bilinear", align_corners=True) + feat_out16 = F.interpolate( + feat_out16, (H, W), mode="bilinear", align_corners=True + ) + feat_out32 = F.interpolate( + feat_out32, (H, W), mode="bilinear", align_corners=True + ) return feat_out, feat_out16, feat_out32 def init_weight(self): for ly in self.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) - if not ly.bias is None: nn.init.constant_(ly.bias, 0) + if not ly.bias is None: + nn.init.constant_(ly.bias, 0) def get_params(self): wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] for name, child in self.named_children(): child_wd_params, child_nowd_params = child.get_params() - if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + if isinstance(child, FeatureFusionModule) or isinstance( + child, BiSeNetOutput + ): lr_mul_wd_params += child_wd_params lr_mul_nowd_params += child_nowd_params else: @@ -276,8 +287,8 @@ def get_params(self): net = BiSeNet(19) net.cuda() net.eval() - in_ten = torch.randn(16, 3, 640, 480).cuda() + in_ten = torch.randn(2, 3, 640, 480).cuda() out, out16, out32 = net(in_ten) print(out.shape) - net.get_params() + print(net.get_params()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ba72307 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch +torchvision +flask +numpy +opencv-python +matplotlib +scikit-image +jsonpickle \ No newline at end of file diff --git a/resnet.py b/resnet.py index aa2bf95..e74be59 100644 --- a/resnet.py +++ b/resnet.py @@ -8,13 +8,14 @@ # from modules.bn import InPlaceABNSync as BatchNorm2d -resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' +resnet18_url = "https://download.pytorch.org/models/resnet18-5c106cde.pth" def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) class BasicBlock(nn.Module): @@ -28,10 +29,9 @@ def __init__(self, in_chan, out_chan, stride=1): self.downsample = None if in_chan != out_chan or stride != 1: self.downsample = nn.Sequential( - nn.Conv2d(in_chan, out_chan, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_chan), - ) + ) def forward(self, x): residual = self.conv1(x) @@ -50,7 +50,7 @@ def forward(self, x): def create_layer_basic(in_chan, out_chan, bnum, stride=1): layers = [BasicBlock(in_chan, out_chan, stride=stride)] - for i in range(bnum-1): + for i in range(bnum - 1): layers.append(BasicBlock(out_chan, out_chan, stride=1)) return nn.Sequential(*layers) @@ -58,8 +58,7 @@ def create_layer_basic(in_chan, out_chan, bnum, stride=1): class Resnet18(nn.Module): def __init__(self): super(Resnet18, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) @@ -74,16 +73,17 @@ def forward(self, x): x = self.maxpool(x) x = self.layer1(x) - feat8 = self.layer2(x) # 1/8 - feat16 = self.layer3(feat8) # 1/16 - feat32 = self.layer4(feat16) # 1/32 + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 return feat8, feat16, feat32 def init_weight(self): state_dict = modelzoo.load_url(resnet18_url) self_state_dict = self.state_dict() for k, v in state_dict.items(): - if 'fc' in k: continue + if "fc" in k: + continue self_state_dict.update({k: v}) self.load_state_dict(self_state_dict) @@ -94,7 +94,7 @@ def get_params(self): wd_params.append(module.weight) if not module.bias is None: nowd_params.append(module.bias) - elif isinstance(module, nn.BatchNorm2d): + elif isinstance(module, nn.BatchNorm2d): nowd_params += list(module.parameters()) return wd_params, nowd_params diff --git a/test.py b/test.py index 7360a0c..0ca479b 100644 --- a/test.py +++ b/test.py @@ -10,24 +10,48 @@ import torchvision.transforms as transforms import cv2 +CUDA = False -def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): +def vis_parsing_maps( + im, parsing_anno, stride, save_im=True, save_path="output/parsing_map_on_im.jpg" +): # Colors for all 20 parts - part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], - [255, 0, 85], [255, 0, 170], - [0, 255, 0], [85, 255, 0], [170, 255, 0], - [0, 255, 85], [0, 255, 170], - [0, 0, 255], [85, 0, 255], [170, 0, 255], - [0, 85, 255], [0, 170, 255], - [255, 255, 0], [255, 255, 85], [255, 255, 170], - [255, 0, 255], [255, 85, 255], [255, 170, 255], - [0, 255, 255], [85, 255, 255], [170, 255, 255]] + part_colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 0, 85], + [255, 0, 170], + [0, 255, 0], + [85, 255, 0], + [170, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [0, 85, 255], + [0, 170, 255], + [255, 255, 0], + [255, 255, 85], + [255, 255, 170], + [255, 0, 255], + [255, 85, 255], + [255, 170, 255], + [0, 255, 255], + [85, 255, 255], + [170, 255, 255], + ] im = np.array(im) vis_im = im.copy().astype(np.uint8) vis_parsing_anno = parsing_anno.copy().astype(np.uint8) - vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) - vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + vis_parsing_anno = cv2.resize( + vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST + ) + vis_parsing_anno_color = ( + np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + ) num_of_class = np.max(vis_parsing_anno) @@ -37,47 +61,55 @@ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_res vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) # print(vis_parsing_anno_color.shape, vis_im.shape) - vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + vis_im = cv2.addWeighted( + cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0 + ) # Save result or not if save_im: - cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno) + cv2.imwrite(save_path[:-4] + ".png", vis_parsing_anno) cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) return vis_parsing_anno # return vis_im -def evaluate(image_path='./imgs/116.jpg', cp='cp/79999_iter.pth'): +def evaluate(image_path="./imgs/116.jpg", cp="cp/79999_iter.pth"): # if not os.path.exists(respth): # os.makedirs(respth) n_classes = 19 net = BiSeNet(n_classes=n_classes) - net.cuda() - net.load_state_dict(torch.load(cp)) + if CUDA: + net.cuda() + net.load_state_dict(torch.load(cp, map_location=torch.device("cpu"))) net.eval() - to_tensor = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) + to_tensor = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) with torch.no_grad(): + + #img = Image.fromarray(image_path) img = Image.open(image_path) + image = img.resize((512, 512), Image.BILINEAR) img = to_tensor(image) img = torch.unsqueeze(img, 0) - img = img.cuda() + if CUDA: + img = img.cuda() out = net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) - # print(parsing) - # print(np.unique(parsing)) + print(parsing) + print(np.unique(parsing)) - # vis_parsing_maps(image, parsing, stride=1, save_im=False, save_path=osp.join(respth, dspth)) + vis_parsing_maps(image, parsing, stride=1) return parsing -if __name__ == "__main__": - evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img/116.jpg', cp='79999_iter.pth') - +if __name__ == "__main__": + evaluate(image_path="./imgs/before.jpg", cp="./cp/79999_iter.pth")