Skip to content

Second order gradients of Normal distribution #33

Description

@pevnak

I need the normal distribution to support second order gradients of Gaussian distribution, since the current implementation relies on broadcasting, which breaks down.

I have quickly hack it as follows, but I would be happy if it gets merged with proper tests, such that I am sure it does what is supposed to do.


function _l(x::Matrix{T}, n, μ, σ2) where {T}
	-(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
end 
	
function _∇l(Δ, x, n, μ, σ2)
	Δ = Δ'
	δ = Δ .* (x - μ) ./ σ2
    (- δ, nothing, δ, Δ .* (((x - μ).^2 ./ (σ2.^2))  - 1 ./ σ2) / 2)
end


function Distributions.logpdf(d::ConditionalDists.BMN, x::Matrix{T}) where T<:Real
    n = size(d.μ,1)
    μ = mean(d)
    σ2 = var(d)
    _l(x, n, μ, σ2)
end

Zygote.@adjoint function _l(x, n, μ, σ2)
	_l(x, n, μ, σ2), Δ -> _∇l(Δ, x, n, μ, σ2)
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions