Skip to content

Initial XPU support for sglang diffusion#1

Open
xiangyuT wants to merge 43 commits into
mainfrom
xpu_0122
Open

Initial XPU support for sglang diffusion#1
xiangyuT wants to merge 43 commits into
mainfrom
xpu_0122

Conversation

@xiangyuT
Copy link
Copy Markdown
Collaborator

@xiangyuT xiangyuT commented Feb 11, 2026

Motivation

Add Intel XPU (GPU) support to SGLang Diffusion, enabling image and video generation on Intel GPUs (e.g., Intel Arc Pro B60).

Tested Models

Model Type
Wan-AI/Wan2.1-T2V-1.3B-Diffusers Text-to-Video
Wan-AI/Wan2.2-TI2V-5B-Diffusers Text/Image-to-Video
black-forest-labs/FLUX.1-dev Text-to-Image
black-forest-labs/FLUX.2-klein-4B Text-to-Image
black-forest-labs/FLUX.2-klein-9B Text-to-Image
Tongyi-MAI/Z-Image-Turbo Text-to-Image

Modifications

1. XPU Platform Abstraction

  • New file: platforms/xpu.pyXpuPlatform with device capability detection, memory management, and attention backend selection.
  • Updated platforms/__init__.py and platforms/interface.py to register XPU in the platform auto-detection chain and add the XPU_FLASH_ATTN attention backend enum.

2. Distributed Communication (XCCL)

  • New file: device_communicators/xpu_communicator.pyXpuCommunicator using PyTorch XCCL backend for Intel XPU collective operations.
  • Updated distributed_init_method to select XCCL backend when running on XPU.

3. XPU Flash Attention Backend

  • New file: attention/backends/xpu_flash_attn.pyXpuFlashAttentionBackend / XpuFlashAttentionImpl.
  • Falls back to torch.nn.functional.scaled_dot_product_attention when sgl-kernel is unavailable or the device does not support flash attention. Also can be configured by setting --attention-backend TORCH_SDPA or --attention-backend XPU_FLASH_ATTN.

4. XPU-Specific Fixes in Hot Paths

Accuracy Tests

All tested models produce visually correct outputs on Intel Arc Pro B60 GPUs. Sample outputs attached in comments below.

Results

  • Wan-AI/Wan2.1-T2V-1.3B-Diffusers
sglang generate --prompt 'A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest' --save-output --log-level debug --output-path outputs --model-path "/llm/models/Wan2.1-T2V-1.3B-Diffusers" --num-gpus 2 --tp-size 2 --height 480 --width 832 --pin-cpu-memory --dit-cpu-offload --text-encoder-cpu-offload --vae-cpu-offload --vae-tiling
Wan-AI.Wan2.1-T2V-1.3B-Diffusers.mp4
  • black-forest-labs/FLUX.2-klein 4B
sglang generate --model-path "/llm/models/FLUX.2-klein-4B"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0   --profile --vae-cpu-offload --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 2 --tp-size 2
black-forest-labs:FLUX 2-klein 4B
  • black-forest-labs/FLUX.2-klein 9B
sglang generate --model-path "/llm/models/FLUX.2-klein-9B"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0   --profile --vae-cpu-offload --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 2 --tp-size 2
black-forest-labs:FLUX 2-klein 9B
  • Tongyi-MAI/Z-Image-Turbo
sglang generate --model-path "/llm/models/Z-Image-Turbo"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0   --profile --vae-cpu-offload --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 2 --tp-size 2
Tongyi-MAI:Z-Image-Turbo
  • black-forest-labs/FLUX.1-dev
sglang generate --model-path "/llm/models/FLUX.1-dev"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0 --vae-cpu-offload --vae-tiling  --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 4 --tp-size 4 --attention-backend TORCH_SDPA
image
  • Wan-AI/Wan2.2-TI2V-5B-Diffusers
sglang generate --prompt 'A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest' --save-output --log-level debug --output-path outputs --model-path "/llm/models/Wan2.2-TI2V-5B-Diffusers" --num-gpus 2 --tp-size 2 --height 480 --width 832 --pin-cpu-memory --dit-cpu-offload --text-encoder-cpu-offload --vae-cpu-offload --vae-tiling --attention-backend TORCH_SDPA
wan2.2_5b.mp4

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

xiangyuT and others added 30 commits January 9, 2026 14:41
Resolved 11 merge conflicts while maintaining Intel XPU support:
- pyproject.toml: Combined both dependency sets
- group_coordinator.py: Use current_platform.get_local_torch_device()
- parallel_state.py: Added XPU to device_id exclusion list
- layernorm.py: Kept both forward_xpu() and forward_npu()
- triton_ops.py: Use torch.get_device_module() (works for XPU)
- component_loader.py: Accept remote's directory refactoring
- gpu_worker.py: Use torch.get_device_module() for device ops
- clip.py: Kept XPU attention workaround alongside ROCm/MUSA
- platforms/__init__.py: Kept both XPU and NPU plugins
- platforms/interface.py: Added both XPU and NPU in get_device()
- layerwise_offload.py: Preserved _is_xpu flag with new API
- xpu.py: Added get_local_torch_device() method
Replace unguarded torch.cuda.Event() and torch.cuda.current_stream().wait_event()
with device-agnostic calls via torch.get_device_module(), which correctly returns
torch.xpu on XPU devices.
@Jasonzzt
Copy link
Copy Markdown
Collaborator

Jasonzzt commented Feb 27, 2026

Results

  • Wan-AI/Wan2.1-T2V-1.3B-Diffusers
sglang generate --prompt 'A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest' --save-output --log-level debug --output-path outputs --model-path "/llm/models/Wan2.1-T2V-1.3B-Diffusers" --num-gpus 2 --tp-size 2 --height 480 --width 832 --pin-cpu-memory --dit-cpu-offload --text-encoder-cpu-offload --vae-cpu-offload --vae-tiling
Wan-AI.Wan2.1-T2V-1.3B-Diffusers.mp4
  • black-forest-labs/FLUX.2-klein 4B
sglang generate --model-path "/llm/models/FLUX.2-klein-4B"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0   --profile --vae-cpu-offload --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 2 --tp-size 2
black-forest-labs:FLUX 2-klein 4B
  • black-forest-labs/FLUX.2-klein 9B
sglang generate --model-path "/llm/models/FLUX.2-klein-9B"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0   --profile --vae-cpu-offload --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 2 --tp-size 2
black-forest-labs:FLUX 2-klein 9B
  • Tongyi-MAI/Z-Image-Turbo
sglang generate --model-path "/llm/models/Z-Image-Turbo"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0   --profile --vae-cpu-offload --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 2 --tp-size 2
Tongyi-MAI:Z-Image-Turbo

@xiangyuT
Copy link
Copy Markdown
Collaborator Author

xiangyuT commented Feb 27, 2026

sglang generate --model-path "/llm/models/FLUX.1-dev"   --prompt "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest"   --seed 0 --vae-cpu-offload --vae-tiling  --pin-cpu-memory --text-encoder-cpu-offload --num-gpus 4 --tp-size 4 --attention-backend TORCH_SDPA
image
sglang generate --prompt 'A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest' --save-output --log-level debug --output-path outputs --model-path "/llm/models/Wan2.2-TI2V-5B-Diffusers" --num-gpus 2 --tp-size 2 --height 480 --width 832 --pin-cpu-memory --dit-cpu-offload --text-encoder-cpu-offload --vae-cpu-offload --vae-tiling --attention-backend TORCH_SDPA
wan2.2_5b.mp4

@xiangyuT xiangyuT requested a review from Copilot February 27, 2026 07:14
@xiangyuT xiangyuT marked this pull request as ready for review February 27, 2026 07:14
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds initial Intel XPU support to the multimodal diffusion runtime by introducing an XPU platform abstraction, wiring up distributed communication via XCCL, and adding several XPU-specific fallbacks/workarounds across attention, normalization, and profiling.

Changes:

  • Introduces XpuPlatform and platform auto-detection; adds XPU enum/backends and routes distributed backend to XCCL.
  • Adds an XPU-specific DeviceCommunicator and updates group coordination / distributed init for XPU.
  • Adds XPU handling in a few hot paths (CLIP SDPA masking, RMSNorm, profiler activities/sync, AVG-reduce workarounds).

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
python/sglang/multimodal_gen/utils.py Switches stream tracking patching to be platform-aware (but currently breaks XPU stream caching due to an early return).
python/sglang/multimodal_gen/runtime/utils/profiler.py Adds XPU profiling activity + XPU synchronize support.
python/sglang/multimodal_gen/runtime/platforms/xpu.py New XPU platform abstraction (device info, memory queries, attention backend selection, communicator class).
python/sglang/multimodal_gen/runtime/platforms/interface.py Adds PlatformEnum.XPU, AttentionBackendEnum.XPU_FLASH_ATTN, XPU device creation, and XPU distributed backend string (xccl).
python/sglang/multimodal_gen/runtime/platforms/init.py Adds XPU plugin detection + resolution order.
python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py Uses non-batched P2P ops for XPU/XCCL compatibility in halo exchange.
python/sglang/multimodal_gen/runtime/models/encoders/clip.py Adds XPU-specific SDPA masking logic to avoid is_causal + attn_mask conflict.
python/sglang/multimodal_gen/runtime/layers/layernorm.py Adds forward_xpu() and XPU AVG-reduce workaround (but currently references CUDA-only imports).
python/sglang/multimodal_gen/runtime/layers/custom_op.py Adds forward_xpu() dispatch path (defaults to native).
python/sglang/multimodal_gen/runtime/distributed/parallel_state.py Switches NCCL→XCCL on XPU; sets XPU device before init_process_group; adjusts device_id passing.
python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py Instantiates XpuCommunicator and adds XPU stream context in graph_capture().
python/sglang/multimodal_gen/runtime/distributed/device_communicators/xpu_communicator.py New XPU communicator using functional collectives for all-to-all workaround (but all_gather/gather tensor shapes are currently incorrect).
python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py Uses SUM+divide instead of AVG all-reduce on XPU due to XCCL AVG issues.
python/sglang/jit_kernel/diffusion/triton/scale_shift.py Allows XPU tensors into a Triton-based kernel (needs a safe XPU fallback).
python/pyproject_xpu.toml Adds diffusion optional dependency group for the XPU build.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

activities.append(torch.profiler.ProfilerActivity.CUDA)
if current_platform.is_npu():
activities.append(torch_npu.profiler.ProfilerActivity.NPU)
if hasattr(torch, "xpu") and torch.xpu.is_available():
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.profiler.ProfilerActivity.XPU is referenced unconditionally when torch.xpu.is_available(). In some PyTorch builds the XPU enum value may not exist even if torch.xpu is present, which would raise AttributeError during profiler construction. Guard this with hasattr(torch.profiler.ProfilerActivity, "XPU") (or a try/except) before appending.

Suggested change
if hasattr(torch, "xpu") and torch.xpu.is_available():
if (
hasattr(torch, "xpu")
and torch.xpu.is_available()
and hasattr(torch.profiler.ProfilerActivity, "XPU")
):

Copilot uses AI. Check for mistakes.
Comment on lines 132 to +138
if torch.cuda.is_available():
torch.cuda.synchronize()
if current_platform.is_npu():
elif current_platform.is_npu():
torch.npu.synchronize()
export_trace = False # set to false because our internal torch_npu.profiler will generate trace file
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
torch.xpu.synchronize()
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop() sync selection is driven by torch.cuda.is_available() first, and uses elif for NPU/XPU. If CUDA is available in the environment while running on NPU/XPU (possible given platform resolution now prefers XPU before CUDA), this will synchronize CUDA and skip the actual device synchronize. Prefer selecting the synchronize call based on current_platform (cuda_alike vs npu vs xpu) rather than CUDA availability.

Copilot uses AI. Check for mistakes.
Comment on lines +148 to +170
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""XPU-specific implementation of RMSNorm.

Uses sgl_kernel if available, otherwise falls back to native implementation.
"""
shape = x.shape
x = x.contiguous()
x = x.view(-1, shape[-1])
if x.dtype == torch.float or self.variance_size_override is not None:
return self.forward_native(x.view(shape), residual)
if residual is not None:
try:
fused_add_rmsnorm(x, residual.view(-1, shape[-1]), self.weight.data, self.variance_epsilon)
return x.view(shape), residual
except Exception:
return self.forward_native(x.view(shape), residual)
try:
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
out = out.view(shape)

This comment was marked as resolved.

Comment on lines +120 to +136
# Allocate output tensor
output_tensor = torch.empty(
(self.world_size,) + input_size,
dtype=input_.dtype,
device=input_.device
)

# All-gather into tensor
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)

# Reshape to concatenate along specified dimension
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (self.world_size * input_size[dim],)
+ input_size[dim + 1:]
)

This comment was marked as resolved.

Comment on lines +163 to +170
output_tensor = torch.empty(
(self.world_size,) + input_size,
dtype=input_.dtype,
device=input_.device
)

# All-gather
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)

This comment was marked as resolved.

assert x.is_cuda and scale.is_cuda
assert (x.is_cuda and scale.is_cuda) or (x.is_xpu and scale.is_xpu)
assert x.is_contiguous()

Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fuse_scale_shift_kernel is implemented with Triton kernels, but the assertion now allows XPU tensors. If Triton is not available or not backend-compatible with XPU in the target environment, this will fail at runtime (either at import/compile time or kernel launch). Consider adding an explicit XPU fallback implementation (e.g., pure PyTorch x * (1 + scale*scale_constant) + shift) and keep the Triton path CUDA-only.

Suggested change
# XPU fallback: use pure PyTorch implementation and avoid Triton kernels.
if x.is_xpu:
B, L, C = x.shape
if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
B_s, num_frames, _, C_s = scale.shape
assert (
B_s == B and C_s == C
), "scale shape must be compatible with input tensor shape"
assert (
L % num_frames == 0
), "seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen = L // num_frames
# Expand [B, F, 1, C] to [B, F, frame_seqlen, C], then reshape to [B, L, C]
scale_expanded = (
scale.expand(B, num_frames, frame_seqlen, C).reshape(B, L, C)
)
shift_expanded = (
shift.expand(B, num_frames, frame_seqlen, C).reshape(B, L, C)
)
else:
scale_expanded = scale
shift_expanded = shift
return x * (1 + scale_expanded * scale_constant) + shift_expanded

Copilot uses AI. Check for mistakes.
Comment on lines +107 to +123
# For non-CUDA-alike platforms, return None
if not current_platform.is_cuda_alike():
return None

global _current_stream
if _current_stream is None:
# when this function is called before any stream is set,
# we return the default stream.
# On ROCm using the default 0 stream in combination with RCCL
# is hurting performance. Therefore creating a dedicated stream
# per process
_current_stream = (
torch.cuda.Stream()
if current_platform.is_rocm()
else torch.cuda.current_stream()
)
if current_platform.is_rocm():
# On ROCm using the default 0 stream in combination with RCCL
# is hurting performance. Therefore creating a dedicated stream
# per process
_current_stream = torch.cuda.Stream()
elif current_platform.is_xpu():
_current_stream = torch.xpu.Stream()
else:
_current_stream = torch.cuda.current_stream()
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_stream() returns early for any non-cuda_alike platform, which includes XPU. That makes the elif current_platform.is_xpu(): _current_stream = torch.xpu.Stream() branch unreachable, and any XPU callers will always get None. Consider changing the guard to allow XPU (or introducing a separate is_stream_supported() check) so the XPU stream cache is actually used.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants