Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions __init__.py

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit concerned about this - the export pipeline seems too complicated.

Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Warp-accelerated PyPose via global monkey-patching.

``import PyposeWarp`` patches ``pypose`` globally so that every
``import pypose as pp`` yields warp-accelerated Lie constructors.

For code that needs the original (unpatched) constructors — e.g. the EKF
Jacobian pipeline that relies on ``torch.autograd.functional.jacobian`` —
use ``import PyposeWarp.generic as pp`` instead.
"""

import sys
import types
from typing import Any, cast
from functools import partial

import pypose
from . import pypose_warp

_wp = cast(Any, pypose_warp)

# ====== Save originals before patching ======

_PATCHED_ATTRS = [
"so3", "SO3", "se3", "SE3",
"SO3_type", "so3_type", "SE3_type", "se3_type",
"identity_SO3", "identity_SE3", "identity_so3", "identity_se3",
"randn_SO3", "randn_SE3", "randn_so3", "randn_se3",
"from_matrix", "euler2SO3", "mat2SO3", "mat2SE3",
]

_orig: dict[str, Any] = {name: getattr(pypose, name) for name in _PATCHED_ATTRS}

# ====== Build PyposeWarp.generic module ======

generic = types.ModuleType("PyposeWarp.generic")
generic.__doc__ = (
"Original (unpatched) PyPose constructors. Use this in files that "
"need torch.autograd.functional.jacobian compatibility."
)
generic.__dict__.update(_orig)
generic.__package__ = "PyposeWarp"


def _generic_getattr(name: str) -> Any:
"""Delegate non-patched attributes to the real pypose module."""
return getattr(pypose, name)


generic.__getattr__ = _generic_getattr # type: ignore[attr-defined]
sys.modules["PyposeWarp.generic"] = generic

# ====== Warp-backed wrapper functions ======


def _wp_identity(orig_fn: Any, warp_ltype: Any) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> pypose.LieTensor:
result = orig_fn(*args, **kwargs)
result.ltype = warp_ltype
return result
return wrapper


def _wp_randn(orig_fn: Any, warp_ltype: Any) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> pypose.LieTensor:
result = orig_fn(*args, **kwargs)
result.ltype = warp_ltype
return result
return wrapper


def wp_from_matrix(mat: Any, ltype: Any, check: bool = True,
rtol: float = 1e-5, atol: float = 1e-5) -> pypose.LieTensor:
if ltype is _wp.warpSE3_type:
result = _orig["from_matrix"](mat, _orig["SE3_type"], check=check, rtol=rtol, atol=atol)
result.ltype = _wp.warpSE3_type
elif ltype is _wp.warpSO3_type:
result = _orig["from_matrix"](mat, _orig["SO3_type"], check=check, rtol=rtol, atol=atol)
result.ltype = _wp.warpSO3_type
else:
result = _orig["from_matrix"](mat, ltype, check=check, rtol=rtol, atol=atol)
return result


def wp_euler2SO3(*args: Any, **kwargs: Any) -> pypose.LieTensor:
result = _orig["euler2SO3"](*args, **kwargs)
result.ltype = _wp.warpSO3_type
return result


def wp_mat2SO3(*args: Any, **kwargs: Any) -> pypose.LieTensor:
result = _orig["mat2SO3"](*args, **kwargs)
result.ltype = _wp.warpSO3_type
return result


def wp_mat2SE3(*args: Any, **kwargs: Any) -> pypose.LieTensor:
result = _orig["mat2SE3"](*args, **kwargs)
result.ltype = _wp.warpSE3_type
return result


# ====== Global monkey-patching ======

# Type objects
pypose.SE3_type = _wp.warpSE3_type
pypose.SO3_type = _wp.warpSO3_type
pypose.se3_type = _wp.warpse3_type
pypose.so3_type = _wp.warpso3_type

# Constructors
pypose.SO3 = partial(pypose.LieTensor, ltype=_wp.warpSO3_type)
pypose.so3 = partial(pypose.LieTensor, ltype=_wp.warpso3_type)
pypose.SE3 = partial(pypose.LieTensor, ltype=_wp.warpSE3_type)
pypose.se3 = partial(pypose.LieTensor, ltype=_wp.warpse3_type)

# Identity functions
pypose.identity_SE3 = _wp_identity(_orig["identity_SE3"], _wp.warpSE3_type)
pypose.identity_SO3 = _wp_identity(_orig["identity_SO3"], _wp.warpSO3_type)
pypose.identity_se3 = _wp_identity(_orig["identity_se3"], _wp.warpse3_type)
pypose.identity_so3 = _wp_identity(_orig["identity_so3"], _wp.warpso3_type)

# Randn functions
pypose.randn_SE3 = _wp_randn(_orig["randn_SE3"], _wp.warpSE3_type)
pypose.randn_SO3 = _wp_randn(_orig["randn_SO3"], _wp.warpSO3_type)
pypose.randn_se3 = _wp_randn(_orig["randn_se3"], _wp.warpse3_type)
pypose.randn_so3 = _wp_randn(_orig["randn_so3"], _wp.warpso3_type)

# Conversion functions
pypose.from_matrix = wp_from_matrix
pypose.euler2SO3 = wp_euler2SO3
pypose.mat2SO3 = wp_mat2SO3
pypose.mat2SE3 = wp_mat2SE3

# ====== Patch pypose internal module bindings ======
# These modules bind `so3`, `identity_SO3` etc. at import time via
# ``from .. import so3, identity_SO3``, so we must patch those bindings
# directly after global patching.

import pypose.module.imu_preintegrator as _preint # noqa: E402

_preint.so3 = pypose.so3
_preint.SO3 = pypose.SO3
_preint.identity_SO3 = pypose.identity_SO3
_preint.LieTensor = pypose.LieTensor

# ====== Re-exports ======

from .pypose_warp import to_warp_backend, to_pypose_backend, is_warp_backend # noqa: E402


# ====== Module-level delegation ======


def __getattr__(name: str) -> Any:
"""Delegate unresolved attributes to the (now patched) pypose module."""
return getattr(pypose, name)


def __dir__() -> list[str]:
"""Expose both local symbols and pypose public attributes."""
return sorted(set(globals().keys()) | set(dir(pypose)))
10 changes: 10 additions & 0 deletions generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Original (unpatched) PyPose constructors for autograd-compatible code.

This module is populated by ``PyposeWarp.__init__`` before monkey-patching.
Use ``import PyposeWarp.generic as pp`` in files that need
``torch.autograd.functional.jacobian`` compatibility (e.g. EKF NLS pipeline).

Attributes not explicitly saved (e.g. ``module``, ``optim``, ``vec2skew``)
are delegated to the real ``pypose`` module via ``__getattr__``; these are
unaffected by monkey-patching so they work identically either way.
"""
6 changes: 2 additions & 4 deletions pypose_warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
]
_PP_TO_WP = {pp_ltype : wp_ltype for pp_ltype, wp_ltype in _BACKEND_LIST}
_WP_TO_PP = {wp_ltype : pp_ltype for pp_ltype, wp_ltype in _BACKEND_LIST}
_ORIG_PP_TYPES = frozenset(pp_ltype for pp_ltype, _ in _BACKEND_LIST)


def to_warp_backend(x: pp.LieTensor) -> pp.LieTensor:
Expand Down Expand Up @@ -46,7 +47,4 @@ def is_warp_backend(x: pp.LieTensor) -> bool:


def is_pypose_backend(x: pp.LieTensor) -> bool:
return x.ltype in {
pp.SE3_type, pp.SO3_type, pp.RxSO3_type, pp.Sim3_type,
pp.se3_type, pp.so3_type, pp.rxso3_type, pp.sim3_type
}
return x.ltype in _ORIG_PP_TYPES