-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
This is no difficult task; simply add a command to switch the inference device to XPU.
I've tested it on my Intel Arc B580 and it functions correctly.
this is diff file
diff --git a/vsdrba_distilled/__init__.py b/vsdrba_distilled/__init__.py
index 9e28965..7b54953 100644
--- a/vsdrba_distilled/__init__.py
+++ b/vsdrba_distilled/__init__.py
@@ -100,7 +100,7 @@ def drba_distilled(
if clip.num_frames < 2:
raise vs.Error("drba: clip's number of frames must be at least 2")
- if not torch.cuda.is_available():
+ if not torch.xpu.is_available():
raise vs.Error("drba: CUDA is not available")
if model not in models:
@@ -138,7 +138,7 @@ def drba_distilled(
fp16 = clip.format.bits_per_sample == 16
dtype = torch.half if fp16 else torch.float
- device = torch.device("cuda", device_index)
+ device = torch.device("xpu", device_index)
modulo = 64
match model:
@@ -182,7 +182,7 @@ def drba_distilled(
+ f"_{dimensions}"
+ f"_{'fp16' if fp16 else 'fp32'}"
+ f"_scale-{scale}"
- + f"_{torch.cuda.get_device_name(device)}"
+ + f"_{torch.xpu.get_device_name(device)}"
+ f"_trt-{tensorrt.__version__}"
+ (f"_workspace-{trt_workspace_size}" if trt_workspace_size > 0 else "")
+ (f"_aux-{trt_max_aux_streams}" if trt_max_aux_streams is not None else "")
@@ -236,19 +236,19 @@ def drba_distilled(
else:
flownet = init_module(model_name, IFNet, scale, device, dtype)
- inf_stream = torch.cuda.Stream(device)
- inf_f2t_stream = torch.cuda.Stream(device)
- inf_t2f_stream = torch.cuda.Stream(device)
+ inf_stream = torch.xpu.Stream(device)
+ inf_f2t_stream = torch.xpu.Stream(device)
+ inf_t2f_stream = torch.xpu.Stream(device)
inf_stream_lock = Lock()
inf_f2t_stream_lock = Lock()
inf_t2f_stream_lock = Lock()
- torch.cuda.current_stream(device).synchronize()
+ torch.xpu.current_stream(device).synchronize()
@torch.inference_mode()
def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
- with inf_f2t_stream_lock, torch.cuda.stream(inf_f2t_stream):
+ with inf_f2t_stream_lock, torch.xpu.stream(inf_f2t_stream):
# t = n * factor_den % factor_num / factor_num
t = 0.5 + n * factor_den % factor_num / factor_num
@@ -265,12 +265,12 @@ def drba_distilled(
inf_f2t_stream.synchronize()
- with inf_stream_lock, torch.cuda.stream(inf_stream):
+ with inf_stream_lock, torch.xpu.stream(inf_stream):
output = flownet(img0, img1, img2, torch.tensor([t], device=img0.device, dtype=img0.dtype))
inf_stream.synchronize()
- with inf_t2f_stream_lock, torch.cuda.stream(inf_t2f_stream):
+ with inf_t2f_stream_lock, torch.xpu.stream(inf_t2f_stream):
if need_pad:
output = output[:, :, :h, :w]
@@ -333,7 +333,7 @@ def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor:
).unsqueeze(0)
-def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame, stream: torch.cuda.Stream) -> vs.VideoFrame:
+def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame, stream: torch.xpu.Stream) -> vs.VideoFrame:
tensor = tensor.squeeze(0).detach()
tensors = [tensor[plane].to("cpu", non_blocking=True) for plane in range(frame.format.num_planes)]
diff --git a/vsdrba_distilled/distilDRBA.py b/vsdrba_distilled/distilDRBA.py
index 08966a2..1c4d7fb 100644
--- a/vsdrba_distilled/distilDRBA.py
+++ b/vsdrba_distilled/distilDRBA.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
from .warplayer import warp
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device("xpu" if torch.xpu.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
diff --git a/vsdrba_distilled/distilDRBA_v2_lite.py b/vsdrba_distilled/distilDRBA_v2_lite.py
index a1f5095..53d90a2 100644
--- a/vsdrba_distilled/distilDRBA_v2_lite.py
+++ b/vsdrba_distilled/distilDRBA_v2_lite.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
from .warplayer import warp
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device("xpu" if torch.xpu.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels