From 0fde10328ee0a2225843001454574a879cf8c70b Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 19 Mar 2026 16:36:07 +0800 Subject: [PATCH 1/3] feat: add xpu fp8 omni linear integration Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- comfy/quant_ops.py | 11 + comfy/xpu_fp8_linear.py | 197 +++++++++ comfy/xpu_quant_layout_ops.py | 71 +++ tests-unit/comfy_quant/test_xpu_fp8_linear.py | 410 ++++++++++++++++++ 4 files changed, 689 insertions(+) create mode 100644 comfy/xpu_fp8_linear.py create mode 100644 comfy/xpu_quant_layout_ops.py create mode 100644 tests-unit/comfy_quant/test_xpu_fp8_linear.py diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 42ee08fb22ca..ae5997871b44 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -40,6 +40,11 @@ class _CKNvfp4Layout: def register_layout_class(name, cls): pass + def register_layout_op(op, cls): + def decorator(func): + return func + return decorator + def get_layout_class(name): return None @@ -219,3 +224,9 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): "QUANT_ALGOS", "register_layout_op", ] + + +try: + import comfy.xpu_quant_layout_ops # noqa: F401 +except ImportError: + pass diff --git a/comfy/xpu_fp8_linear.py b/comfy/xpu_fp8_linear.py new file mode 100644 index 000000000000..328e9aa4e41e --- /dev/null +++ b/comfy/xpu_fp8_linear.py @@ -0,0 +1,197 @@ +import logging +import os +import threading +from typing import Optional + +import torch + + +log = logging.getLogger(__name__) + +_omni_linear = None +_omni_logged_first_use = False +_omni_fp8_failure_cache = set() +_omni_fp8_failure_cache_lock = threading.Lock() + +try: + from omni_xpu_kernel import linear as _omni_linear +except ImportError: + _omni_linear = None + + +def _env_enabled(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() not in {"0", "false", "no", "off", ""} + + +def _omni_fp8_enabled() -> bool: + return _env_enabled("COMFY_XPU_FP8_OMNI_ENABLE", True) + + +def _omni_fp8_log_enabled() -> bool: + return _env_enabled("COMFY_XPU_FP8_OMNI_LOG", False) + + +def _log_first_use(shape): + global _omni_logged_first_use + if not _omni_logged_first_use: + _omni_logged_first_use = True + log.info("[omni_xpu_kernel] First use in xpu_fp8_linear with input shape %s", shape) + + +def _log_fast_path_event(message: str, *args): + if _omni_fp8_log_enabled(): + log.info(message, *args) + + +def _is_primitive_creation_failure(error: RuntimeError) -> bool: + message = str(error).lower() + return "could not create a primitive" in message + + +def _primitive_failure_cache_key(input_tensor: torch.Tensor, qdata: torch.Tensor): + return (tuple(input_tensor.shape), tuple(qdata.shape), str(input_tensor.dtype), input_tensor.device.index) + + +def _log_bad_shape_event(reason: str, input_tensor: torch.Tensor, qdata: torch.Tensor, bias: Optional[torch.Tensor], error: Optional[RuntimeError] = None): + message = ( + "[omni_xpu_kernel] XPU FP8 bad shape %s input_shape=%s qdata_shape=%s dtype=%s device=%s has_bias=%s" + ) + args = [ + reason, + tuple(input_tensor.shape), + tuple(qdata.shape), + str(input_tensor.dtype), + str(input_tensor.device), + bias is not None, + ] + if error is not None: + message += " error=%s" + args.append(str(error)) + log.info(message, *args) + + +def _expand_weight_scale(scale, rows, device): + if scale is None: + return None + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=device, dtype=torch.float32) + scale = scale.to(device=device, dtype=torch.float32) + if scale.ndim == 0: + return scale.expand(rows).contiguous() + if scale.ndim == 1 and scale.numel() == rows: + return scale.contiguous() + return None + + +def _extract_qdata(weight): + return getattr(weight, "_qdata", None) + + +def _extract_params(weight): + return getattr(weight, "_params", None) + + +def _extract_layout_name(weight): + return getattr(weight, "_layout_cls", None) + + +def _normalize_layout_name(layout): + if isinstance(layout, str): + return layout + if hasattr(layout, "__name__"): + return layout.__name__ + if hasattr(layout, "__class__") and hasattr(layout.__class__, "__name__"): + return layout.__class__.__name__ + return None + + +def can_use_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): + if not _omni_fp8_enabled(): + return False + if _omni_linear is None: + return False + if not isinstance(input_tensor, torch.Tensor): + return False + if not input_tensor.is_xpu or input_tensor.ndim != 2: + return False + if input_tensor.dtype not in (torch.float16, torch.bfloat16): + return False + + layout_name = _normalize_layout_name(_extract_layout_name(weight)) + if layout_name not in ("TensorCoreFP8Layout", "TensorCoreFP8E4M3Layout"): + return False + + qdata = _extract_qdata(weight) + params = _extract_params(weight) + if qdata is None or params is None: + return False + if not isinstance(qdata, torch.Tensor) or not qdata.is_xpu or qdata.ndim != 2: + return False + if qdata.device != input_tensor.device: + return False + if qdata.dtype != torch.float8_e4m3fn: + return False + + scales = _expand_weight_scale(getattr(params, "scale", None), qdata.shape[0], qdata.device) + if scales is None: + return False + + if bias is not None: + if not isinstance(bias, torch.Tensor): + return False + if not bias.is_xpu or bias.ndim != 1: + return False + if bias.device != input_tensor.device: + return False + if bias.shape[0] != qdata.shape[0]: + return False + if bias.dtype != input_tensor.dtype: + return False + + if qdata.shape[1] != input_tensor.shape[1]: + return False + + return True + + +def try_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): + if not _omni_fp8_enabled(): + _log_fast_path_event("[omni_xpu_kernel] XPU FP8 fast path disabled by COMFY_XPU_FP8_OMNI_ENABLE") + return None + + if not can_use_omni_fp8_linear(input_tensor, weight, bias): + _log_fast_path_event("[omni_xpu_kernel] XPU FP8 fast path fallback for shape=%s", tuple(input_tensor.shape) if isinstance(input_tensor, torch.Tensor) else None) + return None + + qdata = _extract_qdata(weight) + params = _extract_params(weight) + scales = _expand_weight_scale(params.scale, qdata.shape[0], qdata.device) + failure_key = _primitive_failure_cache_key(input_tensor, qdata) + + with _omni_fp8_failure_cache_lock: + if failure_key in _omni_fp8_failure_cache: + _log_bad_shape_event("cached primitive creation failure", input_tensor, qdata, bias) + return None + + _log_first_use(tuple(input_tensor.shape)) + + try: + output = _omni_linear.onednn_w8a16_fp8(input_tensor.contiguous(), qdata.contiguous(), scales, bias=bias) + except RuntimeError as error: + if not _is_primitive_creation_failure(error): + raise + with _omni_fp8_failure_cache_lock: + _omni_fp8_failure_cache.add(failure_key) + _log_bad_shape_event("primitive creation failure", input_tensor, qdata, bias, error) + return None + + _log_fast_path_event( + "[omni_xpu_kernel] XPU FP8 fast path hit shape=%s dtype=%s cache=%s", + tuple(input_tensor.shape), + str(input_tensor.dtype), + _omni_linear.fp8_cache_stats() if hasattr(_omni_linear, "fp8_cache_stats") else None, + ) + return output diff --git a/comfy/xpu_quant_layout_ops.py b/comfy/xpu_quant_layout_ops.py new file mode 100644 index 000000000000..a092e3b2ed4b --- /dev/null +++ b/comfy/xpu_quant_layout_ops.py @@ -0,0 +1,71 @@ +import logging + +import torch + +from comfy.quant_ops import register_layout_op, TensorCoreFP8E4M3Layout, TensorCoreFP8Layout +from comfy.xpu_fp8_linear import try_omni_fp8_linear + + +log = logging.getLogger(__name__) + +REGISTRATION_STATUS = { + "attempted": False, + "registered": False, + "error": None, +} + + +def _fallback_linear(input_tensor, weight, bias): + target_dtype = getattr(getattr(input_tensor, "_params", None), "orig_dtype", None) + if target_dtype is None and isinstance(input_tensor, torch.Tensor): + target_dtype = input_tensor.dtype + + if hasattr(weight, "dequantize"): + weight = weight.dequantize() + if hasattr(input_tensor, "dequantize"): + input_tensor = input_tensor.dequantize() + + if target_dtype is not None: + if isinstance(input_tensor, torch.Tensor) and input_tensor.dtype != target_dtype: + input_tensor = input_tensor.to(dtype=target_dtype) + if isinstance(weight, torch.Tensor) and weight.dtype != target_dtype: + weight = weight.to(dtype=target_dtype) + if isinstance(bias, torch.Tensor) and bias.dtype != target_dtype: + bias = bias.to(dtype=target_dtype) + + return torch.nn.functional.linear(input_tensor, weight, bias) + + +def _xpu_fp8_linear_handler(func, args, kwargs): + kwargs = kwargs or {} + input_tensor = args[0] + weight = args[1] + bias = None + if len(args) > 2: + bias = args[2] + elif "bias" in kwargs: + bias = kwargs["bias"] + + reshape_back_to = None + fast_path_input = input_tensor + if isinstance(input_tensor, torch.Tensor) and input_tensor.ndim > 2: + reshape_back_to = tuple(input_tensor.shape[:-1]) + fast_path_input = input_tensor.reshape(-1, input_tensor.shape[-1]) + + output = try_omni_fp8_linear(fast_path_input, weight, bias) + if output is not None: + if reshape_back_to is not None: + output = output.reshape(*reshape_back_to, output.shape[-1]) + return output + + return _fallback_linear(input_tensor, weight, bias) + + +try: + REGISTRATION_STATUS["attempted"] = True + register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8E4M3Layout)(_xpu_fp8_linear_handler) + register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)(_xpu_fp8_linear_handler) + REGISTRATION_STATUS["registered"] = True +except Exception as e: + REGISTRATION_STATUS["error"] = str(e) + log.info("[omni_xpu_kernel] XPU FP8 layout registration skipped: %s", e) diff --git a/tests-unit/comfy_quant/test_xpu_fp8_linear.py b/tests-unit/comfy_quant/test_xpu_fp8_linear.py new file mode 100644 index 000000000000..26f97f4c1496 --- /dev/null +++ b/tests-unit/comfy_quant/test_xpu_fp8_linear.py @@ -0,0 +1,410 @@ +import os +import sys +import types +import unittest +from typing import cast +from unittest import mock + +import torch + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + + +def has_xpu(): + try: + return torch.xpu.is_available() + except AttributeError: + return False + + +class FakeQuantizedTensor: + def __init__(self, qdata, layout_cls, scale, orig_dtype): + self._qdata = qdata + self._layout_cls = layout_cls + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=qdata.device, dtype=torch.float32) + self._params = types.SimpleNamespace( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=tuple(qdata.shape), + ) + + def dequantize(self): + scale = self._params.scale + if scale.ndim == 0: + scale_for_mul = scale + elif self._qdata.ndim == 2 and scale.numel() == self._qdata.shape[0]: + scale_for_mul = scale.unsqueeze(1) + else: + scale_for_mul = scale + return self._qdata.to(self._params.orig_dtype) * scale_for_mul.to(self._params.orig_dtype) + + +class TestXpuFp8Linear(unittest.TestCase): + def test_quant_ops_auto_imports_xpu_quant_layout_ops(self): + sys.modules.pop("comfy.quant_ops", None) + sys.modules.pop("comfy.xpu_quant_layout_ops", None) + + __import__("comfy.quant_ops") + + self.assertIn("comfy.xpu_quant_layout_ops", sys.modules) + + def test_registration_module_imports_without_comfy_kitchen_backend(self): + module = __import__("comfy.xpu_quant_layout_ops", fromlist=["REGISTRATION_STATUS"]) + self.assertTrue(module.REGISTRATION_STATUS["attempted"]) + self.assertTrue(module.REGISTRATION_STATUS["registered"]) + self.assertIsNone(module.REGISTRATION_STATUS["error"]) + + def test_fallback_linear_casts_dequantized_weight_and_bias_to_input_dtype(self): + from comfy.xpu_quant_layout_ops import _fallback_linear + + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls="TensorCoreFP8E4M3Layout", + scale=torch.tensor(2.0, dtype=torch.float32), + orig_dtype=torch.float32, + ) + bias = torch.randn(3, dtype=torch.float32) + + captured = {} + + def fake_linear(input_arg, weight_arg, bias_arg): + captured["input_dtype"] = input_arg.dtype + captured["weight_dtype"] = weight_arg.dtype + captured["bias_dtype"] = bias_arg.dtype if bias_arg is not None else None + return torch.zeros(2, 3, dtype=input_arg.dtype) + + with mock.patch("torch.nn.functional.linear", side_effect=fake_linear): + output = _fallback_linear(input_tensor, weight, bias) + + self.assertEqual(captured["input_dtype"], torch.bfloat16) + self.assertEqual(captured["weight_dtype"], torch.bfloat16) + self.assertEqual(captured["bias_dtype"], torch.bfloat16) + self.assertEqual(output.dtype, torch.bfloat16) + self.assertEqual(tuple(output.shape), (2, 3)) + + def test_try_omni_fp8_linear_rejects_nd_input_without_adapter(self): + from comfy import xpu_fp8_linear + + device = torch.device("cpu") + input_tensor = torch.randn(2, 5, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls="TensorCoreFP8E4M3Layout", + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNone(output) + + def test_xpu_fp8_linear_handler_flattens_nd_input_for_fast_path_and_restores_shape(self): + from comfy.xpu_quant_layout_ops import _xpu_fp8_linear_handler + + input_tensor = torch.randn(2, 5, 4, dtype=torch.bfloat16) + weight = object() + bias = torch.randn(3, dtype=torch.bfloat16) + captured = {} + + def fake_try_omni_fp8_linear(input_arg, weight_arg, bias_arg): + captured["shape"] = tuple(input_arg.shape) + captured["weight"] = weight_arg + captured["bias"] = bias_arg + return torch.arange(30, dtype=input_arg.dtype).reshape(10, 3) + + with mock.patch("comfy.xpu_quant_layout_ops.try_omni_fp8_linear", side_effect=fake_try_omni_fp8_linear): + output = _xpu_fp8_linear_handler(None, (input_tensor, weight, bias), None) + + self.assertEqual(captured["shape"], (10, 4)) + self.assertIs(captured["weight"], weight) + self.assertIs(captured["bias"], bias) + self.assertEqual(tuple(output.shape), (2, 5, 3)) + + def test_xpu_fp8_linear_handler_preserves_nd_shape_on_fallback(self): + from comfy.xpu_quant_layout_ops import _xpu_fp8_linear_handler + + input_tensor = torch.randn(2, 5, 4, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls="TensorCoreFP8E4M3Layout", + scale=torch.tensor(2.0, dtype=torch.float32), + orig_dtype=torch.float32, + ) + bias = torch.randn(3, dtype=torch.float32) + + with mock.patch("comfy.xpu_quant_layout_ops.try_omni_fp8_linear", return_value=None): + output = _xpu_fp8_linear_handler(None, (input_tensor, weight, bias), None) + + self.assertEqual(tuple(output.shape), (2, 5, 3)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_try_omni_fp8_linear_rejects_non_per_channel_scale(self): + from comfy import xpu_fp8_linear + + device = torch.device("cpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + bad_scale = torch.ones(2, device=device, dtype=torch.float32) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls="TensorCoreFP8E4M3Layout", + scale=bad_scale, + orig_dtype=torch.bfloat16, + ) + + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + self.assertIsNone(output) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_falls_back_on_primitive_creation_failure(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + fake_module = mock.Mock() + fake_module.onednn_w8a16_fp8.side_effect = RuntimeError("could not create a primitive") + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", set()): + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNone(output) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_reraises_unrelated_runtime_errors(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + fake_module = mock.Mock() + fake_module.onednn_w8a16_fp8.side_effect = RuntimeError("some other runtime error") + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", set()): + with self.assertRaisesRegex(RuntimeError, "some other runtime error"): + xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_caches_primitive_creation_failure_by_shape(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + fake_module = mock.Mock() + fake_module.onednn_w8a16_fp8.side_effect = RuntimeError("could not create a primitive") + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", set()): + output_first = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + output_second = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNone(output_first) + self.assertIsNone(output_second) + self.assertEqual(fake_module.onednn_w8a16_fp8.call_count, 1) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_failure_cache_is_shape_specific(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_a = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + input_b = torch.randn(3, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + fake_module = mock.Mock() + fake_module.onednn_w8a16_fp8.side_effect = RuntimeError("could not create a primitive") + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", set()): + output_a = xpu_fp8_linear.try_omni_fp8_linear(input_a, weight, None) + output_b = xpu_fp8_linear.try_omni_fp8_linear(input_b, weight, None) + + self.assertIsNone(output_a) + self.assertIsNone(output_b) + self.assertEqual(fake_module.onednn_w8a16_fp8.call_count, 2) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_logs_bad_shape_on_primitive_creation_failure(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + fake_module = mock.Mock() + fake_module.onednn_w8a16_fp8.side_effect = RuntimeError("could not create a primitive") + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", set()): + with self.assertLogs(xpu_fp8_linear.log, level="INFO") as logs: + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNone(output) + joined = "\n".join(logs.output) + self.assertIn("primitive creation failure", joined) + self.assertIn("input_shape=(2, 4)", joined) + self.assertIn("qdata_shape=(3, 4)", joined) + self.assertIn("dtype=torch.bfloat16", joined) + self.assertIn("has_bias=False", joined) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_logs_cached_bad_shape_skip(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + failure_key = xpu_fp8_linear._primitive_failure_cache_key(input_tensor, qweight) + fake_module = mock.Mock() + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", {failure_key}): + with self.assertLogs(xpu_fp8_linear.log, level="INFO") as logs: + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNone(output) + fake_module.onednn_w8a16_fp8.assert_not_called() + joined = "\n".join(logs.output) + self.assertIn("cached primitive creation failure", joined) + self.assertIn("input_shape=(2, 4)", joined) + self.assertIn("qdata_shape=(3, 4)", joined) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_accepts_layout_class_object(self): + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + fake_linear = mock.Mock() + fake_linear.onednn_w8a16_fp8.return_value = torch.ones(2, 3, device=device, dtype=torch.bfloat16) + fake_linear.fp8_cache_stats.return_value = {"hits": 1, "misses": 1, "size": 1} + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_linear): + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNotNone(output) + output = cast(torch.Tensor, output) + self.assertEqual(tuple(output.shape), (2, 3)) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_rejects_mixed_xpu_devices(self): + from comfy import xpu_fp8_linear + + if getattr(torch.xpu, "device_count", lambda: 0)() < 2: + self.skipTest("Need 2 XPU devices for mixed-device rejection test") + + input_device = torch.device("xpu:0") + weight_device = torch.device("xpu:1") + input_tensor = torch.randn(2, 4, device=input_device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=weight_device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls="TensorCoreFP8E4M3Layout", + scale=torch.tensor(2.0, device=weight_device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + self.assertIsNone(output) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_expands_scalar_weight_scale(self): + from comfy import xpu_fp8_linear + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls="TensorCoreFP8E4M3Layout", + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + captured = {} + + def fake_linear(x, w, scales, bias=None): + captured["scales"] = scales + return torch.zeros(x.shape[0], w.shape[0], device=x.device, dtype=x.dtype) + + fake_module = mock.Mock() + fake_module.onednn_w8a16_fp8.side_effect = fake_linear + fake_module.fp8_cache_stats.return_value = {"hits": 0, "misses": 1, "size": 1} + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", set()): + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + + self.assertIsNotNone(output) + output = cast(torch.Tensor, output) + self.assertEqual(tuple(output.shape), (2, 3)) + self.assertEqual(captured["scales"].shape[0], 3) + self.assertTrue(torch.allclose(captured["scales"], torch.full((3,), 2.0, device=device))) + + +if __name__ == "__main__": + unittest.main() From fb13e12a5a5937a8d1acb56af2cc476e7a6b7649 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 19 Mar 2026 16:36:07 +0800 Subject: [PATCH 2/3] test: cover mixed precision fp8 xpu integration Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../comfy_quant/test_mixed_precision.py | 242 +++++++++++++++++- 1 file changed, 237 insertions(+), 5 deletions(-) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 7c740491dcac..41df85697348 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -3,6 +3,7 @@ import sys import os import json +from unittest import mock # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -10,6 +11,13 @@ def has_gpu(): return torch.cuda.is_available() + +def has_xpu(): + try: + return torch.xpu.is_available() + except AttributeError: + return False + from comfy.cli_args import args if not has_gpu(): args.cpu = True @@ -20,11 +28,11 @@ def has_gpu(): class SimpleModel(torch.nn.Module): - def __init__(self, operations=ops.disable_weight_init): + def __init__(self, operations=ops.disable_weight_init, device="cpu"): super().__init__() - self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) - self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) - self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + self.layer1 = operations.Linear(10, 20, device=device, dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device=device, dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device=device, dtype=torch.bfloat16) def forward(self, x): x = self.layer1(x) @@ -201,6 +209,231 @@ def apply_lora(weight): self.assertEqual(output.shape, (5, 40)) + def test_mixed_precision_forward_reaches_registered_fp8_linear_handler(self): + """Test that a real mixed_precision forward reaches the registered FP8 linear handler.""" + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict, strict=False) + + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + + handler_calls = [] + + def fake_try_omni_fp8_linear(input_tensor_arg, weight_arg, bias_arg): + handler_calls.append((input_tensor_arg, weight_arg, bias_arg)) + return None + + with mock.patch("comfy.xpu_quant_layout_ops.try_omni_fp8_linear", side_effect=fake_try_omni_fp8_linear): + with torch.inference_mode(): + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(len(handler_calls), 1) + handler_input, handler_weight, handler_bias = handler_calls[0] + self.assertEqual(tuple(handler_input.shape), (5, 10)) + self.assertIsInstance(handler_weight, QuantizedTensor) + self.assertEqual(handler_weight._layout_cls, "TensorCoreFP8E4M3Layout") + self.assertEqual(tuple(handler_bias.shape), (20,)) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_mixed_precision_xpu_forward_invokes_omni_kernel_fast_path(self): + """Test that XPU mixed_precision forward hits the omni FP8 kernel fast path.""" + from comfy import xpu_fp8_linear + + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + + model = SimpleModel(operations=ops.mixed_precision_ops({}), device="xpu") + model.load_state_dict(state_dict, strict=False) + + input_tensor = torch.randn(5, 10, device="xpu", dtype=torch.bfloat16) + fake_linear = mock.Mock() + fake_linear.onednn_w8a16_fp8.return_value = torch.ones(5, 20, device="xpu", dtype=torch.bfloat16) + + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_linear): + with torch.inference_mode(): + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertGreaterEqual(fake_linear.onednn_w8a16_fp8.call_count, 1) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_mixed_precision_xpu_forward_reuses_fp8_cache_for_same_shape(self): + """Test that repeated mixed_precision XPU forwards reuse the omni FP8 cache for the same shape.""" + from omni_xpu_kernel import linear + + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + + model = SimpleModel(operations=ops.mixed_precision_ops({}), device="xpu") + model.load_state_dict(state_dict, strict=False) + + input_tensor = torch.randn(5, 10, device="xpu", dtype=torch.bfloat16) + + linear.fp8_cache_clear() + self.assertEqual(linear.fp8_cache_stats(), {"hits": 0, "misses": 0, "size": 0}) + + with torch.inference_mode(): + output_first = model(input_tensor) + output_second = model(input_tensor) + + stats = linear.fp8_cache_stats() + self.assertEqual(tuple(output_first.shape), (5, 40)) + self.assertEqual(tuple(output_second.shape), (5, 40)) + self.assertEqual(stats["misses"], 1) + self.assertGreaterEqual(stats["hits"], 1) + self.assertEqual(stats["size"], 1) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_mixed_precision_xpu_forward_respects_omni_disable_env(self): + """Test that disabling the ComfyUI omni FP8 env prevents fast-path invocation.""" + from comfy import xpu_fp8_linear + + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + + model = SimpleModel(operations=ops.mixed_precision_ops({}), device="xpu") + model.load_state_dict(state_dict, strict=False) + + input_tensor = torch.randn(5, 10, device="xpu", dtype=torch.bfloat16) + fake_linear = mock.Mock() + fake_linear.onednn_w8a16_fp8.return_value = torch.ones(5, 20, device="xpu", dtype=torch.bfloat16) + + with mock.patch.dict(os.environ, {"COMFY_XPU_FP8_OMNI_ENABLE": "0"}, clear=False): + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_linear): + with torch.inference_mode(): + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(fake_linear.onednn_w8a16_fp8.call_count, 0) + + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_mixed_precision_xpu_forward_logs_fast_path_when_enabled(self): + """Test that enabling ComfyUI omni FP8 log env records fast-path use.""" + from comfy import xpu_fp8_linear + + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + + model = SimpleModel(operations=ops.mixed_precision_ops({}), device="xpu") + model.load_state_dict(state_dict, strict=False) + + input_tensor = torch.randn(5, 10, device="xpu", dtype=torch.bfloat16) + fake_linear = mock.Mock() + fake_linear.onednn_w8a16_fp8.return_value = torch.ones(5, 20, device="xpu", dtype=torch.bfloat16) + + with mock.patch.dict(os.environ, {"COMFY_XPU_FP8_OMNI_LOG": "1"}, clear=False): + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_linear): + with self.assertLogs("comfy.xpu_fp8_linear", level="INFO") as logs: + with torch.inference_mode(): + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertGreaterEqual(fake_linear.onednn_w8a16_fp8.call_count, 1) + self.assertTrue(any("fast path" in line.lower() for line in logs.output)) + def test_error_handling_unknown_format(self): """Test that unknown formats raise error""" # Configure with unknown format @@ -230,4 +463,3 @@ def test_error_handling_unknown_format(self): if __name__ == "__main__": unittest.main() - From c9be2cf155d7e3275646b74e465ce84207d4651f Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Fri, 20 Mar 2026 06:22:32 +0000 Subject: [PATCH 3/3] feat: add FP8 M-chunking for large GEMM shapes and 3-level log verbosity Split M dimension into chunks when oneDNN fails to create FP8 primitives for large M values (e.g. WAN 2.2 14B FFN layers with M=32760). Benchmarked chunk_m=512 yields 4-8% speedup over dequant+bf16 for FFN shapes. Add COMFY_XPU_FP8_OMNI_LOG env var with 3 levels: 0=off, 1=misses only (default), 2=verbose. Previously all logging was gated by a single bool. --- comfy/xpu_fp8_linear.py | 146 ++++++++++++++++-- .../comfy_quant/test_mixed_precision.py | 4 +- tests-unit/comfy_quant/test_xpu_fp8_linear.py | 49 +++++- 3 files changed, 184 insertions(+), 15 deletions(-) diff --git a/comfy/xpu_fp8_linear.py b/comfy/xpu_fp8_linear.py index 328e9aa4e41e..aab6ba0b55a5 100644 --- a/comfy/xpu_fp8_linear.py +++ b/comfy/xpu_fp8_linear.py @@ -13,6 +13,10 @@ _omni_fp8_failure_cache = set() _omni_fp8_failure_cache_lock = threading.Lock() +# M-chunking: (shape_key) → chunk_m for shapes that benefit from FP8 chunking +_omni_fp8_m_chunk_cache: dict = {} +_omni_fp8_m_chunk_cache_lock = threading.Lock() + try: from omni_xpu_kernel import linear as _omni_linear except ImportError: @@ -30,8 +34,26 @@ def _omni_fp8_enabled() -> bool: return _env_enabled("COMFY_XPU_FP8_OMNI_ENABLE", True) -def _omni_fp8_log_enabled() -> bool: - return _env_enabled("COMFY_XPU_FP8_OMNI_LOG", False) +def _omni_fp8_log_level() -> int: + """Return log verbosity level for FP8 operations. + + 0 = off (no logging) + 1 = misses only (primitive creation failures, bad shapes — first occurrence) + 2 = verbose (all events including hits and cached failures) + + Default is 1 (misses only) when COMFY_XPU_FP8_OMNI_LOG is not set. + Set COMFY_XPU_FP8_OMNI_LOG=0 to disable, =2 or =verbose for full logging. + """ + value = os.environ.get("COMFY_XPU_FP8_OMNI_LOG") + if value is None: + return 1 # default: misses only + value = value.strip().lower() + if value in {"0", "false", "no", "off", ""}: + return 0 + if value in {"2", "verbose", "debug", "all"}: + return 2 + # "1", "true", "yes", "on" → level 1 + return 1 def _log_first_use(shape): @@ -41,8 +63,15 @@ def _log_first_use(shape): log.info("[omni_xpu_kernel] First use in xpu_fp8_linear with input shape %s", shape) -def _log_fast_path_event(message: str, *args): - if _omni_fp8_log_enabled(): +def _log_miss_event(message: str, *args): + """Log cache miss / first failure events (level >= 1).""" + if _omni_fp8_log_level() >= 1: + log.info(message, *args) + + +def _log_verbose_event(message: str, *args): + """Log verbose events: hits, cached failures, fallbacks (level >= 2).""" + if _omni_fp8_log_level() >= 2: log.info(message, *args) @@ -70,7 +99,12 @@ def _log_bad_shape_event(reason: str, input_tensor: torch.Tensor, qdata: torch.T if error is not None: message += " error=%s" args.append(str(error)) - log.info(message, *args) + + # First failure (with error) → miss-level; cached failure → verbose-level + if error is not None: + _log_miss_event(message, *args) + else: + _log_verbose_event(message, *args) def _expand_weight_scale(scale, rows, device): @@ -108,6 +142,68 @@ def _normalize_layout_name(layout): return None +# --------------------------------------------------------------------------- +# M-chunking: split large M into chunks for oneDNN FP8 primitive creation. +# oneDNN fails to JIT-compile FP8 kernels for large M; chunking works around +# this while still being faster than dequant+bf16 for FFN-shaped GEMMs. +# --------------------------------------------------------------------------- + +_M_CHUNK_THRESHOLD = 4096 # below this, direct FP8 path works + +# (K, N) -> chunk_m. Only shapes where FP8 M-chunking beats dequant+bf16. +_M_CHUNK_TABLE = { + (5120, 13824): 512, # WAN 2.2 14B FFN up: 4% faster + (13824, 5120): 512, # WAN 2.2 14B FFN down: 8% faster +} + + +def _select_chunk_m(m: int, k: int, n: int, dtype: torch.dtype) -> Optional[int]: + """Return optimal chunk_m for M-chunking, or None if not beneficial.""" + if m <= _M_CHUNK_THRESHOLD: + return None + return _M_CHUNK_TABLE.get((k, n)) + + +def _try_fp8_m_chunked( + input_tensor: torch.Tensor, + qdata: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor], + chunk_m: int, +) -> Optional[torch.Tensor]: + """Execute FP8 GEMM by splitting M dimension into chunks of size chunk_m.""" + m, k = input_tensor.shape + n = qdata.shape[0] + + output = torch.empty(m, n, dtype=input_tensor.dtype, device=input_tensor.device) + + for start in range(0, m, chunk_m): + end = min(start + chunk_m, m) + x_chunk = input_tensor[start:end].contiguous() + + try: + out_chunk = _omni_linear.onednn_w8a16_fp8( + x_chunk, qdata, scales, bias=bias, + ) + except RuntimeError as e: + if _is_primitive_creation_failure(e): + _log_miss_event( + "[omni_xpu_kernel] XPU FP8 M-chunking failed at chunk [%d:%d] " + "chunk_shape=(%d,%d) error=%s", + start, end, end - start, k, str(e), + ) + return None + raise + + output[start:end] = out_chunk + + return output + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + def can_use_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): if not _omni_fp8_enabled(): return False @@ -159,11 +255,11 @@ def can_use_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): def try_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): if not _omni_fp8_enabled(): - _log_fast_path_event("[omni_xpu_kernel] XPU FP8 fast path disabled by COMFY_XPU_FP8_OMNI_ENABLE") + _log_verbose_event("[omni_xpu_kernel] XPU FP8 fast path disabled by COMFY_XPU_FP8_OMNI_ENABLE") return None if not can_use_omni_fp8_linear(input_tensor, weight, bias): - _log_fast_path_event("[omni_xpu_kernel] XPU FP8 fast path fallback for shape=%s", tuple(input_tensor.shape) if isinstance(input_tensor, torch.Tensor) else None) + _log_verbose_event("[omni_xpu_kernel] XPU FP8 fast path fallback for shape=%s", tuple(input_tensor.shape) if isinstance(input_tensor, torch.Tensor) else None) return None qdata = _extract_qdata(weight) @@ -174,6 +270,17 @@ def try_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): with _omni_fp8_failure_cache_lock: if failure_key in _omni_fp8_failure_cache: _log_bad_shape_event("cached primitive creation failure", input_tensor, qdata, bias) + with _omni_fp8_m_chunk_cache_lock: + cached_chunk_m = _omni_fp8_m_chunk_cache.get(failure_key) + if cached_chunk_m is not None: + _log_verbose_event( + "[omni_xpu_kernel] XPU FP8 M-chunking (cached) chunk_m=%d for shape %s", + cached_chunk_m, tuple(input_tensor.shape), + ) + return _try_fp8_m_chunked( + input_tensor.contiguous(), qdata.contiguous(), scales, bias, + cached_chunk_m, + ) return None _log_first_use(tuple(input_tensor.shape)) @@ -186,11 +293,32 @@ def try_omni_fp8_linear(input_tensor, weight, bias: Optional[torch.Tensor]): with _omni_fp8_failure_cache_lock: _omni_fp8_failure_cache.add(failure_key) _log_bad_shape_event("primitive creation failure", input_tensor, qdata, bias, error) + + # Try M-chunking for this failed shape + m, k = input_tensor.shape + n = qdata.shape[0] + chunk_m = _select_chunk_m(m, k, n, input_tensor.dtype) + if chunk_m is not None: + result = _try_fp8_m_chunked( + input_tensor.contiguous(), qdata.contiguous(), scales, bias, + chunk_m, + ) + if result is not None: + # Cache successful chunk_m + with _omni_fp8_m_chunk_cache_lock: + _omni_fp8_m_chunk_cache[failure_key] = chunk_m + _log_miss_event( + "[omni_xpu_kernel] XPU FP8 M-chunking success, cached chunk_m=%d for shape %s", + chunk_m, tuple(input_tensor.shape), + ) + return result + return None - _log_fast_path_event( - "[omni_xpu_kernel] XPU FP8 fast path hit shape=%s dtype=%s cache=%s", + _log_verbose_event( + "[omni_xpu_kernel] XPU FP8 fast path hit input_shape=%s weight_shape=%s dtype=%s cache=%s", tuple(input_tensor.shape), + tuple(qdata.shape), str(input_tensor.dtype), _omni_linear.fp8_cache_stats() if hasattr(_omni_linear, "fp8_cache_stats") else None, ) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 41df85697348..965dcc18134e 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -391,7 +391,7 @@ def test_mixed_precision_xpu_forward_respects_omni_disable_env(self): @unittest.skipUnless(has_xpu(), "XPU not available") def test_mixed_precision_xpu_forward_logs_fast_path_when_enabled(self): - """Test that enabling ComfyUI omni FP8 log env records fast-path use.""" + """Test that verbose log level (COMFY_XPU_FP8_OMNI_LOG=2) records fast-path hits.""" from comfy import xpu_fp8_linear layer_quant_config = { @@ -424,7 +424,7 @@ def test_mixed_precision_xpu_forward_logs_fast_path_when_enabled(self): fake_linear = mock.Mock() fake_linear.onednn_w8a16_fp8.return_value = torch.ones(5, 20, device="xpu", dtype=torch.bfloat16) - with mock.patch.dict(os.environ, {"COMFY_XPU_FP8_OMNI_LOG": "1"}, clear=False): + with mock.patch.dict(os.environ, {"COMFY_XPU_FP8_OMNI_LOG": "2"}, clear=False): with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_linear): with self.assertLogs("comfy.xpu_fp8_linear", level="INFO") as logs: with torch.inference_mode(): diff --git a/tests-unit/comfy_quant/test_xpu_fp8_linear.py b/tests-unit/comfy_quant/test_xpu_fp8_linear.py index 26f97f4c1496..bcbc483880b6 100644 --- a/tests-unit/comfy_quant/test_xpu_fp8_linear.py +++ b/tests-unit/comfy_quant/test_xpu_fp8_linear.py @@ -296,6 +296,7 @@ def test_try_omni_fp8_linear_logs_bad_shape_on_primitive_creation_failure(self): @unittest.skipUnless(has_xpu(), "XPU not available") def test_try_omni_fp8_linear_logs_cached_bad_shape_skip(self): + """Cached primitive creation failure logs only at verbose level (COMFY_XPU_FP8_OMNI_LOG=2).""" from comfy import xpu_fp8_linear from comfy.quant_ops import TensorCoreFP8E4M3Layout @@ -312,10 +313,18 @@ def test_try_omni_fp8_linear_logs_cached_bad_shape_skip(self): failure_key = xpu_fp8_linear._primitive_failure_cache_key(input_tensor, qweight) fake_module = mock.Mock() - with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): - with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", {failure_key}): - with self.assertLogs(xpu_fp8_linear.log, level="INFO") as logs: - output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + old_val = os.environ.get("COMFY_XPU_FP8_OMNI_LOG") + os.environ["COMFY_XPU_FP8_OMNI_LOG"] = "2" + try: + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", {failure_key}): + with self.assertLogs(xpu_fp8_linear.log, level="INFO") as logs: + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + finally: + if old_val is None: + os.environ.pop("COMFY_XPU_FP8_OMNI_LOG", None) + else: + os.environ["COMFY_XPU_FP8_OMNI_LOG"] = old_val self.assertIsNone(output) fake_module.onednn_w8a16_fp8.assert_not_called() @@ -324,6 +333,38 @@ def test_try_omni_fp8_linear_logs_cached_bad_shape_skip(self): self.assertIn("input_shape=(2, 4)", joined) self.assertIn("qdata_shape=(3, 4)", joined) + @unittest.skipUnless(has_xpu(), "XPU not available") + def test_try_omni_fp8_linear_cached_failure_silent_at_default_log_level(self): + """At default log level (1), cached primitive creation failures produce no log output.""" + from comfy import xpu_fp8_linear + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + device = torch.device("xpu") + input_tensor = torch.randn(2, 4, device=device, dtype=torch.bfloat16) + qweight = torch.randn(3, 4, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = FakeQuantizedTensor( + qdata=qweight, + layout_cls=TensorCoreFP8E4M3Layout, + scale=torch.tensor(2.0, device=device, dtype=torch.float32), + orig_dtype=torch.bfloat16, + ) + + failure_key = xpu_fp8_linear._primitive_failure_cache_key(input_tensor, qweight) + fake_module = mock.Mock() + + old_val = os.environ.get("COMFY_XPU_FP8_OMNI_LOG") + os.environ.pop("COMFY_XPU_FP8_OMNI_LOG", None) + try: + with mock.patch.object(xpu_fp8_linear, "_omni_linear", fake_module): + with mock.patch.object(xpu_fp8_linear, "_omni_fp8_failure_cache", {failure_key}): + output = xpu_fp8_linear.try_omni_fp8_linear(input_tensor, weight, None) + finally: + if old_val is not None: + os.environ["COMFY_XPU_FP8_OMNI_LOG"] = old_val + + self.assertIsNone(output) + fake_module.onednn_w8a16_fp8.assert_not_called() + @unittest.skipUnless(has_xpu(), "XPU not available") def test_try_omni_fp8_linear_accepts_layout_class_object(self): from comfy import xpu_fp8_linear