Better handling of symmetric case in from_rotation_matrix#10
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates Quaternion.from_rotation_matrix to better handle the “symmetric rotation matrix” edge case (notably 180° rotations where the trace is −1) and adds a regression-style test that exercises conversion across a span of rotations around π.
Changes:
- Reworks
from_rotation_matrixto compute an unnormalized quaternion from(1 + trace, R - Rᵀ)and normalizes at return. - Simplifies symmetric-case detection to trigger when
1 + trace ≈ 0(i.e., trace ≈ −1). - Adds a test that converts a dense SLERP-generated quaternion span to rotation matrices and back.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/quatorch/quaternion.py |
Adjusts quaternion extraction from rotation matrices; updates symmetric-case masking and normalizes output. |
test/unit_tests/test_quaternion.py |
Adds a regression test spanning rotations around π to validate from_rotation_matrix continuity/accuracy. |
Comments suppressed due to low confidence (1)
src/quatorch/quaternion.py:208
- The symmetric (trace≈-1) branch computes
vas a softmax-weighted sum of columns ofuuT. When two (or more) columns have equal norm but opposite sign (e.g., axis components with equal magnitude and opposite sign), the weighted sum can cancel to ~0, makingu = v / v.norm(...)produce NaNs. Consider selecting a single well-conditioned column/row viaargmaxonvs_norm(or using an eigenvector / diagonal-based reconstruction) instead of averaging columns, and guard againstv.norm()==0.
uuT = (R[mask] + torch.eye(3, device=R.device, dtype=R.dtype)) / 2
vs_norm = torch.norm(uuT, dim=-2, keepdim=True)
v = torch.einsum(
"... cd, ... cd-> ...c", uuT, torch.softmax(vs_norm * 100, dim=-1)
)
u = v / v.norm(dim=-1, keepdim=True)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] | ||
| w = torch.sqrt(1.0 + trace) / 2.0 | ||
| asR = (R - R.transpose(-2, -1)) / (4.0 * w.view(-1, 1, 1)) | ||
| w = 1.0 + trace |
There was a problem hiding this comment.
I would, for your future selfs sake, use a better name here, e.g.
ww = (1.0 + trace) / 4.0| asR = (R - R.transpose(-2, -1)) / (4.0 * w.view(-1, 1, 1)) | ||
| w = 1.0 + trace | ||
| asR = R - R.transpose(-2, -1) # anti-symmetric part of R | ||
| x = asR[..., 2, 1] |
There was a problem hiding this comment.
similarly, better naming here could be useful:
wx = asR[..., 2, 1] / 4.0
wy = asR[..., 0, 2] / 4.0
wz = asR[..., 1, 0] / 4.0| # Excluding identity, as it's symmetric, but it's not problematic | ||
| mask = symmetric_mask & ~identity_mask | ||
| eps = torch.finfo(w.dtype).resolution * 2 | ||
| mask = w < eps # w should be non-negative for SO(3) matrices |
There was a problem hiding this comment.
in theory the w defined above is in [0.0,4.0], but precision errors could violate this, I think.
Either way, if it does become negative it should definitely be in the mask, so not a concern here 🤷♂️
| @@ -220,7 +217,7 @@ def from_rotation_matrix(R: torch.Tensor) -> "Quaternion": | |||
|
|
|||
| q = torch.stack([w, x, y, z], dim=-1) | |||
There was a problem hiding this comment.
with the naming suggestions above, this would become more readable as
q = torch.stack([ww, wx, wy, wz], dim=-1)making clear that in the next normalization step the division by w will reappear 😢
… scipy Co-authored-by: jvdoorss <jvdoorss@gmail.com>
cc @jvdoorss