Skip to content

Commit b89e114

Browse files
committed
(FIX) handle rowvar in pairedcorr denominator
1 parent 0c46143 commit b89e114

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/nitrix/functional/covariance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,10 @@ def pairedcorr(
445445
inddof = ddof
446446
if inddof is None:
447447
inddof = 1 - bias
448-
varX = X.var(-1, keepdims=True, ddof=inddof)
449-
varY = Y.var(-1, keepdims=True, ddof=inddof)
450-
fact = jax.lax.sqrt(varX @ varY.swapaxes(-2, -1))
448+
ax = -1 if rowvar else -2
449+
varX = X.var(ax, ddof=inddof)[..., None]
450+
varY = Y.var(ax, ddof=inddof)[..., None, :]
451+
fact = jax.lax.sqrt(varX @ varY)
451452
return (
452453
pairedcov(X, Y, rowvar=rowvar, bias=bias, ddof=ddof, **params) / fact
453454
)

0 commit comments

Comments
 (0)