From d957b2f6eb33ae32d0be7ad0735a3fea58e798fe Mon Sep 17 00:00:00 2001 From: xu20160924 Date: Sat, 24 May 2025 14:06:27 +0800 Subject: [PATCH] BiSeNet: compatible with platforms except CUDA --- utils.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/utils.py b/utils.py index d4faef1..57e3015 100644 --- a/utils.py +++ b/utils.py @@ -206,15 +206,25 @@ def mask_BiSeNet(crop, hair=False, hat=False, ): + device = ( + "cuda" + # Device for NVIDIA or AMD GPUs + if torch.cuda.is_available() + else "mps" + # Device for Apple Silicon (Metal Performance Shaders) + if torch.backends.mps.is_available() + else "cpu" + ) + with torch.no_grad(): bisenet = BiSeNet(n_classes=19) - bisenet.cuda() + bisenet.to(device) model_path = os.path.join(models_dir, 'bisenet', '79999_iter.pth') - bisenet.load_state_dict(torch.load(model_path)) + bisenet.load_state_dict(torch.load(model_path, map_location=device)) bisenet.eval() - crop_t = crop.permute(0,3,1,2).cuda().float() + crop_t = crop.permute(0,3,1,2).to(device).float() segms_t = bisenet(crop_t)[0].argmax(1).float() - + dic = { 'skin': 1, 'l_brow': 2, @@ -240,7 +250,7 @@ def mask_BiSeNet(crop, if k in dic and v: keep.append(dic[k]) - face_part_ids = torch.tensor(keep).cuda() + face_part_ids = torch.tensor(keep).to(device) segms_t = torch.sum(segms_t.repeat(len(face_part_ids), 1,1,1) == face_part_ids[...,None,None,None], axis=0).float() mask = segms_t.cpu() return mask @@ -299,7 +309,7 @@ def mask_jonathandinu(crop, skin=True, nose=True, eye_g=True, l_eye=True, r_eye= for k, v in locals().items(): if k in ids and v: keep.append(ids[k]) - face_part_ids = torch.tensor(keep).cuda() + face_part_ids = torch.tensor(keep).to(device) mask = torch.sum(labels.repeat(len(face_part_ids), 1,1,1) == face_part_ids[...,None,None,None], axis=0).float().cpu()