Motivation and description
Say I have the arrays x[i,b], y[j,b] and A[i,j,b]. Is there an efficient way to do the following "batched dot" operation:
[sum(x[i,b] * A[i,j,b] * y[j,b] for i = axes(A,1) for j = axes(A,2)) for b = ...]
where b traverses the batch dimension. As usual, we could have size(x,2) == 1, size(A,3)==1, ..., which would mean the corresponding missing dimension is broadcasted.
Apologies if there is already a way to do this (efficiently) with existing functions in NNlib, I could not figure it out.
Possible Implementation
No response
Motivation and description
Say I have the arrays
x[i,b],y[j,b]andA[i,j,b]. Is there an efficient way to do the following "batched dot" operation:where
btraverses the batch dimension. As usual, we could havesize(x,2) == 1,size(A,3)==1, ..., which would mean the corresponding missing dimension is broadcasted.Apologies if there is already a way to do this (efficiently) with existing functions in
NNlib, I could not figure it out.Possible Implementation
No response