Skip to content
Open
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
# Manifest.toml

# visual studio settings folder
.vscode/*
52 changes: 40 additions & 12 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ authors = ["Niklas Heim <heim.niklas@gmail.com>"]
version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
11 changes: 11 additions & 0 deletions src/GenerativeModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ module GenerativeModels
using Zygote: @nograd, @adjoint
using DiffEqBase: ODEProblem, solve
using OrdinaryDiffEq: Tsit5
using SpecialFunctions
using Distributions
using Adapt

abstract type AbstractGM end
abstract type AbstractVAE{T<:Real} <: AbstractGM end
abstract type AbstractGAN{T<:Real} <: AbstractGM end
abstract type AbstractSVAE{T<:Real} <: AbstractGM end

# functions that are overloaded by this module
import Base.length
import Random.rand
import Statistics.mean
import SpecialFunctions: besselix, logabsgamma

# needed to make e.g. sampling work
@nograd similar, randn!, fill!
Expand All @@ -24,15 +29,21 @@ module GenerativeModels
include(joinpath("utils", "nogradarray.jl"))
include(joinpath("utils", "saveload.jl"))
include(joinpath("utils", "utils.jl"))
include(joinpath("utils", "vmf.jl"))
include(joinpath("utils", "flux_ode_decoder.jl"))

include(joinpath("pdfs", "gaussian.jl"))
include(joinpath("pdfs", "hs_uniform.jl"))
include(joinpath("pdfs", "vonmisesfisher.jl"))
include(joinpath("pdfs", "abstract_cgaussian.jl"))
include(joinpath("pdfs", "cmean_gaussian.jl"))
include(joinpath("pdfs", "cmeanvar_gaussian.jl"))
include(joinpath("pdfs", "abstract_cvmf.jl"))
include(joinpath("pdfs", "cmeanconc_vmf.jl"))

include(joinpath("models", "vae.jl"))
include(joinpath("models", "rodent.jl"))
include(joinpath("models", "gan.jl"))
include(joinpath("models", "svae.jl"))

end # module
87 changes: 87 additions & 0 deletions src/models/svae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
export SVAE, SVAE_vmf_prior, SVAE_hsu_prior

"""
SVAE{T}([prior::Union{HypersphericalUniform{T}, VonMisesFisher{T}}, zlen::Int] encoder::AbstractCVMF, decoder::AbstractCPDF)

HyperSpherical Variational Auto-Encoder.

# Example
Create an S-VAE with either HSU prior or VMF prior with μ = [1, 0, ..., 0] and κ = 1 with:
```julia-repl
julia> enc = CMeanConcVMF{Float32}(Dense(5,4), 3)
CMeanConcVMF{Float32}(mapping=Dense(5, 4), μ_from_hidden=Chain(Dense(4, 3), #51), κ_from_hidden=Dense(4, 1, #52))

julia> dec = CMeanVarGaussian{Float32,ScalarVar}(Dense(3, 6))
CMeanVarGaussian{Float32,ScalarVar}(mapping=Dense(3, 6))

julia> svae = SVAE(HypersphericalUniform{Float32}(3), enc, dec)
SVAE{Float32}:
prior = HypersphericalUniform{Float32}(3)
encoder = (CMeanConcVMF{Float32}(mapping=Dense(5, 4), μ_from_hidden=Chain(Dens...)
decoder = CMeanVarGaussian{Float32,ScalarVar}(mapping=Dense(3, 6))

julia> mean(svae.decoder, mean(svae.encoder, rand(5, 1)))
5×1 Array{Float32,2}:
-0.7267006
0.6847478
-0.032789093
0.13542232
-0.270345421

julia> elbo(svae, rand(Float32, 5, 1))
15.011719478567946
```
"""
struct SVAE{T} <: AbstractSVAE{T}
prior::Union{HypersphericalUniform{T}, VonMisesFisher{T}}
encoder::AbstractCVMF{T}
decoder::AbstractCPDF{T}
end

Flux.@functor SVAE

# SVAE(p::Union{HypersphericalUniform{T}, VonMisesFisher{T}}, e::AbstractCVMF{T}, d::AbstractCPDF{T}) where T = SVAE{T}(p, e, d)

function SVAE_vmf_prior(zlength::Int, enc::AbstractCPDF{T}, dec::AbstractCPDF{T}) where T
μp = NoGradArray(zeros(T, zlength))
μp[1] = T(1)
κp = NoGradArray(ones(T, 1))
prior = VonMisesFisher(μp, κp)
SVAE{T}(prior, enc, dec)
end

function SVAE_hsu_prior(zlength::Int, enc::AbstractCPDF{T}, dec::AbstractCPDF{T}) where T
prior = HypersphericalUniform{T}(zlength)
SVAE{T}(prior, enc, dec)
end

"""
elbo(m::SVAE, x::AbstractArray; β=1)

Evidence lower boundary of the SVAE model. `β` scales the KLD term. (Assumes hyperspherical uniform prior)
"""
function elbo(m::SVAE{T}, x::AbstractArray{T}; β=T(1)) where {T}
z = rand(m.encoder, x)
llh = mean(-loglikelihood(m.decoder, x, z))
kl = mean(kld(m.encoder, m.prior, x))
llh + β*kl
end

"""
mmd(m::SVAE, x::AbstractArray, k)

Maximum mean discrepancy of a SVAE model given data `x` and kernel function `k(x,y)`.
"""
mmd(m::SVAE{T}, x::AbstractArray{T}, k) where {T} = mmd(m.encoder, m.prior, x, k)

function Base.show(io::IO, m::SVAE{T}) where T
p = short_repr(m.prior, 70)
e = short_repr(m.encoder, 70)
d = short_repr(m.decoder, 70)
msg = """$(typeof(m)):
prior = $(p)
encoder = $(e)
decoder = $(d)
"""
print(io, msg)
end
54 changes: 54 additions & 0 deletions src/pdfs/abstract_cvmf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
export loglikelihood, kld, rand, mean_conc, concentration

abstract type AbstractCVMF{T} <: AbstractCPDF{T} end

function rand(p::AbstractCVMF{T}, z::AbstractArray{T}) where {T}
(μ, κ) = mean_conc(p, z)
sample_vmf(μ, κ)
end

function loglikelihood(p::AbstractCVMF{T}, x::AbstractArray{T}, z::AbstractArray{T}) where {T}
(μ, κ) = mean_conc(p, z)
log_vmf(x, μ, κ)
end

# This is here because we always compute KLD with VMF and hyperspherical uniform - nothing else as KLD between two VMFs is rather complicated to compute
"""
kld(p::AbstractCVMF, q::HypersphericalUniform, z::AbstractArray)

Compute Kullback-Leibler divergence between a conditional Von Mises-Fisher distribution `p` given `z`
and a hyperspherical uniform distribution `q` with the same dimensionality.
"""
function kld(p::AbstractCVMF{T}, q::HypersphericalUniform{T}, z::AbstractArray{T}) where {T}
(μ, κ) = mean_conc(p, z)
if size(μ, 1) != q.dims
error("Cannot compute KLD between VMF and HSU with different dimensionality")
end
.- vmfentropy.(q.dims, κ) .+ huentropy(q.dims, T)
end

"""
mean_conc(p::AbstractCVMF, z::AbstractArray)

Returns mean and concentration of a conditional VMF distribution.
"""
mean_conc(p::AbstractCVMF, z::AbstractArray) = error("Not implemented!")


"""
mean(p::AbstractCVMF, z::AbstractArray)

Returns mean of a conditional VMF distribution.
"""
mean(p::AbstractCVMF, z::AbstractArray) = mean_conc(p, z)[1]

"""
concentration(p::AbstractCVMF, z::AbstractArray)

Returns variance of a conditional VMF distribution.
"""
concentration(p::AbstractCVMF, z::AbstractArray) = mean_conc(p, z)[2]




Loading