Skip to content

Commit 37f87ee

Browse files
committed
Change check from rotation to orthonormal
1 parent cdf1e95 commit 37f87ee

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

diffdrr/pose.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from einops import rearrange
11-
from roma import is_rotation_matrix
11+
from roma import is_orthonormal_matrix
1212

1313

1414
class RigidTransform(torch.nn.Module):
@@ -18,11 +18,12 @@ class RigidTransform(torch.nn.Module):
1818
inversion, and conversions to various representations of SE(3).
1919
"""
2020

21-
def __init__(self, matrix):
21+
def __init__(self, matrix, eps=1e-6):
2222
super().__init__()
2323
if matrix.dim() == 2:
2424
matrix = matrix.unsqueeze(0)
2525
self.register_buffer("matrix", matrix)
26+
self.eps = eps
2627

2728
def __len__(self):
2829
return len(self.matrix)
@@ -44,7 +45,7 @@ def translation(self):
4445
return self.matrix[..., :3, 3]
4546

4647
def inverse(self):
47-
if is_rotation_matrix(self.matrix[..., :3, :3]):
48+
if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps):
4849
R = self.matrix[..., :3, :3]
4950
t = self.matrix[..., :3, 3]
5051
Rinv = R.mT

notebooks/api/06_pose.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
"import torch\n",
9393
"\n",
9494
"from einops import rearrange\n",
95-
"from roma import is_rotation_matrix\n",
95+
"from roma import is_orthonormal_matrix\n",
9696
"\n",
9797
"\n",
9898
"class RigidTransform(torch.nn.Module):\n",
@@ -102,11 +102,12 @@
102102
" inversion, and conversions to various representations of SE(3).\n",
103103
" \"\"\"\n",
104104
"\n",
105-
" def __init__(self, matrix):\n",
105+
" def __init__(self, matrix, eps=1e-6):\n",
106106
" super().__init__()\n",
107107
" if matrix.dim() == 2:\n",
108108
" matrix = matrix.unsqueeze(0)\n",
109109
" self.register_buffer(\"matrix\", matrix)\n",
110+
" self.eps = eps\n",
110111
"\n",
111112
" def __len__(self):\n",
112113
" return len(self.matrix)\n",
@@ -128,7 +129,7 @@
128129
" return self.matrix[..., :3, 3]\n",
129130
"\n",
130131
" def inverse(self):\n",
131-
" if is_rotation_matrix(self.matrix[..., :3, :3]):\n",
132+
" if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps):\n",
132133
" R = self.matrix[..., :3, :3]\n",
133134
" t = self.matrix[..., :3, 3]\n",
134135
" Rinv = R.mT\n",

0 commit comments

Comments
 (0)