Skip to content

[Feature requests] Add xpu support #1

@WhitePr

Description

@WhitePr

This is no difficult task; simply add a command to switch the inference device to XPU.
I've tested it on my Intel Arc B580 and it functions correctly.

this is diff file

diff --git a/vsdrba_distilled/__init__.py b/vsdrba_distilled/__init__.py
index 9e28965..7b54953 100644
--- a/vsdrba_distilled/__init__.py
+++ b/vsdrba_distilled/__init__.py
@@ -100,7 +100,7 @@ def drba_distilled(
     if clip.num_frames < 2:
         raise vs.Error("drba: clip's number of frames must be at least 2")
 
-    if not torch.cuda.is_available():
+    if not torch.xpu.is_available():
         raise vs.Error("drba: CUDA is not available")
 
     if model not in models:
@@ -138,7 +138,7 @@ def drba_distilled(
     fp16 = clip.format.bits_per_sample == 16
     dtype = torch.half if fp16 else torch.float
 
-    device = torch.device("cuda", device_index)
+    device = torch.device("xpu", device_index)
 
     modulo = 64
     match model:
@@ -182,7 +182,7 @@ def drba_distilled(
                     + f"_{dimensions}"
                     + f"_{'fp16' if fp16 else 'fp32'}"
                     + f"_scale-{scale}"
-                    + f"_{torch.cuda.get_device_name(device)}"
+                    + f"_{torch.xpu.get_device_name(device)}"
                     + f"_trt-{tensorrt.__version__}"
                     + (f"_workspace-{trt_workspace_size}" if trt_workspace_size > 0 else "")
                     + (f"_aux-{trt_max_aux_streams}" if trt_max_aux_streams is not None else "")
@@ -236,19 +236,19 @@ def drba_distilled(
     else:
         flownet = init_module(model_name, IFNet, scale, device, dtype)
 
-    inf_stream = torch.cuda.Stream(device)
-    inf_f2t_stream = torch.cuda.Stream(device)
-    inf_t2f_stream = torch.cuda.Stream(device)
+    inf_stream = torch.xpu.Stream(device)
+    inf_f2t_stream = torch.xpu.Stream(device)
+    inf_t2f_stream = torch.xpu.Stream(device)
 
     inf_stream_lock = Lock()
     inf_f2t_stream_lock = Lock()
     inf_t2f_stream_lock = Lock()
 
-    torch.cuda.current_stream(device).synchronize()
+    torch.xpu.current_stream(device).synchronize()
 
     @torch.inference_mode()
     def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
-        with inf_f2t_stream_lock, torch.cuda.stream(inf_f2t_stream):
+        with inf_f2t_stream_lock, torch.xpu.stream(inf_f2t_stream):
             # t = n * factor_den % factor_num / factor_num
             t = 0.5 + n * factor_den % factor_num / factor_num
 
@@ -265,12 +265,12 @@ def drba_distilled(
 
             inf_f2t_stream.synchronize()
 
-        with inf_stream_lock, torch.cuda.stream(inf_stream):
+        with inf_stream_lock, torch.xpu.stream(inf_stream):
             output = flownet(img0, img1, img2, torch.tensor([t], device=img0.device, dtype=img0.dtype))
 
             inf_stream.synchronize()
 
-        with inf_t2f_stream_lock, torch.cuda.stream(inf_t2f_stream):
+        with inf_t2f_stream_lock, torch.xpu.stream(inf_t2f_stream):
             if need_pad:
                 output = output[:, :, :h, :w]
 
@@ -333,7 +333,7 @@ def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor:
     ).unsqueeze(0)
 
 
-def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame, stream: torch.cuda.Stream) -> vs.VideoFrame:
+def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame, stream: torch.xpu.Stream) -> vs.VideoFrame:
     tensor = tensor.squeeze(0).detach()
     tensors = [tensor[plane].to("cpu", non_blocking=True) for plane in range(frame.format.num_planes)]
 
diff --git a/vsdrba_distilled/distilDRBA.py b/vsdrba_distilled/distilDRBA.py
index 08966a2..1c4d7fb 100644
--- a/vsdrba_distilled/distilDRBA.py
+++ b/vsdrba_distilled/distilDRBA.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
 
 from .warplayer import warp
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device("xpu" if torch.xpu.is_available() else "cpu")
 
 
 def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
diff --git a/vsdrba_distilled/distilDRBA_v2_lite.py b/vsdrba_distilled/distilDRBA_v2_lite.py
index a1f5095..53d90a2 100644
--- a/vsdrba_distilled/distilDRBA_v2_lite.py
+++ b/vsdrba_distilled/distilDRBA_v2_lite.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
 
 from .warplayer import warp
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device("xpu" if torch.xpu.is_available() else "cpu")
 
 
 def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions