From c3e9b5d26111fddf7c63558819e70adee36ec62a Mon Sep 17 00:00:00 2001 From: AHsu98 <34590951+AHsu98@users.noreply.github.com> Date: Mon, 18 Dec 2023 20:33:47 -0800 Subject: [PATCH 1/8] Added group norm L0 and shifted group norm L0 --- src/groupNormL0.jl | 68 ++++++++++++++++++++++++++++++++++ src/shiftedGroupNormL0.jl | 77 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 src/groupNormL0.jl create mode 100644 src/shiftedGroupNormL0.jl diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl new file mode 100644 index 00000000..fc8f88c9 --- /dev/null +++ b/src/groupNormL0.jl @@ -0,0 +1,68 @@ +# Group L2 norm (times a constant) + +export GroupNormL0 + +@doc raw""" + GroupNormL0(λ = 1, idx = [:]) + +Returns the group ``\ell_0``-norm operator +```math +f(x) = \sum_i \lambda_i \| \|x_{[i]}\|_2 \|_0 +``` +for groups ``x_{[i]}`` and nonnegative weights ``\lambda_i``. +This assumes that the groups ``x_{[i]}`` are non-overlapping +""" +struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} + lambda::RR + idx::I + + function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} + if any(lambda .< 0) + error("weights λ must be nonnegative") + elseif length(lambda) != length(idx) + error("number of weights and groups must be the same") + else + new{R, RR, I}(lambda, idx) + end + end +end + +GroupNormL0(lambda::AbstractVector{R} = [one(R)], idx::I = [:]) where {R <: Real, I} = + GroupNormL0{R, typeof(lambda), I}(lambda, idx) + +function (f::GroupNormL0)(x::AbstractArray{R}) where {R <: Real} + sum_c = R(0) + for (idx, λ) ∈ zip(f.idx, f.lambda) + y = norm(x[idx]) + if y>0 + sum_c += λ + end + end + return sum_c +end + +function prox!( + y::AbstractArray{R}, + f::GroupNormL0{R, RR, I}, + x::AbstractArray{R}, + γ::R = R(1), +) where {R <: Real, RR <: AbstractVector{R}, I} + ysum = R(0) + for (idx, λ) ∈ zip(f.idx, f.lambda) + yt = norm(x[idx])^2 + if yt !=0 + ysum += λ + end + if yt <= 2 * γ * λ + y[idx] .= 0 + else + y[idx] .= x[idx] + end + end + return ysum +end + +fun_name(f::GroupNormL0) = "Group L₀-norm" +fun_dom(f::GroupNormL0) = "AbstractArray{Float64}, AbstractArray{Complex}" +fun_expr(f::GroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀" +fun_params(f::GroupNormL0) = "λ = $(f.lambda), g = $(f.g)" diff --git a/src/shiftedGroupNormL0.jl b/src/shiftedGroupNormL0.jl new file mode 100644 index 00000000..7a044385 --- /dev/null +++ b/src/shiftedGroupNormL0.jl @@ -0,0 +1,77 @@ +export ShiftedGroupNormL0 + +mutable struct ShiftedGroupNormL0{ + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, +} <: ShiftedProximableFunction + h::GroupNormL0{R, RR, I} + xk::V0 + sj::V1 + sol::V2 + shifted_twice::Bool + xsy::V2 + + function ShiftedGroupNormL0( + h::GroupNormL0{R, RR, I}, + xk::AbstractVector{R}, + sj::AbstractVector{R}, + shifted_twice::Bool, + ) where {R <: Real, RR <: AbstractVector{R}, I} + sol = similar(sj) + xsy = similar(sj) + new{R, RR, I, typeof(xk), typeof(sj), typeof(sol)}(h, xk, sj, sol, shifted_twice, xsy) + end +end + +shifted( + h::GroupNormL0{R, RR, I}, + xk::AbstractVector{R}, +) where {R <: Real, RR <: AbstractVector{R}, I} = ShiftedGroupNormL0(h, xk, zero(xk), false) +shifted(h::NormL2{R}, xk::AbstractVector{R}) where {R <: Real} = + ShiftedGroupNormL0(GroupNormL0([h.lambda]), xk, zero(xk), false) +shifted( + ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, + sj::AbstractVector{R}, +) where { + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, +} = ShiftedGroupNormL0(ψ.h, ψ.xk, sj, true) + +fun_name(ψ::ShiftedGroupNormL0) = "shifted x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀ function" +fun_expr(ψ::ShiftedGroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xk + sj + t‖₂" +fun_params(ψ::ShiftedGroupNormL0) = "xk = $(ψ.xk)\n" * " "^14 * "sj = $(ψ.sj)\n" * " "^14 + +function prox!( + y::AbstractVector{R}, + ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, + q::AbstractVector{R}, + σ::R, +) where { + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, +} + ψ.sol .= q + ψ.xk + ψ.sj + + for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) + snorm = norm(ψ.sol[idx])^2 + if snorm <= 2 * γ * λ + y[idx] .= 0 + else + y[idx] .= ψ.sol[idx] + end + end + y .-= (ψ.xk + ψ.sj) + return y +end From 198ff3839dbfedd37d937dd0eff8ef8c09b7e9e2 Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:09:42 -0800 Subject: [PATCH 2/8] Applied changes from @dpo + to .+ and some minor syntax changes Co-authored-by: Dominique --- src/groupNormL0.jl | 12 ++++-------- src/shiftedGroupNormL0.jl | 4 ++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl index fc8f88c9..f1312f07 100644 --- a/src/groupNormL0.jl +++ b/src/groupNormL0.jl @@ -17,13 +17,9 @@ struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} idx::I function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} - if any(lambda .< 0) - error("weights λ must be nonnegative") - elseif length(lambda) != length(idx) - error("number of weights and groups must be the same") - else - new{R, RR, I}(lambda, idx) - end + any(lambda .< 0) && error("weights λ must be nonnegative") + length(lambda) != length(idx) && error("number of weights and groups must be the same") + new{R, RR, I}(lambda, idx) end end @@ -34,7 +30,7 @@ function (f::GroupNormL0)(x::AbstractArray{R}) where {R <: Real} sum_c = R(0) for (idx, λ) ∈ zip(f.idx, f.lambda) y = norm(x[idx]) - if y>0 + if y > 0 sum_c += λ end end diff --git a/src/shiftedGroupNormL0.jl b/src/shiftedGroupNormL0.jl index 7a044385..1922dfc6 100644 --- a/src/shiftedGroupNormL0.jl +++ b/src/shiftedGroupNormL0.jl @@ -62,7 +62,7 @@ function prox!( V1 <: AbstractVector{R}, V2 <: AbstractVector{R}, } - ψ.sol .= q + ψ.xk + ψ.sj + ψ.sol .= q .+ ψ.xk .+ ψ.sj for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) snorm = norm(ψ.sol[idx])^2 @@ -72,6 +72,6 @@ function prox!( y[idx] .= ψ.sol[idx] end end - y .-= (ψ.xk + ψ.sj) + y .-= (ψ.xk .+ ψ.sj) return y end From 76ed6448dbed41cdb8854e872c48df1c49dd52e7 Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:11:45 -0800 Subject: [PATCH 3/8] fixed description --- src/groupNormL0.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl index f1312f07..00f1342c 100644 --- a/src/groupNormL0.jl +++ b/src/groupNormL0.jl @@ -1,4 +1,4 @@ -# Group L2 norm (times a constant) +# Group L0 norm (times a constant) export GroupNormL0 From cec91e34f28388d41ef7a85fd982836e36ce96d1 Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:26:25 -0800 Subject: [PATCH 4/8] added a check that the groups are non-overlapping. Also snuck in the same changes that @dpo suggested on the checking for this groupNormL0 to make groupNormL2 match as well --- src/groupNormL0.jl | 1 + src/groupNormL2.jl | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl index 00f1342c..586b489f 100644 --- a/src/groupNormL0.jl +++ b/src/groupNormL0.jl @@ -19,6 +19,7 @@ struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} any(lambda .< 0) && error("weights λ must be nonnegative") length(lambda) != length(idx) && error("number of weights and groups must be the same") + length(Set(Iterators.flatten(v))) != sum(length, v) && error("groups must be non-overlapping") new{R, RR, I}(lambda, idx) end end diff --git a/src/groupNormL2.jl b/src/groupNormL2.jl index bf000337..bd9bb426 100644 --- a/src/groupNormL2.jl +++ b/src/groupNormL2.jl @@ -17,13 +17,10 @@ struct GroupNormL2{R <: Real, RR <: AbstractVector{R}, I} idx::I function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} - if any(lambda .< 0) - error("weights λ must be nonnegative") - elseif length(lambda) != length(idx) - error("number of weights and groups must be the same") - else - new{R, RR, I}(lambda, idx) - end + any(lambda .< 0) && error("weights λ must be nonnegative") + length(lambda) != length(idx) && error("number of weights and groups must be the same") + length(Set(Iterators.flatten(v))) != sum(length, v) && error("groups must be non-overlapping") + new{R, RR, I}(lambda, idx) end end From 7a17fd72393ff1c4619f701b98dc61746ff8e625 Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:27:22 -0800 Subject: [PATCH 5/8] Fixed typo --- src/groupNormL0.jl | 2 +- src/groupNormL2.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl index 586b489f..194ec774 100644 --- a/src/groupNormL0.jl +++ b/src/groupNormL0.jl @@ -19,7 +19,7 @@ struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} any(lambda .< 0) && error("weights λ must be nonnegative") length(lambda) != length(idx) && error("number of weights and groups must be the same") - length(Set(Iterators.flatten(v))) != sum(length, v) && error("groups must be non-overlapping") + length(Set(Iterators.flatten(idx))) != sum(length, idx) && error("groups must be non-overlapping") new{R, RR, I}(lambda, idx) end end diff --git a/src/groupNormL2.jl b/src/groupNormL2.jl index bd9bb426..b2ef2f0a 100644 --- a/src/groupNormL2.jl +++ b/src/groupNormL2.jl @@ -19,7 +19,7 @@ struct GroupNormL2{R <: Real, RR <: AbstractVector{R}, I} function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} any(lambda .< 0) && error("weights λ must be nonnegative") length(lambda) != length(idx) && error("number of weights and groups must be the same") - length(Set(Iterators.flatten(v))) != sum(length, v) && error("groups must be non-overlapping") + length(Set(Iterators.flatten(idx))) != sum(length, idx) && error("groups must be non-overlapping") new{R, RR, I}(lambda, idx) end end From 48ce9471c4431a3f3d9ae22501aaa017f006e898 Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 18:27:42 -0800 Subject: [PATCH 6/8] fixed a naming issue in shiftedGroupNormL0, dropped overlap checking in groupNormL0 and groupNormL2 as it was causing tests to error with the way I was checking, and added the groupNormL0 and shiftedGroupNormL0 to the main file. Added groupNormL0 to the tests, and its not erroring, but haven't added the correctness check yet. --- src/ShiftedProximalOperators.jl | 2 ++ src/groupNormL0.jl | 1 - src/groupNormL2.jl | 1 - src/shiftedGroupNormL0.jl | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ShiftedProximalOperators.jl b/src/ShiftedProximalOperators.jl index c400c958..527c13a6 100644 --- a/src/ShiftedProximalOperators.jl +++ b/src/ShiftedProximalOperators.jl @@ -30,6 +30,7 @@ include("utils.jl") include("psvd.jl") include("rootNormLhalf.jl") +include("groupNormL0.jl") include("groupNormL2.jl") include("Rank.jl") include("cappedl1.jl") @@ -39,6 +40,7 @@ include("shiftedNormL0.jl") include("shiftedNormL0Box.jl") include("shiftedRootNormLhalf.jl") include("shiftedNormL1.jl") +include("shiftedGroupNormL0.jl") include("shiftedGroupNormL2.jl") include("shiftedNormL1B2.jl") diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl index 194ec774..00f1342c 100644 --- a/src/groupNormL0.jl +++ b/src/groupNormL0.jl @@ -19,7 +19,6 @@ struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} any(lambda .< 0) && error("weights λ must be nonnegative") length(lambda) != length(idx) && error("number of weights and groups must be the same") - length(Set(Iterators.flatten(idx))) != sum(length, idx) && error("groups must be non-overlapping") new{R, RR, I}(lambda, idx) end end diff --git a/src/groupNormL2.jl b/src/groupNormL2.jl index b2ef2f0a..45e8106a 100644 --- a/src/groupNormL2.jl +++ b/src/groupNormL2.jl @@ -19,7 +19,6 @@ struct GroupNormL2{R <: Real, RR <: AbstractVector{R}, I} function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} any(lambda .< 0) && error("weights λ must be nonnegative") length(lambda) != length(idx) && error("number of weights and groups must be the same") - length(Set(Iterators.flatten(idx))) != sum(length, idx) && error("groups must be non-overlapping") new{R, RR, I}(lambda, idx) end end diff --git a/src/shiftedGroupNormL0.jl b/src/shiftedGroupNormL0.jl index 1922dfc6..3b12808b 100644 --- a/src/shiftedGroupNormL0.jl +++ b/src/shiftedGroupNormL0.jl @@ -66,7 +66,7 @@ function prox!( for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) snorm = norm(ψ.sol[idx])^2 - if snorm <= 2 * γ * λ + if snorm <= 2 * σ * λ y[idx] .= 0 else y[idx] .= ψ.sol[idx] From b6075f79db09354ce58b71f267792b891835f59a Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 18:28:11 -0800 Subject: [PATCH 7/8] missed this on in last commit --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index da2507e2..d1c41269 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -179,7 +179,7 @@ for (op, shifted_op) ∈ zip((:NormL2,), (:ShiftedGroupNormL2,)) end end -for (op, shifted_op) ∈ zip((:GroupNormL2,), (:ShiftedGroupNormL2,)) +for (op, shifted_op) ∈ zip((:GroupNormL2,:GroupNormL0), (:ShiftedGroupNormL2,:ShiftedGroupNormL0)) @testset "$shifted_op" begin ShiftedOp = eval(shifted_op) Op = eval(op) From 36fa78f62de7fca8c4cc90a4919525b431ac0d53 Mon Sep 17 00:00:00 2001 From: Alexander Hsu <34590951+AHsu98@users.noreply.github.com> Date: Sat, 22 Feb 2025 18:30:43 -0800 Subject: [PATCH 8/8] dropped groupNormL0 from tests, will add again later --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d1c41269..da2507e2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -179,7 +179,7 @@ for (op, shifted_op) ∈ zip((:NormL2,), (:ShiftedGroupNormL2,)) end end -for (op, shifted_op) ∈ zip((:GroupNormL2,:GroupNormL0), (:ShiftedGroupNormL2,:ShiftedGroupNormL0)) +for (op, shifted_op) ∈ zip((:GroupNormL2,), (:ShiftedGroupNormL2,)) @testset "$shifted_op" begin ShiftedOp = eval(shifted_op) Op = eval(op)