Thank you for your great work. I have a question about definition 3.2:
To my understanding, the transport map $T_{m,l}$ maps C(source distribution) to D(target distribution). Why do we map C(source distribution) to a new C(source distribution) and D(target distribution) to a new D(target distribution) using the map between the source and the target? Does it make any sense to compute the transport map in the next layer?
I find that there is a reverse mapping parameter in the LinearProj class, but I didn't find the code that calls it:
class LinearProj(nn.Module):
def __init__(
self,
dim: int,
):
super().__init__()
self.w1 = nn.Parameter(torch.randn(1, dim))
self.b1 = nn.Parameter(torch.zeros((1, dim)))
def forward(self, x: torch.Tensor, reverse: bool = False):
assert x.shape[-1] == self.w1.shape[-1]
if not reverse:
return x * self.w1 + self.b1
else:
return (x - self.b1) / (self.w1 + 1e-10)
Is this reverse mapping related to definition 3.2? Could you please elaborate on it? Thank you!
Thank you for your great work. I have a question about definition 3.2:
To my understanding, the transport map$T_{m,l}$ maps C(source distribution) to D(target distribution). Why do we map C(source distribution) to a new C(source distribution) and D(target distribution) to a new D(target distribution) using the map between the source and the target? Does it make any sense to compute the transport map in the next layer?
I find that there is a reverse mapping parameter in the LinearProj class, but I didn't find the code that calls it:
Is this reverse mapping related to definition 3.2? Could you please elaborate on it? Thank you!