diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..da10dbd --- /dev/null +++ b/__init__.py @@ -0,0 +1,161 @@ +"""Warp-accelerated PyPose via global monkey-patching. + +``import PyposeWarp`` patches ``pypose`` globally so that every +``import pypose as pp`` yields warp-accelerated Lie constructors. + +For code that needs the original (unpatched) constructors — e.g. the EKF +Jacobian pipeline that relies on ``torch.autograd.functional.jacobian`` — +use ``import PyposeWarp.generic as pp`` instead. +""" + +import sys +import types +from typing import Any, cast +from functools import partial + +import pypose +from . import pypose_warp + +_wp = cast(Any, pypose_warp) + +# ====== Save originals before patching ====== + +_PATCHED_ATTRS = [ + "so3", "SO3", "se3", "SE3", + "SO3_type", "so3_type", "SE3_type", "se3_type", + "identity_SO3", "identity_SE3", "identity_so3", "identity_se3", + "randn_SO3", "randn_SE3", "randn_so3", "randn_se3", + "from_matrix", "euler2SO3", "mat2SO3", "mat2SE3", +] + +_orig: dict[str, Any] = {name: getattr(pypose, name) for name in _PATCHED_ATTRS} + +# ====== Build PyposeWarp.generic module ====== + +generic = types.ModuleType("PyposeWarp.generic") +generic.__doc__ = ( + "Original (unpatched) PyPose constructors. Use this in files that " + "need torch.autograd.functional.jacobian compatibility." +) +generic.__dict__.update(_orig) +generic.__package__ = "PyposeWarp" + + +def _generic_getattr(name: str) -> Any: + """Delegate non-patched attributes to the real pypose module.""" + return getattr(pypose, name) + + +generic.__getattr__ = _generic_getattr # type: ignore[attr-defined] +sys.modules["PyposeWarp.generic"] = generic + +# ====== Warp-backed wrapper functions ====== + + +def _wp_identity(orig_fn: Any, warp_ltype: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> pypose.LieTensor: + result = orig_fn(*args, **kwargs) + result.ltype = warp_ltype + return result + return wrapper + + +def _wp_randn(orig_fn: Any, warp_ltype: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> pypose.LieTensor: + result = orig_fn(*args, **kwargs) + result.ltype = warp_ltype + return result + return wrapper + + +def wp_from_matrix(mat: Any, ltype: Any, check: bool = True, + rtol: float = 1e-5, atol: float = 1e-5) -> pypose.LieTensor: + if ltype is _wp.warpSE3_type: + result = _orig["from_matrix"](mat, _orig["SE3_type"], check=check, rtol=rtol, atol=atol) + result.ltype = _wp.warpSE3_type + elif ltype is _wp.warpSO3_type: + result = _orig["from_matrix"](mat, _orig["SO3_type"], check=check, rtol=rtol, atol=atol) + result.ltype = _wp.warpSO3_type + else: + result = _orig["from_matrix"](mat, ltype, check=check, rtol=rtol, atol=atol) + return result + + +def wp_euler2SO3(*args: Any, **kwargs: Any) -> pypose.LieTensor: + result = _orig["euler2SO3"](*args, **kwargs) + result.ltype = _wp.warpSO3_type + return result + + +def wp_mat2SO3(*args: Any, **kwargs: Any) -> pypose.LieTensor: + result = _orig["mat2SO3"](*args, **kwargs) + result.ltype = _wp.warpSO3_type + return result + + +def wp_mat2SE3(*args: Any, **kwargs: Any) -> pypose.LieTensor: + result = _orig["mat2SE3"](*args, **kwargs) + result.ltype = _wp.warpSE3_type + return result + + +# ====== Global monkey-patching ====== + +# Type objects +pypose.SE3_type = _wp.warpSE3_type +pypose.SO3_type = _wp.warpSO3_type +pypose.se3_type = _wp.warpse3_type +pypose.so3_type = _wp.warpso3_type + +# Constructors +pypose.SO3 = partial(pypose.LieTensor, ltype=_wp.warpSO3_type) +pypose.so3 = partial(pypose.LieTensor, ltype=_wp.warpso3_type) +pypose.SE3 = partial(pypose.LieTensor, ltype=_wp.warpSE3_type) +pypose.se3 = partial(pypose.LieTensor, ltype=_wp.warpse3_type) + +# Identity functions +pypose.identity_SE3 = _wp_identity(_orig["identity_SE3"], _wp.warpSE3_type) +pypose.identity_SO3 = _wp_identity(_orig["identity_SO3"], _wp.warpSO3_type) +pypose.identity_se3 = _wp_identity(_orig["identity_se3"], _wp.warpse3_type) +pypose.identity_so3 = _wp_identity(_orig["identity_so3"], _wp.warpso3_type) + +# Randn functions +pypose.randn_SE3 = _wp_randn(_orig["randn_SE3"], _wp.warpSE3_type) +pypose.randn_SO3 = _wp_randn(_orig["randn_SO3"], _wp.warpSO3_type) +pypose.randn_se3 = _wp_randn(_orig["randn_se3"], _wp.warpse3_type) +pypose.randn_so3 = _wp_randn(_orig["randn_so3"], _wp.warpso3_type) + +# Conversion functions +pypose.from_matrix = wp_from_matrix +pypose.euler2SO3 = wp_euler2SO3 +pypose.mat2SO3 = wp_mat2SO3 +pypose.mat2SE3 = wp_mat2SE3 + +# ====== Patch pypose internal module bindings ====== +# These modules bind `so3`, `identity_SO3` etc. at import time via +# ``from .. import so3, identity_SO3``, so we must patch those bindings +# directly after global patching. + +import pypose.module.imu_preintegrator as _preint # noqa: E402 + +_preint.so3 = pypose.so3 +_preint.SO3 = pypose.SO3 +_preint.identity_SO3 = pypose.identity_SO3 +_preint.LieTensor = pypose.LieTensor + +# ====== Re-exports ====== + +from .pypose_warp import to_warp_backend, to_pypose_backend, is_warp_backend # noqa: E402 + + +# ====== Module-level delegation ====== + + +def __getattr__(name: str) -> Any: + """Delegate unresolved attributes to the (now patched) pypose module.""" + return getattr(pypose, name) + + +def __dir__() -> list[str]: + """Expose both local symbols and pypose public attributes.""" + return sorted(set(globals().keys()) | set(dir(pypose))) diff --git a/generic.py b/generic.py new file mode 100644 index 0000000..ef6e4a1 --- /dev/null +++ b/generic.py @@ -0,0 +1,10 @@ +"""Original (unpatched) PyPose constructors for autograd-compatible code. + +This module is populated by ``PyposeWarp.__init__`` before monkey-patching. +Use ``import PyposeWarp.generic as pp`` in files that need +``torch.autograd.functional.jacobian`` compatibility (e.g. EKF NLS pipeline). + +Attributes not explicitly saved (e.g. ``module``, ``optim``, ``vec2skew``) +are delegated to the real ``pypose`` module via ``__getattr__``; these are +unaffected by monkey-patching so they work identically either way. +""" diff --git a/pypose_warp/__init__.py b/pypose_warp/__init__.py index 2968ce0..2bfede8 100644 --- a/pypose_warp/__init__.py +++ b/pypose_warp/__init__.py @@ -19,6 +19,7 @@ ] _PP_TO_WP = {pp_ltype : wp_ltype for pp_ltype, wp_ltype in _BACKEND_LIST} _WP_TO_PP = {wp_ltype : pp_ltype for pp_ltype, wp_ltype in _BACKEND_LIST} +_ORIG_PP_TYPES = frozenset(pp_ltype for pp_ltype, _ in _BACKEND_LIST) def to_warp_backend(x: pp.LieTensor) -> pp.LieTensor: @@ -46,7 +47,4 @@ def is_warp_backend(x: pp.LieTensor) -> bool: def is_pypose_backend(x: pp.LieTensor) -> bool: - return x.ltype in { - pp.SE3_type, pp.SO3_type, pp.RxSO3_type, pp.Sim3_type, - pp.se3_type, pp.so3_type, pp.rxso3_type, pp.sim3_type - } + return x.ltype in _ORIG_PP_TYPES