Skip to content

Better handling of symmetric case in from_rotation_matrix#10

Merged
egidioln merged 11 commits into
mainfrom
9-better-handling-of-symmetric-case-in-from_rotation_matrix
Apr 19, 2026
Merged

Better handling of symmetric case in from_rotation_matrix#10
egidioln merged 11 commits into
mainfrom
9-better-handling-of-symmetric-case-in-from_rotation_matrix

Conversation

@egidioln
Copy link
Copy Markdown
Owner

@egidioln egidioln commented Apr 15, 2026

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates Quaternion.from_rotation_matrix to better handle the “symmetric rotation matrix” edge case (notably 180° rotations where the trace is −1) and adds a regression-style test that exercises conversion across a span of rotations around π.

Changes:

  • Reworks from_rotation_matrix to compute an unnormalized quaternion from (1 + trace, R - Rᵀ) and normalizes at return.
  • Simplifies symmetric-case detection to trigger when 1 + trace ≈ 0 (i.e., trace ≈ −1).
  • Adds a test that converts a dense SLERP-generated quaternion span to rotation matrices and back.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/quatorch/quaternion.py Adjusts quaternion extraction from rotation matrices; updates symmetric-case masking and normalizes output.
test/unit_tests/test_quaternion.py Adds a regression test spanning rotations around π to validate from_rotation_matrix continuity/accuracy.
Comments suppressed due to low confidence (1)

src/quatorch/quaternion.py:208

  • The symmetric (trace≈-1) branch computes v as a softmax-weighted sum of columns of uuT. When two (or more) columns have equal norm but opposite sign (e.g., axis components with equal magnitude and opposite sign), the weighted sum can cancel to ~0, making u = v / v.norm(...) produce NaNs. Consider selecting a single well-conditioned column/row via argmax on vs_norm (or using an eigenvector / diagonal-based reconstruction) instead of averaging columns, and guard against v.norm()==0.
        uuT = (R[mask] + torch.eye(3, device=R.device, dtype=R.dtype)) / 2
        vs_norm = torch.norm(uuT, dim=-2, keepdim=True)
        v = torch.einsum(
            "... cd, ... cd-> ...c", uuT, torch.softmax(vs_norm * 100, dim=-1)
        )
        u = v / v.norm(dim=-1, keepdim=True)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/quatorch/quaternion.py Outdated
Comment thread test/unit_tests/test_quaternion.py
Copy link
Copy Markdown
Contributor

@jvdoorss jvdoorss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some naming suggestions, but the new logic does seem an improvement 👍 I'll add some more comments in #9

Comment thread src/quatorch/quaternion.py Outdated
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
w = torch.sqrt(1.0 + trace) / 2.0
asR = (R - R.transpose(-2, -1)) / (4.0 * w.view(-1, 1, 1))
w = 1.0 + trace
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would, for your future selfs sake, use a better name here, e.g.

ww = (1.0 + trace) / 4.0

Comment thread src/quatorch/quaternion.py Outdated
asR = (R - R.transpose(-2, -1)) / (4.0 * w.view(-1, 1, 1))
w = 1.0 + trace
asR = R - R.transpose(-2, -1) # anti-symmetric part of R
x = asR[..., 2, 1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, better naming here could be useful:

wx = asR[..., 2, 1] / 4.0
wy = asR[..., 0, 2] / 4.0
wz = asR[..., 1, 0] / 4.0

Comment thread src/quatorch/quaternion.py Outdated
# Excluding identity, as it's symmetric, but it's not problematic
mask = symmetric_mask & ~identity_mask
eps = torch.finfo(w.dtype).resolution * 2
mask = w < eps # w should be non-negative for SO(3) matrices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory the w defined above is in [0.0,4.0], but precision errors could violate this, I think.

Either way, if it does become negative it should definitely be in the mask, so not a concern here 🤷‍♂️

Comment thread src/quatorch/quaternion.py Outdated
@@ -220,7 +217,7 @@ def from_rotation_matrix(R: torch.Tensor) -> "Quaternion":

q = torch.stack([w, x, y, z], dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the naming suggestions above, this would become more readable as

q = torch.stack([ww, wx, wy, wz], dim=-1)

making clear that in the next normalization step the division by w will reappear 😢

@egidioln egidioln merged commit 8efc4ca into main Apr 19, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Better handling of symmetric case in `from_rotation_matrix

3 participants