From a53b07eefea2d3bdb3b9b4e98bb21867e07d79ff Mon Sep 17 00:00:00 2001 From: orangeH25 <18085625039@163.com> Date: Tue, 7 Apr 2026 09:06:13 +0000 Subject: [PATCH] Adapt HIF8 tensor --- torch_npu/utils/hif8_tensor.py | 509 ++++++++++++++++----------------- 1 file changed, 249 insertions(+), 260 deletions(-) diff --git a/torch_npu/utils/hif8_tensor.py b/torch_npu/utils/hif8_tensor.py index ab410156f2..a01dc6d176 100644 --- a/torch_npu/utils/hif8_tensor.py +++ b/torch_npu/utils/hif8_tensor.py @@ -8,10 +8,11 @@ __all__ = [] -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import torch from torch.utils._pytree import tree_map +from torch._subclasses.fake_tensor import FakeTensor import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error @@ -21,6 +22,7 @@ tex = torch_npu._C._cd aten = torch.ops.aten +HIF8_OPS_TABLE: Dict[Any, Any] = {} NPU_CUSTOM_DType = { torch.uint8: tex.DType.uint8, @@ -30,7 +32,31 @@ torch.bfloat16: tex.DType.bfloat16, } +def implements(aten_ops): + """Register aten ops to the HIF8 op table.""" + def decorator(func): + for op in aten_ops: + if op in HIF8_OPS_TABLE: + raise RuntimeError( + f"HIF8 op {op} is already registered to {HIF8_OPS_TABLE[op].__name__}" + ) + HIF8_OPS_TABLE[op] = func + return func + + return decorator + +def _is_fakeish_tensor(x): + return ( + isinstance(x, torch.Tensor) + and ( + isinstance(x, FakeTensor) + or x.device.type == "meta" + or type(x).__name__ == "FunctionalTensor" + ) + ) + +@torch._dynamo.allow_in_graph class _FromHiFloat8Func(torch.autograd.Function): """Cast from HIF8 to other dtype""" @@ -45,7 +71,7 @@ def forward( data = tensor._data.contiguous().view(1, -1).detach() out = tex.cast_from_fp8( data, - tex.DType.hifloat8, + NPU_CUSTOM_DType[dtype], # tex.DType.hifloat8, NPU_CUSTOM_DType[dtype], ) out = out.view(tensor.size()) @@ -60,6 +86,7 @@ def backward( return grad, None +@torch._dynamo.allow_in_graph class _ToHiFloat8Func(torch.autograd.Function): """Cast to HIF8 from other dtype""" @@ -77,7 +104,7 @@ def forward( # Cast data to HIF8 data = tex.cast_to_fp8( tensor.view(1, -1), - tex.DType.hifloat8, + NPU_CUSTOM_DType[tensor.dtype], # tex.DType.hifloat8, ) data = data.view(tensor.size()) @@ -95,153 +122,29 @@ def backward( # Assume that we want gradients in full precision return grad, None - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new _HiFloat8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: _HiFloat8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = dict( - data=tensor._data, +def _from_hifloat8_impl( + tensor: _HiFloat8Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if dtype is None: + dtype = tensor.dtype + data = tensor._data + if _is_fakeish_tensor(data): + if tuple(data.size()) != tuple(tensor.size()): + data = data.view(tensor.size()) + if data.dtype != dtype: + data = data.to(dtype) + return data + return _FromHiFloat8Func.apply(tensor, dtype) + + +def _to_hifloat8_impl(tensor: torch.Tensor) -> _HiFloat8Tensor: + if _is_fakeish_tensor(tensor): + return _HiFloat8Tensor( + data=tensor, dtype=tensor.dtype, ) - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return _HiFloat8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the _HiFloat8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, _HiFloat8Tensor): - return _HiFloat8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, _HiFloat8Tensor): - dgrad = _HiFloat8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the _HiFloat8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, _HiFloat8Tensor): - return _HiFloat8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, _HiFloat8Tensor): - dgrad = _HiFloat8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), - ) - return dgrad, None - return grad.reshape(ctx.shape), None - - -class _TransposeFunc(torch.autograd.Function): - """Transpose function - - Transpose the _HiFloat8Tensor. - - """ - - @staticmethod - def forward(ctx, tensor, dim0, dim1): - ctx.save_for_backward(dim0, dim1) - if isinstance(tensor, _HiFloat8Tensor): - return _HiFloat8Tensor.make_like( - tensor, - data=tensor._data.transpose(dim0, dim1), - ) - return tensor.transpose(dim0, dim1) - - @staticmethod - def backward(ctx, grad): - dim0, dim1 = ctx.saved_tensors - if isinstance(grad, _HiFloat8Tensor): - dgrad = _HiFloat8Tensor.make_like( - grad, - data=grad._data.transpose(dim0, dim1), - ) - return dgrad, None - return grad.transpose(dim0, dim1), None, None + return _ToHiFloat8Func.apply(tensor) class _HiFloat8Tensor(torch.Tensor): @@ -268,17 +171,17 @@ def __new__( dtype: torch.dtype = torch.float32, ): # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"HiFloat8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - + pta_error(ErrCode.VALUE) - ) + # if data.element_size() != 1: + # raise ValueError( + # f"HiFloat8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" + # + pta_error(ErrCode.VALUE) + # ) if data.requires_grad: raise ValueError( "HiFloat8Tensor requires non-differentiable data buffer" + pta_error(ErrCode.VALUE) ) - if not data.is_npu: + if not _is_fakeish_tensor(data) and not data.is_npu: data = data.npu() # Initialize tensor object @@ -292,7 +195,7 @@ def __new__( requires_grad=data.requires_grad, device=data.device, ) - self._data: torch.Tensor = data + self._data = data return self @@ -331,7 +234,7 @@ def from_hifloat8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: By default the resulting tensor's dtype is the _HiFloat8Tensor's nominal dtype. """ - return _FromHiFloat8Func.apply(self, dtype) + return _from_hifloat8_impl(self, dtype) @classmethod def to_hifloat8( @@ -339,9 +242,7 @@ def to_hifloat8( tensor: torch.Tensor ): """Construct _HiFloat8Tensor from PyTorch tensor""" - return _ToHiFloat8Func.apply( - tensor - ) + return _to_hifloat8_impl(tensor) def float(self) -> torch.Tensor: return self.from_hifloat8(dtype=torch.float32) @@ -356,13 +257,16 @@ def cpu(self) -> torch.Tensor: return self.from_hifloat8().cpu() def clone(self) -> _HiFloat8Tensor: - return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) + return aten.clone.default(self) def view(self, *shape: Tuple[int]) -> _HiFloat8Tensor: - return _ViewFunc.apply(self, shape) + return aten.view.default(self, shape) def reshape(self, *shape: Tuple[int]) -> _HiFloat8Tensor: - return _ReshapeFunc.apply(self, shape) + return aten.reshape.default(self, shape) + + def transpose(self, dim0, dim1): + return aten.transpose.int(self, dim0, dim1) def contiguous( self, @@ -376,9 +280,9 @@ def contiguous( """ if self._data.is_contiguous(memory_format=memory_format): return self - return _IdentityFunc.apply( + return _HiFloat8Tensor.make_like( self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, + data=self._data.detach().contiguous(memory_format=memory_format), ) def to_dtype(self, dtype: torch.dtype) -> _HiFloat8Tensor: @@ -395,102 +299,27 @@ def to_dtype(self, dtype: torch.dtype) -> _HiFloat8Tensor: @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): + if kwargs is None: + kwargs = {} - # In-place copy op - if func == aten.copy_.default: - - # Check tensors - dst = args[0] - src = args[1] - if not isinstance(dst, torch.Tensor): - raise RuntimeError( - "Attempted to copy into something that isn't a PyTorch tensor" - + pta_error(ErrCode.TYPE) - ) - if not isinstance(src, torch.Tensor): - raise RuntimeError( - "Attempted to copy from something that isn't a PyTorch tensor" - + pta_error(ErrCode.TYPE) - ) - - # Special handling based on which tensors are HIF8 - dst_is_hif8 = isinstance(dst, _HiFloat8Tensor) - src_is_hif8 = isinstance(src, _HiFloat8Tensor) - if dst_is_hif8 and src_is_hif8: - # Directly copy HIF8 data if possible - dst._data.copy_(src._data) - - elif not dst_is_hif8 and src_is_hif8: - # Cast source tensor to higher precision - dst.copy_(src.from_hifloat8()) - - elif dst_is_hif8 and not src_is_hif8: - # Make sure input is in expected format - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) - - # Cast to HIF8 - if not dst._data.is_contiguous(): - raise RuntimeError( - "Transformer Engine cast kernels require contiguous data" - + pta_error(ErrCode.INTERNAL) - ) - tex.cast_to_fp8_noalloc( - src.view(1, -1), - dst._data.view(1, -1), - tex.DType.hifloat8, + def allowed_subclasses(type_): + return ( + issubclass(cls, type_) + or issubclass(torch._subclasses.fake_tensor.FakeTensor, type_) + or issubclass( + torch._subclasses.functional_tensor.FunctionalTensor, type_ ) - else: - # Invalid case - raise RuntimeError( - "Using HiFloat8Tensor copy logic, but no HiFloat8Tensor found" - + pta_error(ErrCode.INTERNAL) - ) - - # Nothing to return for in-place ops - return None - - # Slice op - if func == aten.slice.Tensor: - tensor = args[0] - data = tensor._data - data_slice = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return _HiFloat8Tensor.make_like(tensor, data=data_slice) - - # Detach op - if func == aten.detach.default: - # Simply return a new _HiFloat8Tensor with the same attrs - return _HiFloat8Tensor.make_like( - args[0], - data=args[0]._data, ) - # View op - if func == aten.view.default: - tensor = args[0] - data = tensor._data - data_view = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return _HiFloat8Tensor.make_like( - tensor, - data=data_view, - ) + if not all(allowed_subclasses(t) for t in types): + return NotImplemented + if func in HIF8_OPS_TABLE: + return HIF8_OPS_TABLE[func](func, args, kwargs) + def maybe_unwrap(t): if isinstance(t, _HiFloat8Tensor): - return t.from_hifloat8() + return _from_hifloat8_impl(t) return t def maybe_update_inplace(arg, new_arg, schema_arg): @@ -517,7 +346,7 @@ def maybe_update_inplace(arg, new_arg, schema_arg): new_kwargs = tree_map(maybe_unwrap, kwargs) schema_args = func._schema.arguments args_len = len(args) - out = super().__torch_dispatch__(func, types, new_args, new_kwargs) + super().__torch_dispatch__(func, types, new_args, new_kwargs) for arg, new_arg, schema_arg in zip(args, new_args, schema_args): maybe_update_inplace(arg, new_arg, schema_arg) for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): @@ -529,10 +358,8 @@ def maybe_update_inplace(arg, new_arg, schema_arg): # Default op # Note: cast to higher precision and perform op args = tree_map(maybe_unwrap, args) - if kwargs is not None: - kwargs = tree_map(maybe_unwrap, kwargs) - out = super().__torch_dispatch__(func, types, args, kwargs) - return out + kwargs = tree_map(maybe_unwrap, kwargs) + return super().__torch_dispatch__(func, types, args, kwargs) @classmethod def _make_in_reduce_ex( @@ -580,5 +407,167 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} return torch._C._disabled_torch_function_impl(func, types, args, kwargs) - def transpose(self, dim0, dim1): - return _TransposeFunc.apply(self, dim0, dim1) + def __tensor_flatten__(self): + return ["_data"], {"dtype": self.dtype} + + @staticmethod + def __tensor_unflatten__(tensor_data_dict, meta, outer_size, outer_stride): + data = tensor_data_dict["_data"] + if outer_size is not None and outer_stride is not None: + if tuple(data.size()) != tuple(outer_size) or tuple(data.stride()) != tuple( + outer_stride + ): + data = data.as_strided(outer_size, outer_stride, data.storage_offset()) + + + if isinstance(data, FakeTensor) or data.device.type == "meta": + shape = outer_size if outer_size is not None else data.shape + stride = outer_stride if outer_stride is not None else data.stride() + return torch.empty_strided( + shape, + stride, + dtype=meta["dtype"], + device=data.device, + ) + + return _HiFloat8Tensor( + data=data, + dtype=meta["dtype"], + ) + +def _wrap_hif8_like( + tensor: _HiFloat8Tensor, + data: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> _HiFloat8Tensor: + return _HiFloat8Tensor.make_like( + tensor, + data=data, + dtype=tensor.dtype if dtype is None else dtype, + ) + + +def _unwrap_hif8(x): + if isinstance(x, _HiFloat8Tensor): + return _from_hifloat8_impl(x) + return x + +@implements( + [ + aten._unsafe_view.default, + aten.as_strided.default, + aten.clone.default, + aten.slice.Tensor, + aten.fill_.Scalar, + aten.reshape.default, + ] +) +def hif8_desugar_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **(kwargs or {})) + return _wrap_hif8_like(args[0], new_data) + + +@implements([aten.detach.default]) +def hif8_detach(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **(kwargs or {})) + return _wrap_hif8_like(args[0], new_data) + + +@implements([aten.t.default, aten.transpose.int]) +def hif8_transpose(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **(kwargs or {})) + return _wrap_hif8_like(args[0], new_data) + + +@implements([aten.view.default]) +def hif8_view(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **(kwargs or {})) + return _wrap_hif8_like(args[0], new_data) + + +@implements([aten._to_copy.default]) +def hif8_to_copy(aten_op, args, kwargs=None): + kwargs = kwargs or {} + tensor = args[0] + target_dtype = kwargs.get("dtype", tensor.dtype) + data_kwargs = {k: v for k, v in kwargs.items() if k != "dtype"} + new_data = tensor._data + if data_kwargs: + new_data = aten_op(new_data, **data_kwargs) + return _wrap_hif8_like(tensor, new_data, dtype=target_dtype) + + +@implements([aten.mm.default, aten.matmul.default]) +def hif8_matmul(aten_op, args, kwargs=None): + kwargs = kwargs or {} + new_args = tree_map(_unwrap_hif8, args) + new_kwargs = tree_map(_unwrap_hif8, kwargs) + return aten_op(*new_args, **new_kwargs) + + +@implements([aten.copy_.default]) +def hif8_copy(aten_op, args, kwargs=None): + kwargs = kwargs or {} + dst = args[0] + src = args[1] + if not isinstance(dst, torch.Tensor): + raise RuntimeError( + "Attempted to copy into something that isn't a PyTorch tensor" + + pta_error(ErrCode.TYPE) + ) + if not isinstance(src, torch.Tensor): + raise RuntimeError( + "Attempted to copy from something that isn't a PyTorch tensor" + + pta_error(ErrCode.TYPE) + ) + + dst_is_hif8 = isinstance(dst, _HiFloat8Tensor) + src_is_hif8 = isinstance(src, _HiFloat8Tensor) + + if not dst_is_hif8 and src_is_hif8: + src_hp = src.from_hifloat8() + return aten_op(dst, src_hp, *args[2:], **kwargs) + + if dst_is_hif8 and src_is_hif8: + fp8_out = aten_op(dst._data, src._data, *args[2:], **kwargs) + return _wrap_hif8_like(dst, fp8_out) + + if dst_is_hif8 and not src_is_hif8: + + if _is_fakeish_tensor(dst._data) or _is_fakeish_tensor(src): + if not _is_fakeish_tensor(src): + src = dst._data.new_empty_strided( + src.size(), + src.stride(), + dtype=src.dtype, + ) + + # keep broadcast semantics check + _ = src.expand(dst.size()) + + # copy_ mutates dst; result keeps dst metadata + return _wrap_hif8_like(dst, dst._data) + + + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, + ) + + if not dst._data.is_contiguous(): + raise RuntimeError( + "Transformer Engine cast kernels require contiguous data" + + pta_error(ErrCode.INTERNAL) + ) + tex.cast_to_fp8_noalloc( + src.contiguous().view(1, -1), + dst._data.view(1, -1), + NPU_CUSTOM_DType[dst._data.dtype], + ) + return dst + + raise RuntimeError( + "Using HiFloat8Tensor copy logic, but no HiFloat8Tensor found" + + pta_error(ErrCode.INTERNAL) + ) \ No newline at end of file