Skip to content

hafaio/jaxsne

Repository files navigation

jaxsne

build pypi docs

A library for doing dimensionality reduction in different metric spaces, or using different distributions.

In addition to regular t-SNE for MNIST

tsne

this can also be used to embed points on the sphere

ssne

or even into hierarchical hyperbolic space

psne

The downside is that this is generally less performant than the t-SNE provided by scikit-learn, so should only be used if you want to tweak the metrics or measures.

Installation

pip install jaxsne

Basic Usage

import jaxsne

data = ... # n x d
reduced = jaxsne.sne(data)
# or
reduced = jaxsne.scaling(data)

Advanced Usage

import jaxsne
import jax
from jax import Array
from jax import numpy as jnp

@jax.jit
def manhattan(left: Array, right: Array) -> Array:
    return jnp.abs(left - right).sum(-1)


data = ... # n x d
reduced = jaxsne.sne(data, in_metric=manhattan, out_metric=manhattan)

Development

uv run ruff format --check
uv run ruff check
uv run pyright
uv run pytest

Publishing

rm -rf dist
uv build
uv publish --username __token__

Tasks

About

A library for doing dimensionality reduction with different curvature

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors