Conversation
Resolved 11 merge conflicts while maintaining Intel XPU support: - pyproject.toml: Combined both dependency sets - group_coordinator.py: Use current_platform.get_local_torch_device() - parallel_state.py: Added XPU to device_id exclusion list - layernorm.py: Kept both forward_xpu() and forward_npu() - triton_ops.py: Use torch.get_device_module() (works for XPU) - component_loader.py: Accept remote's directory refactoring - gpu_worker.py: Use torch.get_device_module() for device ops - clip.py: Kept XPU attention workaround alongside ROCm/MUSA - platforms/__init__.py: Kept both XPU and NPU plugins - platforms/interface.py: Added both XPU and NPU in get_device() - layerwise_offload.py: Preserved _is_xpu flag with new API - xpu.py: Added get_local_torch_device() method
Replace unguarded torch.cuda.Event() and torch.cuda.current_stream().wait_event() with device-agnostic calls via torch.get_device_module(), which correctly returns torch.xpu on XPU devices.
There was a problem hiding this comment.
Pull request overview
Adds initial Intel XPU support to the multimodal diffusion runtime by introducing an XPU platform abstraction, wiring up distributed communication via XCCL, and adding several XPU-specific fallbacks/workarounds across attention, normalization, and profiling.
Changes:
- Introduces
XpuPlatformand platform auto-detection; adds XPU enum/backends and routes distributed backend to XCCL. - Adds an XPU-specific
DeviceCommunicatorand updates group coordination / distributed init for XPU. - Adds XPU handling in a few hot paths (CLIP SDPA masking, RMSNorm, profiler activities/sync, AVG-reduce workarounds).
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| python/sglang/multimodal_gen/utils.py | Switches stream tracking patching to be platform-aware (but currently breaks XPU stream caching due to an early return). |
| python/sglang/multimodal_gen/runtime/utils/profiler.py | Adds XPU profiling activity + XPU synchronize support. |
| python/sglang/multimodal_gen/runtime/platforms/xpu.py | New XPU platform abstraction (device info, memory queries, attention backend selection, communicator class). |
| python/sglang/multimodal_gen/runtime/platforms/interface.py | Adds PlatformEnum.XPU, AttentionBackendEnum.XPU_FLASH_ATTN, XPU device creation, and XPU distributed backend string (xccl). |
| python/sglang/multimodal_gen/runtime/platforms/init.py | Adds XPU plugin detection + resolution order. |
| python/sglang/multimodal_gen/runtime/models/vaes/parallel/wan_dist_utils.py | Uses non-batched P2P ops for XPU/XCCL compatibility in halo exchange. |
| python/sglang/multimodal_gen/runtime/models/encoders/clip.py | Adds XPU-specific SDPA masking logic to avoid is_causal + attn_mask conflict. |
| python/sglang/multimodal_gen/runtime/layers/layernorm.py | Adds forward_xpu() and XPU AVG-reduce workaround (but currently references CUDA-only imports). |
| python/sglang/multimodal_gen/runtime/layers/custom_op.py | Adds forward_xpu() dispatch path (defaults to native). |
| python/sglang/multimodal_gen/runtime/distributed/parallel_state.py | Switches NCCL→XCCL on XPU; sets XPU device before init_process_group; adjusts device_id passing. |
| python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py | Instantiates XpuCommunicator and adds XPU stream context in graph_capture(). |
| python/sglang/multimodal_gen/runtime/distributed/device_communicators/xpu_communicator.py | New XPU communicator using functional collectives for all-to-all workaround (but all_gather/gather tensor shapes are currently incorrect). |
| python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py | Uses SUM+divide instead of AVG all-reduce on XPU due to XCCL AVG issues. |
| python/sglang/jit_kernel/diffusion/triton/scale_shift.py | Allows XPU tensors into a Triton-based kernel (needs a safe XPU fallback). |
| python/pyproject_xpu.toml | Adds diffusion optional dependency group for the XPU build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| activities.append(torch.profiler.ProfilerActivity.CUDA) | ||
| if current_platform.is_npu(): | ||
| activities.append(torch_npu.profiler.ProfilerActivity.NPU) | ||
| if hasattr(torch, "xpu") and torch.xpu.is_available(): |
There was a problem hiding this comment.
torch.profiler.ProfilerActivity.XPU is referenced unconditionally when torch.xpu.is_available(). In some PyTorch builds the XPU enum value may not exist even if torch.xpu is present, which would raise AttributeError during profiler construction. Guard this with hasattr(torch.profiler.ProfilerActivity, "XPU") (or a try/except) before appending.
| if hasattr(torch, "xpu") and torch.xpu.is_available(): | |
| if ( | |
| hasattr(torch, "xpu") | |
| and torch.xpu.is_available() | |
| and hasattr(torch.profiler.ProfilerActivity, "XPU") | |
| ): |
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| if current_platform.is_npu(): | ||
| elif current_platform.is_npu(): | ||
| torch.npu.synchronize() | ||
| export_trace = False # set to false because our internal torch_npu.profiler will generate trace file | ||
| elif hasattr(torch, 'xpu') and torch.xpu.is_available(): | ||
| torch.xpu.synchronize() |
There was a problem hiding this comment.
stop() sync selection is driven by torch.cuda.is_available() first, and uses elif for NPU/XPU. If CUDA is available in the environment while running on NPU/XPU (possible given platform resolution now prefers XPU before CUDA), this will synchronize CUDA and skip the actual device synchronize. Prefer selecting the synchronize call based on current_platform (cuda_alike vs npu vs xpu) rather than CUDA availability.
| def forward_xpu( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| """XPU-specific implementation of RMSNorm. | ||
|
|
||
| Uses sgl_kernel if available, otherwise falls back to native implementation. | ||
| """ | ||
| shape = x.shape | ||
| x = x.contiguous() | ||
| x = x.view(-1, shape[-1]) | ||
| if x.dtype == torch.float or self.variance_size_override is not None: | ||
| return self.forward_native(x.view(shape), residual) | ||
| if residual is not None: | ||
| try: | ||
| fused_add_rmsnorm(x, residual.view(-1, shape[-1]), self.weight.data, self.variance_epsilon) | ||
| return x.view(shape), residual | ||
| except Exception: | ||
| return self.forward_native(x.view(shape), residual) | ||
| try: | ||
| out = rmsnorm(x, self.weight.data, self.variance_epsilon) | ||
| out = out.view(shape) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
| # Allocate output tensor | ||
| output_tensor = torch.empty( | ||
| (self.world_size,) + input_size, | ||
| dtype=input_.dtype, | ||
| device=input_.device | ||
| ) | ||
|
|
||
| # All-gather into tensor | ||
| dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) | ||
|
|
||
| # Reshape to concatenate along specified dimension | ||
| output_tensor = output_tensor.movedim(0, dim) | ||
| output_tensor = output_tensor.reshape( | ||
| input_size[:dim] | ||
| + (self.world_size * input_size[dim],) | ||
| + input_size[dim + 1:] | ||
| ) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
| output_tensor = torch.empty( | ||
| (self.world_size,) + input_size, | ||
| dtype=input_.dtype, | ||
| device=input_.device | ||
| ) | ||
|
|
||
| # All-gather | ||
| dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
| assert x.is_cuda and scale.is_cuda | ||
| assert (x.is_cuda and scale.is_cuda) or (x.is_xpu and scale.is_xpu) | ||
| assert x.is_contiguous() | ||
|
|
There was a problem hiding this comment.
fuse_scale_shift_kernel is implemented with Triton kernels, but the assertion now allows XPU tensors. If Triton is not available or not backend-compatible with XPU in the target environment, this will fail at runtime (either at import/compile time or kernel launch). Consider adding an explicit XPU fallback implementation (e.g., pure PyTorch x * (1 + scale*scale_constant) + shift) and keep the Triton path CUDA-only.
| # XPU fallback: use pure PyTorch implementation and avoid Triton kernels. | |
| if x.is_xpu: | |
| B, L, C = x.shape | |
| if scale.dim() == 4: | |
| # scale/shift: [B, F, 1, C] | |
| B_s, num_frames, _, C_s = scale.shape | |
| assert ( | |
| B_s == B and C_s == C | |
| ), "scale shape must be compatible with input tensor shape" | |
| assert ( | |
| L % num_frames == 0 | |
| ), "seq_len must be divisible by num_frames for 4D scale/shift" | |
| frame_seqlen = L // num_frames | |
| # Expand [B, F, 1, C] to [B, F, frame_seqlen, C], then reshape to [B, L, C] | |
| scale_expanded = ( | |
| scale.expand(B, num_frames, frame_seqlen, C).reshape(B, L, C) | |
| ) | |
| shift_expanded = ( | |
| shift.expand(B, num_frames, frame_seqlen, C).reshape(B, L, C) | |
| ) | |
| else: | |
| scale_expanded = scale | |
| shift_expanded = shift | |
| return x * (1 + scale_expanded * scale_constant) + shift_expanded |
| # For non-CUDA-alike platforms, return None | ||
| if not current_platform.is_cuda_alike(): | ||
| return None | ||
|
|
||
| global _current_stream | ||
| if _current_stream is None: | ||
| # when this function is called before any stream is set, | ||
| # we return the default stream. | ||
| # On ROCm using the default 0 stream in combination with RCCL | ||
| # is hurting performance. Therefore creating a dedicated stream | ||
| # per process | ||
| _current_stream = ( | ||
| torch.cuda.Stream() | ||
| if current_platform.is_rocm() | ||
| else torch.cuda.current_stream() | ||
| ) | ||
| if current_platform.is_rocm(): | ||
| # On ROCm using the default 0 stream in combination with RCCL | ||
| # is hurting performance. Therefore creating a dedicated stream | ||
| # per process | ||
| _current_stream = torch.cuda.Stream() | ||
| elif current_platform.is_xpu(): | ||
| _current_stream = torch.xpu.Stream() | ||
| else: | ||
| _current_stream = torch.cuda.current_stream() |
There was a problem hiding this comment.
current_stream() returns early for any non-cuda_alike platform, which includes XPU. That makes the elif current_platform.is_xpu(): _current_stream = torch.xpu.Stream() branch unreachable, and any XPU callers will always get None. Consider changing the guard to allow XPU (or introducing a separate is_stream_supported() check) so the XPU stream cache is actually used.




Motivation
Add Intel XPU (GPU) support to SGLang Diffusion, enabling image and video generation on Intel GPUs (e.g., Intel Arc Pro B60).
Tested Models
Wan-AI/Wan2.1-T2V-1.3B-DiffusersWan-AI/Wan2.2-TI2V-5B-Diffusersblack-forest-labs/FLUX.1-devblack-forest-labs/FLUX.2-klein-4Bblack-forest-labs/FLUX.2-klein-9BTongyi-MAI/Z-Image-TurboModifications
1. XPU Platform Abstraction
platforms/xpu.py—XpuPlatformwith device capability detection, memory management, and attention backend selection.platforms/__init__.pyandplatforms/interface.pyto register XPU in the platform auto-detection chain and add theXPU_FLASH_ATTNattention backend enum.2. Distributed Communication (XCCL)
device_communicators/xpu_communicator.py—XpuCommunicatorusing PyTorch XCCL backend for Intel XPU collective operations.distributed_init_methodto select XCCL backend when running on XPU.3. XPU Flash Attention Backend
attention/backends/xpu_flash_attn.py—XpuFlashAttentionBackend/XpuFlashAttentionImpl.torch.nn.functional.scaled_dot_product_attentionwhen sgl-kernel is unavailable or the device does not support flash attention. Also can be configured by setting--attention-backend TORCH_SDPAor--attention-backend XPU_FLASH_ATTN.4. XPU-Specific Fixes in Hot Paths
Accuracy Tests
Results
Wan-AI.Wan2.1-T2V-1.3B-Diffusers.mp4
wan2.2_5b.mp4
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci