diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 871d7c4..a062f89 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,6 +15,8 @@ concurrency: env: JULIA_PKG_SERVER: "" # Fix for Windows: Forces GR/Plots to be headless + JULIA_NUM_THREADS: 1 + JULIA_PKG_PRECOMPILE_AUTO: 0 GKSwstype: "100" # Fix for macOS: Forces Matplotlib to use non-interactive backend MPLBACKEND: "Agg" diff --git a/.gitignore b/.gitignore index b94acdb..3b14e53 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/Project.toml b/Project.toml index 937c24c..2286d27 100644 --- a/Project.toml +++ b/Project.toml @@ -18,8 +18,10 @@ DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FourierFilterFlux = "3d7dfd45-6c90-4c9b-b697-194a05757159" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GLMNet = "8d5ece8b-de18-5317-b113-243142960cc6" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -44,26 +46,35 @@ Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a" WaveletsExt = "8f464e1e-25db-479f-b0a5-b7680379e03f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +cuFFT = "533571aa-0936-420e-b4be-9c66f5f626ca" + +[extensions] +ScatteringPlotsExt = "Plots" [compat] AbstractFFTs = "1" Adapt = "3, 4" -CUDA = "4, 5" +CUDA = "4 - 6" ContinuousWavelets = "1" Distributions = "0.25" FFTW = "1" FileIO = "1" +FiniteDifferences = "0.12.34" Flux = "0.13, 0.14, 0.15, 0.16" +Functors = "0.5.2" JLD2 = "0.4, 0.5, 0.6" -Plots = "1" +Plots = "1.41.6" Reexport = "1" Wavelets = "0.9, 0.10" WaveletsExt = "0.2.3" Zygote = "0.6, 0.7" +cuFFT = "6.2.0" julia = "1.10" [extras] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "CUDA", "BenchmarkTools"] diff --git a/docs/make.jl b/docs/make.jl index 9c5ed78..2eb5e57 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,13 +2,14 @@ ENV["PLOTS_TEST"] = "true" ENV["GKSwstype"] = "100" ENV["LINES"] = "9" ENV["COLUMNS"] = "60" -using Documenter, ScatteringTransform, ScatteringPlots +using Documenter, ScatteringTransform, Plots mkpath("docs/src/figures") +const ScatteringPlotsExt = Base.get_extension(ScatteringTransform, :ScatteringPlotsExt) makedocs( sitename = "ScatteringTransform.jl", format = Documenter.HTML(), - modules = [ScatteringTransform, ScatteringPlots], + modules = [ScatteringTransform, ScatteringPlotsExt], authors="David Weber, Naoki Saito, Jared White", clean=true, checkdocs = :exports, diff --git a/docs/src/figures/firstLayer.png b/docs/src/figures/firstLayer.png index 3e8fc6c..a467e35 100644 Binary files a/docs/src/figures/firstLayer.png and b/docs/src/figures/firstLayer.png differ diff --git a/docs/src/figures/firstLayerAll.png b/docs/src/figures/firstLayerAll.png index 3391369..46bf0ac 100644 Binary files a/docs/src/figures/firstLayerAll.png and b/docs/src/figures/firstLayerAll.png differ diff --git a/docs/src/figures/jointPlot.png b/docs/src/figures/jointPlot.png index 77fcd7a..9eb92be 100644 Binary files a/docs/src/figures/jointPlot.png and b/docs/src/figures/jointPlot.png differ diff --git a/docs/src/figures/secondLayer.png b/docs/src/figures/secondLayer.png index 66beb75..81a0e98 100644 Binary files a/docs/src/figures/secondLayer.png and b/docs/src/figures/secondLayer.png differ diff --git a/docs/src/figures/secondLayerSpecificPath.png b/docs/src/figures/secondLayerSpecificPath.png index 0d0f23c..ebdae87 100644 Binary files a/docs/src/figures/secondLayerSpecificPath.png and b/docs/src/figures/secondLayerSpecificPath.png differ diff --git a/docs/src/figures/sliceByFirst.gif b/docs/src/figures/sliceByFirst.gif index a18bd48..c10c236 100644 Binary files a/docs/src/figures/sliceByFirst.gif and b/docs/src/figures/sliceByFirst.gif differ diff --git a/docs/src/figures/sliceBySecond.gif b/docs/src/figures/sliceBySecond.gif index 4fdc4f8..be22d30 100644 Binary files a/docs/src/figures/sliceBySecond.gif and b/docs/src/figures/sliceBySecond.gif differ diff --git a/docs/src/figures/zerothLayer.png b/docs/src/figures/zerothLayer.png index 6aefead..3bc2428 100644 Binary files a/docs/src/figures/zerothLayer.png and b/docs/src/figures/zerothLayer.png differ diff --git a/docs/src/index.md b/docs/src/index.md index e2ee5da..6d0dd4c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,17 +14,21 @@ For a comparable package in python, see [Kymatio](https://www.kymat.io/). ## Basic Example -```@setup -using ScatteringTransform, Wavelets, Plots +```@example ex +using ScatteringTransform, ContinuousWavelets, Wavelets, Plots +const ScatteringPlotsExt = Base.get_extension(ScatteringTransform, :ScatteringPlotsExt) +import .ScatteringPlotsExt: plotZerothLayer1D, plotFirstLayer1D, plotFirstLayer1DAll, + plotSecondLayer1DFixAndVary, plotSecondLayer1DSpecificPath, + plotSecondLayer1D, jointPlot1D +nothing # hide ``` -As an example signal, lets work with the doppler signal: +As an example signal, let's work with a doppler signal: ```@example ex -using Wavelets, Plots N = 2047 -f = testfunction(N, "Doppler") -plot(f, legend=false, title="Doppler signal") +signal = testfunction(N, "Doppler") +plot(signal, legend=false, title="Doppler signal") savefig("figures/rawDoppler.svg"); #hide nothing # hide ``` @@ -32,12 +36,12 @@ nothing # hide ![](figures/rawDoppler.svg) First we need to make a `scatteringTransform` instance, which will create and store all of the necessary filters, subsampling operators, nonlinear functions, etc. -The parameters are described in the `scatteringTransform` type. +The parameters are described in the `scatteringTransform` type. The function `reshapeInputs` converts data matrices and vectors into a usable form and returns the reshaped array and its dimensions. It works for any shaped input signal. Since the Doppler signal is smooth, but with varying frequency, let's set the wavelet family `cw=Morlet(π)` specifies the mother wavelet to be a Morlet wavelet with mean frequency π, and frequency spacing `β=2`: ```@example ex -using ScatteringTransform, ContinuousWavelets -St = scatteringTransform((N, 1, 1), 2, cw=Morlet(π), β=2, σ=abs) +f, dims = reshapeInputs(signal) +St = scatteringTransform(dims, 2, cw=Morlet(π), β=2, σ=abs) sf = St(f) ``` @@ -56,7 +60,7 @@ plotZerothLayer1D(sf) The first layer is the average of the absolute value of the scalogram: ```@example ex -plotFirstLayer(sf, St) +plotFirstLayer1D(sf, St) ``` With the plotting utilities included in this package, you are able to display the previous plot along with the original signal and the first layer wavelet gradients: @@ -73,7 +77,7 @@ With our plotting utilities, you can display the second layer with respect to sp To this end, lets make two gifs, the first with the _first_ layer frequency varying with time: ```@example ex -plotSecondLayerFixAndVary(sf, St, 1:30, 1, fps=1, saveTo="figures/sliceByFirst.gif") +plotSecondLayer1DFixAndVary(sf, St, 1:30, 1, fps=1, saveTo="figures/sliceByFirst.gif") nothing # hide ``` ![](figures/sliceByFirst.gif) @@ -84,7 +88,7 @@ As the first layer frequency increases, the energy concentrates to the beginning The second has the _second_ layer frequency varying with time: ```@example ex -plotSecondLayerFixAndVary(sf, St, 1, 1:28, fps=1, saveTo="figures/sliceBySecond.gif") +plotSecondLayer1DFixAndVary(sf, St, 1, 1:28, fps=1, saveTo="figures/sliceBySecond.gif") nothing # hide ``` ![](figures/sliceBySecond.gif) @@ -92,7 +96,7 @@ nothing # hide If desired, this package allows one to plot the results of a specific path. Here is an example, where we are plotting the resulting plot if we were to use first layer wavelet 3 and second layer wavelet 1. ```@example ex -plotSecondLayerSpecificPath(sf, St, 3, 1, f) +plotSecondLayer1DSpecificPath(sf, St, 3, 1, f) ``` For any fixed second layer frequency, we get approximately the curve in the first layer scalogram, with different portions emphasized, and the overall mass decreasing as the frequency increases, corresponding to the decreasing amplitude of the envelope for the doppler signal. @@ -101,7 +105,7 @@ These plots can also be created using various plotting utilities defined in this For example, we can generate a denser representation with the `plotSecondLayer` function: ```@example ex -plotSecondLayer(sf, St) +plotSecondLayer1D(sf, St) ``` where the frequencies are along the axes, the heatmap gives the largest value across time for that path, and at each path is a small plot of the averaged timecourse. @@ -109,8 +113,13 @@ where the frequencies are along the axes, the heatmap gives the largest value ac ### Joint Plot -Finally, we can constuct a joint plot of much of our prior information. This plot will display the zeroth layer, first layer and second layer information for a given example. +Finally, we can constuct a joint plot of much of our prior information. This plot will display the original signal, zeroth layer, first layer and second layer information for a given example. ```@example ex -jointPlot(sf, "Scattering Transform", :viridis, St) +jointPlot1D(sf, "Scattering Transform", :viridis, St, f) ``` + + +## Future Updates + +In the future we will be adding plotting support for the third layer of the Scattering Transform. In addition, 2D variants of the plotting functions will also be created and documented here with examples. \ No newline at end of file diff --git a/docs/src/plots.md b/docs/src/plots.md index b48985c..6cf9c10 100644 --- a/docs/src/plots.md +++ b/docs/src/plots.md @@ -1,14 +1,29 @@ # Plotting Scattering Transforms + +## 1D Plotting Functions ```@docs -ScatteringTransform.plotZerothLayer1D -ScatteringTransform.plotFirstLayer1D -ScatteringTransform.gifFirstLayer -ScatteringTransform.plotFirstLayer1DAll -ScatteringTransform.plotFirstLayer -ScatteringTransform.plotSecondLayerSpecificPath -ScatteringTransform.plotSecondLayer1DSubsetGif -ScatteringTransform.plotSecondLayerFixAndVary -ScatteringTransform.plotSecondLayer -ScatteringTransform.jointPlot +ScatteringPlotsExt.plotZerothLayer1D +ScatteringPlotsExt.plotFirstLayer1DSingleWavelet +ScatteringPlotsExt.gifFirstLayer1D +ScatteringPlotsExt.plotFirstLayer1DAll +ScatteringPlotsExt.plotFirstLayer1D +ScatteringPlotsExt.plotSecondLayer1DSpecificPath +ScatteringPlotsExt.gifSecondLayer1DSubset +ScatteringPlotsExt.plotSecondLayer1DFixAndVary +ScatteringPlotsExt.plotSecondLayer1D +ScatteringPlotsExt.jointPlot1D ``` + +## 2D Plotting Functions +```@docs +ScatteringPlotsExt.plotOriginalSignal2D +ScatteringPlotsExt.plotZerothLayer2D +ScatteringPlotsExt.plotFirstLayer2DSingleWavelet +ScatteringPlotsExt.visualizeFirstLayer2D +ScatteringPlotsExt.plotFirstLayer2D +ScatteringPlotsExt.plotFirstLayer2DAll +ScatteringPlotsExt.plotSecondLayer2DSingleWavelet +ScatteringPlotsExt.visualizeSecondLayer2D +ScatteringPlotsExt.plotSecondLayer2D +``` \ No newline at end of file diff --git a/docs/src/utils.md b/docs/src/utils.md index 0b33028..c845a83 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -15,4 +15,5 @@ ScatteringTransform.normalize ScatteringTransform.processArgs ScatteringTransform.getParameters ScatteringTransform.extractAddPadding +ScatteringTransform.reshapeInputs ``` diff --git a/ext/ScatteringPlotsExt.jl b/ext/ScatteringPlotsExt.jl new file mode 100644 index 0000000..68b136a --- /dev/null +++ b/ext/ScatteringPlotsExt.jl @@ -0,0 +1,12 @@ +module ScatteringPlotsExt + +using ScatteringTransform +using Plots +using FFTW: rfft, irfft, fft, ifft +using LinearAlgebra: norm + +include(joinpath(@__DIR__, "..", "src", "scatteringplots.jl")) +export plotOriginalSignal1D, plotZerothLayer1D, plotFirstLayer1DSingleWavelet, gifFirstLayer1D, plotFirstLayer1DAll, plotFirstLayer1D, plotSecondLayer1DOld, plotSecondLayer1DSpecificPath, gifSecondLayer1DSubset, plotSecondLayer1DFixAndVary, plotSecondLayer1D, jointPlot1D +export plotOriginalSignal2D, plotZerothLayer2D, plotFirstLayer2DSingleWavelet, visualizeFirstLayer2D, plotFirstLayer2D, plotFirstLayer2DAll, plotSecondLayer2DSingleWavelet, visualizeSecondLayer2D, plotSecondLayer2D + +end \ No newline at end of file diff --git a/src/ScatteringTransform.jl b/src/ScatteringTransform.jl index 3f33f86..14fe4c6 100644 --- a/src/ScatteringTransform.jl +++ b/src/ScatteringTransform.jl @@ -12,7 +12,7 @@ using Adapt using RecipesBase using Base: tail using ChainRulesCore -using Plots # who cares about weight really? +# using Plots # who cares about weight really? using Statistics using Dates @@ -43,8 +43,24 @@ end include("utilities.jl") export getWavelets, flatten, roll, importantCoords, batchOff, getParameters, getMeanFreq, computeLoc -export roll, wrap, flatten +export roll, wrap, flatten, reshapeInputs include("adjoints.jl") -include("scatteringplots.jl") -export plotZerothLayer1D, plotFirstLayer1D, gifFirstLayer, plotFirstLayer1DAll, plotFirstLayer, plotSecondLayer, plotSecondLayer1D, plotSecondLayerSpecificPath, plotSecondLayer1DSubsetGif, plotSecondLayerFixAndVary, jointPlot + +for f in [:plotOriginalSignal1D, :plotZerothLayer1D, :plotFirstLayer1DSingleWavelet, + :gifFirstLayer1D, :plotFirstLayer1DAll, :plotFirstLayer1D, + :plotSecondLayer1DSpecificPath, :gifSecondLayer1DSubset, + :plotSecondLayer1DFixAndVary, :plotSecondLayer1D, :jointPlot1D, + :plotOriginalSignal2D, :plotZerothLayer2D, :plotFirstLayer2DSingleWavelet, + :visualizeFirstLayer2D, :plotFirstLayer2D, :plotFirstLayer2DAll, + :plotSecondLayer2DSingleWavelet, :visualizeSecondLayer2D, :plotSecondLayer2D] + @eval function $f(args...; kwargs...) + ext = Base.get_extension(ScatteringTransform, :ScatteringPlotsExt) + if isnothing(ext) + error("Load Plots.jl to use plotting functions: `using Plots`") + end + getfield(ext, $(QuoteNode(f)))(args...; kwargs...) + end + @eval export $f +end + end # end Module diff --git a/src/adjoints.jl b/src/adjoints.jl index ce4eabb..576d2de 100644 --- a/src/adjoints.jl +++ b/src/adjoints.jl @@ -13,9 +13,8 @@ end Zygote.@adjoint function getindex(F::T, i::Integer) where {T<:Scattered} function getInd_rrule(Ȳ) - zeroNonRefed = map(ii -> ii - 1 == i ? Ȳ : zeros(eltype(F.output[ii]), - size(F.output[ii])...), - (1:length(F.output)...,)) + # zeroNonRefed = map(ii -> ii - 1 == i ? Ȳ : zeros(eltype(F.output[ii]), size(F.output[ii])...), (1:length(F.output)...,)) + zeroNonRefed = map(ii -> ii - 1 == i ? Ȳ : zero(F.output[ii]), (1:length(F.output)...,)) ∂F = T(F.m, F.k, zeroNonRefed) return ∂F, nothing end @@ -24,9 +23,8 @@ end Zygote.@adjoint function getindex(F::T, inds::AbstractArray) where {T<:Scattered} function getInd_rrule(Ȳ) - zeroNonRefed = map(ii -> ii - 1 in inds ? Ȳ[indexin(ii - 1, inds)[1]] : - zeros(eltype(F.output[ii]), size(F.output[ii])...), - (1:length(F.output)...,)) + # zeroNonRefed = map(ii -> ii - 1 in inds ? Ȳ[indexin(ii - 1, inds)[1]] : zeros(eltype(F.output[ii]), size(F.output[ii])...), (1:length(F.output)...,)) + zeroNonRefed = map(ii -> ii - 1 in inds ? Ȳ[indexin(ii - 1, inds)[1]] : zero(F.output[ii]), (1:length(F.output)...,)) ∂F = T(F.m, F.k, zeroNonRefed) return ∂F, nothing end @@ -35,9 +33,8 @@ end Zygote.@adjoint function getindex(x::T, p::pathLocs) where {T<:Scattered} function getInd_rrule(Δ) - zeroNonRefed = map(ii -> zeros(eltype(x.output[ii]), - size(x.output[ii])...), - (1:length(x.output)...,)) + # zeroNonRefed = map(ii -> zeros(eltype(x.output[ii]), size(x.output[ii])...), (1:length(x.output)...,)) + zeroNonRefed = map(ii -> zero(x.output[ii]), (1:length(x.output)...,)) ∂x = T(x.m, x.k, zeroNonRefed) ∂x[p] = Δ return ∂x, nothing @@ -47,13 +44,47 @@ end function rrule(::typeof(flatten), scatRes) function ∇flatten(Δarray) - return (NO_FIELDS, roll(Δarray, scatRes),) + return (NoTangent(), roll(Δarray, scatRes),) end return flatten(scatRes), ∇flatten end function rrule(::typeof(roll), toRoll, stOutput) function ∇roll(Δ) - return NO_FIELDS, flatten(Δ), NO_FIELDS + return NoTangent(), flatten(Δ), NoTangent() end return roll(toRoll, stOutput), ∇roll end + + +function ChainRulesCore.rrule(::typeof(normalize), x, Nd) + n = ndims(x) + totalThisLayer = prod(size(x)[(Nd+1):(n-1)]) + sumSqDims = 1:(n-1) + normSq = sum(abs.(x) .^ 2, dims=sumSqDims) + scale = totalThisLayer ./ sqrt.(normSq) + scale = ifelse.(isnan.(scale) .| isinf.(scale), one(eltype(scale)), scale) + y = x .* scale + + function normalize_pullback(Δy) + # All operations stay on the same device as x + ∂x = Δy .* scale + return NoTangent(), ∂x, NoTangent() + end + return y, normalize_pullback +end + +Zygote.@adjoint function normalize(x, Nd) + n = ndims(x) + totalThisLayer = prod(size(x)[(Nd+1):(n-1)]) + sumSqDims = 1:(n-1) + normSq = sum(abs.(x) .^ 2, dims=sumSqDims) + scale = totalThisLayer ./ sqrt.(normSq) + scale = ifelse.(isnan.(scale) .| isinf.(scale), one(eltype(scale)), scale) + y = x .* scale + function normalize_pullback(Δy) + scale_adapted = adapt(typeof(Δy), scale) + ∂x = Δy .* scale_adapted + return ∂x, nothing + end + return y, normalize_pullback +end \ No newline at end of file diff --git a/src/scattered.jl b/src/scattered.jl index 7721987..3a981b0 100644 --- a/src/scattered.jl +++ b/src/scattered.jl @@ -67,7 +67,7 @@ The resulting output of a scattering transform. `m` is the number of layers, whi ScatteredOut(output, k = 1) A less involved constructor given just a list or tuple of the output from each layer. `k` gives the signal dimension, as above, with the default that `k=1`. """ -function ScatteredOut(m, k, fixDim, n, q, T) +function ScatteredOut(m, k, fixDim, n, q, T; arrType=Array) @assert m + 1 == size(n, 1) @assert m + 1 == length(q) @assert k == size(n, 2) @@ -75,12 +75,12 @@ function ScatteredOut(m, k, fixDim, n, q, T) n = Int.(n) q = Int.(q) N = k + length(fixDim) + 1 - output = [zeros(T, n[i, :]..., prod(q[1:(i-1)] .- 1), fixDim...) for i = 1:m+1] + output = [adapt(arrType, zeros(T, n[i, :]..., prod(q[1:(i-1)] .- 1), fixDim...)) for i = 1:m+1] return ScatteredOut{T,N}(m, k, output) end function ScatteredFull(m, k, fixDim::Array{<:Real,1}, n::Array{<:Real,2}, - q::Array{<:Real,1}, T) + q::Array{<:Real,1}, T; arrType=Array) @assert m + 1 == size(n, 1) @assert m + 1 == length(q) @assert k == size(n, 2) @@ -88,8 +88,8 @@ function ScatteredFull(m, k, fixDim::Array{<:Real,1}, n::Array{<:Real,2}, n = Int.(n) q = Int.(q) N = k + length(fixDim) + 1 - data = [zeros(T, n[i, :]..., prod(q[1:i] .- 1), fixDim...) for i = 1:m+1] - output = [zeros(T, n[i, :]..., prod(q[1:(i-1)] .- 1), fixDim...) for i = 1:m+1] + data = [adapt(arrType, zeros(T, n[i, :]..., prod(q[1:i] .- 1), fixDim...)) for i = 1:m+1] + output = [adapt(arrType, zeros(T, n[i, :]..., prod(q[1:(i-1)] .- 1), fixDim...)) for i = 1:m+1] return ScatteredFull{T,N}(m, k, data, output) end @@ -136,7 +136,7 @@ return the input dimension size (also given by `sct.k`) ndims(sct::Scattered) = sct.k size(sct::ScatteredOut) = map(x -> size(x), sct.output) size(sct::ScatteredFull) = map(x -> size(x), sct.data) -similar(sct::ScatteredOut) = ScatteredOut(map(x -> similar(arrayType(sct), axes(x)), sct.output), sct.k) +similar(sct::ScatteredOut) = ScatteredOut(map(x -> similar(x), sct.output), sct.k) similar(sct::ScatteredFull) = ScatteredFull(map(x -> similar(x), sct.data), map(x -> similar(x), sct.output), sct.k) import Statistics.mean diff --git a/src/scatteringplots.jl b/src/scatteringplots.jl index b40b28d..0e9c278 100644 --- a/src/scatteringplots.jl +++ b/src/scatteringplots.jl @@ -2,89 +2,106 @@ plotZerothLayer1D(sf; saveTo=nothing, index=1) Function that plots the zeroth layer of the scattering transform at a specified example index. """ -function plotZerothLayer1D(sf; saveTo=nothing, index=1) - plt = plot(sf[0][:, 1, index], title="Zeroth Layer", legend=false, xlim=(0, length(sf[0][:, 1, index])+1), color=:blue, margin=5Plots.mm) +function plotOriginalSignal1D(f; title="Original Signal", saveTo=nothing, index=1) + plt = plot(f[:,1,index], title=title, legend=false, xlim=(0, length(f[:, 1, index])+1), color=:blue, margin=5Plots.mm, size=(720,480)) if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt end """ - plotFirstLayer1D(j, origLoc, origSig; saveTo=nothing, index=1) -Function that plots the first layer gradient wavelet at index `j` across space, along with the original signal. -It also includes heatmaps of the gradient wavelet in both the spatial and frequency domains. -The variable `j` specifies which wavelet to plot from the first layer, `index` specifies which example in the batch to plot, + plotZerothLayer1D(sf; title="Zeroth Layer", saveTo=nothing, index=1) +Function that plots the zeroth layer of the scattering transform at a specified example index. +""" +function plotZerothLayer1D(sf; title="Zeroth Layer", saveTo=nothing, index=1) + plt = plot(sf[0][:, 1, index], title=title, legend=false, xlim=(0, length(sf[0][:, 1, index])+1), color=:blue, margin=5Plots.mm, size=(720,480)) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + +""" + plotFirstLayer1DSingleWavelet(j, origLoc, origSig; title="First Layer", saveTo=nothing, index=1) +The variable `j` specifies which wavelet results to plot from the first layer, `index` specifies which example in the batch to plot, `origLoc` is the `ScatteredOut` object containing the scattering transform results, and `origSig` is the original input signal. +It also includes heatmaps of the gradient wavelet in both the spatial and frequency domains. """ -function plotFirstLayer1D(j, origLoc, origSig; saveTo=nothing, index=1) +function plotFirstLayer1DSingleWavelet(j, origLoc, origSig; title="First Layer", saveTo=nothing, index=1) space = plot(origLoc[1][:, j, index], xlim=(0, length(origLoc[1][:, j, index])+1), legend=false, - color=:red, title="First Layer - Gradient Wavelet $j - Varying Location") + color=:red, title="$title - Gradient Wavelet $j - Varying Location") org = plot(origSig[:,:,index], legend=false, color=:red, title="Original Signal", xlim=(0, length(origSig[:,:,index])+1)) ∇h = heatmap(origLoc[1][:, j, index]', xlabel="space", yticks=false, ylabel="", title="First Layer gradient - Wavelet j=$j") ∇̂h = heatmap(log.(abs.(rfft(origLoc[1][:, j, index], 1)) .^ 2)', xlabel="frequency", yticks=false, ylabel="", title="Log-power Frequency Domain - Wavelet j=$j") l = Plots.@layout [a; b{0.1h}; [b c]] - plt = plot(space, org, ∇h, ∇̂h, layout=l, size=(1200, 800), margin=5Plots.mm) + plt = plot(space, org, ∇h, ∇̂h, layout=l, size=(1280, 720), margin=5Plots.mm) if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt end """ - gifFirstLayer(origLoc, origSig; fps=2, saveTo=nothing, index=1) + gifFirstLayer1D(origLoc, origSig; fps=2, title="First Layer", saveTo=nothing, index=1) Function to create a GIF visualizing all wavelets in the first layer across space for each example in the batch. The variable `origLoc` is the `ScatteredOut` object containing the scattering transform results, `index` specifies which example in the batch to plot, `origSig` is the original input signal, `saveTo` specifies the file path to save the GIF, and `fps` sets the frames per second for the GIF animation. -If `saveTo` is provided, the GIF is saved to that file path. +If `saveTo` is provided, the GIF is saved to that file path. The default `fps` is set to 2 frames per second. """ -function gifFirstLayer(origLoc, origSig; fps=2, saveTo=nothing, index=1) +function gifFirstLayer1D(origLoc, origSig; fps=2, title="First Layer", saveTo=nothing, index=1) anim = Animation() for j = 1:size(origLoc[1])[end-1] - plotFirstLayer1D(j, origLoc, origSig; index=index) + plotFirstLayer1DSingleWavelet(j, origLoc, origSig; index=index) frame(anim) end filepath = isnothing(saveTo) ? "tmp.gif" : saveTo + mkpath(dirname(filepath)) return gif(anim, filepath, fps=fps) end """ - plotFirstLayer1DAll(origLoc, origSig; saveTo=nothing, index=1, cline=:darkrainbow) + plotFirstLayer1DAll(origLoc, origSig; title="First Layer", saveTo=nothing, index=1, cline=:darkrainbow) Function that plots all first layer gradient wavelets for a specific example signal `index` across space, along with the original signal. It also includes heatmaps of the gradient wavelets in both the spatial and frequency domains. The variable `index` specifies which example in the batch to plot, `origLoc` is the `ScatteredOut` object containing the scattering transform results, `origSig` is the original input signal, and `saveTo` is the file path to save the plot. """ -function plotFirstLayer1DAll(origLoc, origSig; saveTo=nothing, index=1, cline=:darkrainbow) +function plotFirstLayer1DAll(origLoc, origSig; title="First Layer", saveTo=nothing, index=1, cline=:darkrainbow) space = plot(origLoc[1][:, :, index], line_z=(1:size(origLoc[1], 2))', xlim=(0, length(origLoc[1][:, 1, index])+1), - legend=false, colorbar=true, color=cline, title="first layer gradient wavelets") + legend=false, colorbar=true, color=cline, title="$title Gradient Wavelets") org = plot(origSig[:,:,index], legend=false, color=:red, title="Original Signal", xlim=(0, length(origSig[:,:,index])+1)) ∇h = heatmap(origLoc[1][:, 1:end, index]', xlabel="space", ylabel="wavelet index", title="First Layer gradients") ∇̂h = heatmap(log.(abs.(rfft(origLoc[1][:, 1:end, index], 1)) .^ 2)', xlabel="frequency", ylabel="wavelet index", title="Log-power Frequency Domain") l = Plots.@layout [a; b{0.1h}; [b c]] - plt = plot(space, org, ∇h, ∇̂h, layout=l, size=(1200, 800), margin=5Plots.mm) + plt = plot(space, org, ∇h, ∇̂h, layout=l, size=(1280, 720), margin=5Plots.mm) if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt end """ - plotFirstLayer(stw, St; saveTo=nothing, index=1) + plotFirstLayer1D(stw, St; title="First Layer", saveTo=nothing, index=1) Function that creates a heatmap of the first layer scattering transform results at a specified example index. The variable `stw` is the scattered output, `St` is the scattering transform object, `saveTo` is the file path to save the plot, and `index` specifies which example in the batch to plot. """ -function plotFirstLayer(stw, St; saveTo=nothing, index=1) +function plotFirstLayer1D(stw, St; title="First Layer", saveTo=nothing, index=1) f1, f2, f3 = getMeanFreq(St) # the mean frequencies for the wavelets in each layer. plt = heatmap(1:size(stw[1], 1), f1[1:end-1], stw[1][:, :, index]', xlabel="time index", ylabel="Frequency (Hz)", margin=5Plots.mm, - color=:viridis, title="First Layer", size=(1200, 800)) + color=:viridis, title="$title", size=(1280, 720)) if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt @@ -95,9 +112,9 @@ meanWave(wave) = sum(real.(range(0, stop=1, length=size(wave, 1)) .* wave), dims # As far as I can tell, this function is not used elsewhere. """ - plotSecondLayer1D(loc, origLoc, wave1, wave2, original=false, subsamSz=(128,85,)) + plotSecondLayer1DOld(loc, origLoc, wave1, wave2, original=false, subsamSz=(128,85,)) """ -function plotSecondLayer1D(loc, origLoc, wave1, wave2, original=false, subsamSz=(128, 85,), c=:thermal, lastDiagFreq=true) +function plotSecondLayer1DOld(loc, origLoc, wave1, wave2, original=false, subsamSz=(128, 85,), c=:thermal, lastDiagFreq=true) waveUsed = real.(ifftshift(irfft(wave1[:, loc[2]], subsamSz[1] * 2))) l1wave = plot(waveUsed, legend=false, titlefontsize=8, title="layer 1 ($(loc[2]))") annotate!(size(waveUsed, 1) * 5 / 6, maximum(waveUsed), Plots.text("freq = $(meanWave(wave1)[loc[2]])"[1:13], 5)) @@ -126,22 +143,22 @@ function plotSecondLayer1D(loc, origLoc, wave1, wave2, original=false, subsamSz= end """ - plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original; saveTo=nothing, index=1) + plotSecondLayer1DSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original; saveTo=nothing, index=1) `stw` is the scattered output, `St` is the scattering transform object, `firstLayerWaveletIndex` and `secondLayerWaveletIndex` specify the path to plot, `original` is the original signal, and `index` specifies which example in the batch to plot. This function creates a plot showing the original signal and the scattering result for the specified path. It also displays the mean frequencies associated with the selected wavelets. Finally, it displays the log-power norm of the second layer signal for the specified path. This value is used elsewhere to create heatmaps of the second layer scattering results. """ -function plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original; saveTo=nothing, index=1) +function plotSecondLayer1DSpecificPath(stw, St, firstLayerWaveletIndex, secondLayerWaveletIndex, original; saveTo=nothing, index=1) # Plot of original signal. org = plot(original[:,:,index], legend=false, color=:red, title="Original Signal", xlabel="time (samples)", ylabel="amplitude", xlims=(0, length(original[:,:,index])+1)) f1, f2, f3 = getMeanFreq(St) # Plot the signal for a specific path. signalLayer1Freq = f1[firstLayerWaveletIndex]; signalLayer2Freq = f2[secondLayerWaveletIndex] - titlePlot = plot(title="Path: First Layer - wavelet $firstLayerWaveletIndex, Second Layer - wavelet $secondLayerWaveletIndex\n" * - "First Layer Freq = $(round(signalLayer1Freq, sigdigits=3)) Hz | " * - "Second Layer Freq = $(round(signalLayer2Freq, sigdigits=3)) Hz", - grid=false, showaxis=false, xticks=nothing, yticks=nothing, bottom_margin=-5Plots.px, titlefontsize=11) + plotsTitle = "Path: First Layer - wavelet $firstLayerWaveletIndex, Second Layer - wavelet $secondLayerWaveletIndex\n" * + "First Layer Freq = $(round(signalLayer1Freq, sigdigits=3)) Hz | " * + "Second Layer Freq = $(round(signalLayer2Freq, sigdigits=3)) Hz" + titlePlot = plot(title=plotsTitle, grid=false, showaxis=false, xticks=nothing, yticks=nothing, bottom_margin=-5Plots.px, titlefontsize=11) path_spatial = stw[2][:, secondLayerWaveletIndex, firstLayerWaveletIndex, index] ∇h = plot(path_spatial, xlabel="time (samples)", ylabel="amplitude", title="Second Layer Plot", legend=false, linewidth=1.5, frame=:box, fill=0, fillalpha=0.5, @@ -151,38 +168,40 @@ function plotSecondLayerSpecificPath(stw, St, firstLayerWaveletIndex, secondLaye normPlot = plot(title="Second Layer Signal norm (log-power) = $(round(secondLayerNorm, sigdigits=4))", grid=false, showaxis=false, xticks=nothing, yticks=nothing, titlefontsize=11, framestyle=:none) l = Plots.@layout [a{0.4h}; title{0.05h}; c; d{0.05h}] - plt = plot(org, titlePlot, ∇h, normPlot, layout=l, margin=4Plots.mm, size=(900,600)) + plt = plot(org, titlePlot, ∇h, normPlot, layout=l, margin=4Plots.mm, size=(1080,720)) if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt end """ - plotSecondLayer1DSubsetGif(stw, St, firstLayerWavelets, secondLayerWavelets, original; fps=2, saveTo=nothing, index=1) + gifSecondLayer1DSubset(stw, St, firstLayerWavelets, secondLayerWavelets, original; fps=2, saveTo=nothing, index=1) Create a GIF visualizing the second layer scattering results for specified subsets of wavelets from the first and second layers. The variables `firstLayerWavelets` and `secondLayerWavelets` are arrays containing the indices of the wavelets to be visualized from the first and second layers, respectively. For example, to visualize all the wavelets from the first layer with respect to a specific wavelet from the second layer, you can set `firstLayerWavelets = 1:size(stw[1], 2)` and `secondLayerWavelets = k`, where `k` is the index of the desired second layer wavelet. Once again, the `index` parameter specifies which example in the batch to plot. It defaults to the first example in the batch. If `saveTo` is not provided, the GIF is saved to "tmp.gif". """ -function plotSecondLayer1DSubsetGif(stw, St, firstLayerWavelets, secondLayerWavelets, original; fps=2, saveTo=nothing, index=1) +function gifSecondLayer1DSubset(stw, St, firstLayerWavelets, secondLayerWavelets, original; fps=2, saveTo=nothing, index=1) anim = Animation() for j in firstLayerWavelets, k in secondLayerWavelets - plt = plotSecondLayerSpecificPath(stw, St, j, k, original; index=index) + plt = plotSecondLayer1DSpecificPath(stw, St, j, k, original; saveTo=nothing, index=index) frame(anim, plt) end + mkpath(dirname(saveTo)) filepath = isnothing(saveTo) ? "tmp.gif" : saveTo return gif(anim, filepath, fps=fps) end """ - plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets; fps=2, saveTo=nothing, index=1) + plotSecondLayer1DFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets; fps=2, saveTo=nothing, index=1) Create a GIF visualizing slices of the second layer scattering results by fixing one layer's wavelet and varying the other layer's wavelet. The variables `firstLayerWavelets` and `secondLayerWavelets` are arrays containing the indices of the wavelets to be visualized from the first and second layers, respectively. If `firstLayerWavelets` contains only one index, the function fixes that wavelet and varies the second layer wavelets, and vice versa. If `saveTo` is not provided, the GIF is saved to "tmp.gif". """ -function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets; fps=2, saveTo=nothing, index=1) +function plotSecondLayer1DFixAndVary(stw, St, firstLayerWavelets, secondLayerWavelets; fps=2, saveTo=nothing, index=1) f1, f2, f3 = getMeanFreq(St) anim = Animation() @@ -195,7 +214,7 @@ function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavel end toPlot = stw[2][:, jj, :, index] plt = heatmap(1:size(toPlot, 1), f1[1:end-1], toPlot', title="Second Layer Wavelet $jj, Frequency=$(round(f2[jj], sigdigits=4))Hz", - xlabel="time (samples)", ylabel="First Layer Frequency (Hz)", c=cgrad(:viridis, scale=:exp)) + xlabel="time (samples)", ylabel="First Layer Frequency (Hz)", c=cgrad(:viridis, scale=:exp), margin=5Plots.mm, size=(1080,720)) frame(anim, plt) end else @@ -206,16 +225,17 @@ function plotSecondLayerFixAndVary(stw, St, firstLayerWavelets, secondLayerWavel end toPlot = stw[2][:, :, jj, index] plt = heatmap(1:size(toPlot, 1), f2[1:end-1], toPlot', title="First Layer Wavelet $jj, Frequency=$(round(f1[jj], sigdigits=4))Hz", - xlabel="time (samples)", ylabel="Second Layer Frequency (Hz)", c=cgrad(:viridis, scale=:exp)) + xlabel="time (samples)", ylabel="Second Layer Frequency (Hz)", c=cgrad(:viridis, scale=:exp), margin=5Plots.mm, size=(1080,720)) frame(anim, plt) end end + mkpath(dirname(saveTo)) filepath = isnothing(saveTo) ? "tmp.gif" : saveTo return gif(anim, filepath, fps=fps) end """ - plotSecondLayer(stw, St; saveTo=nothing, index=1, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0,.9]), threshold=0, linePalette=:greys, minLog=NaN, kwargs...) + plotSecondLayer1D(stw, St; saveTo=nothing, index=1, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0,.9]), threshold=0, linePalette=:greys, minLog=NaN, kwargs...) TODO fix the similarity of these names. xVals and yVals give the spacing of the grid, as it doesn't seem to be done correctly by default. xVals gives the distance from the left and the right @@ -224,16 +244,16 @@ also as a tuple. Default values are `xVals = (.037, .852), yVals = (.056, .939)` If you have no colorbar, set `xVals = (.0015, .997), yVals = (.002, .992)` In the case that arbitrary space has been introduced, if you have a title, use `xVals = (.037, .852), yVals = (.056, .939)`, or if you have no title, use `xVals = (.0105, .882), yVals = (.056, .939)` """ -function plotSecondLayer(stw::ScatteredOut, St; saveTo=nothing, index=1, kwargs...) +function plotSecondLayer1D(stw::ScatteredOut, St; saveTo=nothing, index=1, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0,.9]), threshold=0, linePalette=:greys, minLog=NaN, kwargs...) secondLayerRes = stw[2] if ndims(secondLayerRes) > 3 - return plotSecondLayer(secondLayerRes[:, :, :, index], St; saveTo=saveTo, kwargs...) + return plotSecondLayer1D(secondLayerRes[:, :, :, index], St; saveTo=saveTo, kwargs...) else - return plotSecondLayer(secondLayerRes, St; saveTo=saveTo, kwargs...) + return plotSecondLayer1D(secondLayerRes, St; saveTo=saveTo, kwargs...) end end -function plotSecondLayer(stw, St; saveTo=nothing, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0, 0.9]), threshold=0, freqsigdigits=3, linePalette=:greys, minLog=NaN, subClims=(Inf, -Inf), δt=1000, firstFreqSpacing=nothing, secondFreqSpacing=nothing, transp=true, labelRot=30, xlabel=nothing, ylabel=nothing, frameTypes=:box, miniFillAlpha=0.5, kwargs...) +function plotSecondLayer1D(stw, St; saveTo=nothing, title="Second Layer results", xVals=-1, yVals=-1, logPower=true, toHeat=nothing, c=cgrad(:viridis, [0, 0.9]), threshold=0, freqsigdigits=3, linePalette=:greys, minLog=NaN, subClims=(Inf, -Inf), δt=1000, firstFreqSpacing=nothing, secondFreqSpacing=nothing, transp=true, labelRot=30, xlabel=nothing, ylabel=nothing, frameTypes=:box, miniFillAlpha=0.5, kwargs...) n, m = size(stw)[2:3] freqs = getMeanFreq(St, δt) freqs = map(x -> round.(x, sigdigits=freqsigdigits), freqs)[1:2] @@ -301,10 +321,10 @@ function plotSecondLayer(stw, St; saveTo=nothing, title="Second Layer results", if title == "" plt = heatmap(toHeat; yticks=yTicksFreq, xticks=xTicksFreq, tick_direction=:out, rotation=labelRot, - xlabel=xlabel, ylabel=ylabel, c=c, clims=(bottom, top), kwargs...) + xlabel=xlabel, ylabel=ylabel, left_margin=5Plots.mm, bottom_margin=3Plots.mm, c=c, clims=(bottom, top), size=(1280,1080), kwargs...) else plt = heatmap(toHeat; yticks=yTicksFreq, xticks=xTicksFreq, tick_direction=:out, rotation=labelRot, - title=title, xlabel=xlabel, ylabel=ylabel, c=c, clims=(bottom, top), kwargs...) + title=title, xlabel=xlabel, ylabel=ylabel, left_margin=5Plots.mm, bottom_margin=3Plots.mm, c=c, clims=(bottom, top), size=(1280,1080), kwargs...) end nPlot = 2 for i in 1:n, j in 1:m @@ -318,18 +338,20 @@ function plotSecondLayer(stw, St; saveTo=nothing, title="Second Layer results", end end if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt end """ - jointPlot(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false) + jointPlot1DOld(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false) Create a joint plot visualizing the zeroth, first, and second layer scattering results for a specified example. The variable `thingToPlot` is a tuple containing the scattering results for the zeroth, first, and second layers, `thingName` is the title for the plot, `cSymbol` specifies the color gradient to use, and `St` is the scattering transform object. The function allows for -various customization options, including shared color scaling, target example selection, frequency digit rounding, and additional plotting options. +various customization options, including shared color scaling, target example selection, frequency digit rounding, and additional plotting options. In this implementation, `targetExample` is the +same as `index` in other functions, specifying which example in the batch to plot. """ -function jointPlot(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false, kwargs...) +function jointPlot1DOld(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, extraPlot=nothing, allPositive=false, logPower=false, kwargs...) if sharedColorbar clims = (min(minimum.(thingToPlot)...), max(maximum.(thingToPlot)...)) climszero = clims @@ -385,9 +407,356 @@ function jointPlot(thingToPlot, thingName, cSymbol, St; saveTo=nothing, sharedCo end titlePlot = plot(title=thingName, grid=false, showaxis=false, xticks=nothing, yticks=nothing, bottom_margin=0Plots.px) lay = Plots.@layout [o{0.00001h}; [[a b; c{0.1h} d{0.1h}] b{0.04w}]] - plt = plot(titlePlot, p2, p1, extraPlot, p0, colorbarOnly, layout=lay, size=(1500,1000), margin=6Plots.mm) + plt = plot(titlePlot, p2, p1, extraPlot, p0, colorbarOnly, layout=lay, size=(1500,1000), margin=6Plots.mm, top_margin=2Plots.mm) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + + +""" + jointPlot1D(thingToPlot, thingName, cSymbol, St, origSig; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, allPositive=false, logPower=false) +Create a joint plot visualizing the zeroth, first, and second layer scattering results for a specified example. The variable `thingToPlot` is a tuple containing the scattering results for the +zeroth, first, and second layers, `thingName` is the title for the plot, `cSymbol` specifies the color gradient to use, and `St` is the scattering transform object. The function allows for +various customization options, including shared color scaling, target example selection, frequency digit rounding, and additional plotting options. In this implementation, `targetExample` is the +same as `index` in other functions, specifying which example in the batch to plot. Finally, `origSig` is the original input signal to be displayed alongside the scattering results. +""" +function jointPlot1D(thingToPlot, thingName, cSymbol, St, origSig; saveTo=nothing, sharedColorScaling=:exp, targetExample=1, δt=1000, freqigdigits=3, sharedColorbar=false, allPositive=false, logPower=false, kwargs...) + if sharedColorbar + clims = (min(minimum.(thingToPlot)...), max(maximum.(thingToPlot)...)) + climszero = clims + climsfirst = clims + climssecond = clims + toHeat = [norm(thingToPlot[2][:, i, j, targetExample], Inf) for i = 1:size(thingToPlot[2], 2), j = 1:size(thingToPlot[2], 3)] + else + climszero = (min(minimum.(thingToPlot[0])...), max(maximum.(thingToPlot[0])...)) + climsfirst = (min(minimum.(thingToPlot[1])...), max(maximum.(thingToPlot[1])...)) + climssecond = (min(minimum.(thingToPlot[2])...), max(maximum.(thingToPlot[2])...)) + toHeat = [norm(thingToPlot[2][:, i, j, targetExample], Inf) for i = 1:size(thingToPlot[2], 2), j = 1:size(thingToPlot[2], 3)] + end + firstLay = thingToPlot[1][:, :, targetExample]' + zeroLay = thingToPlot[0][:, :, targetExample]' + toHeat[toHeat.==0] .= -Inf # we would like zeroes to not actually render + firstLay[firstLay.==0] .= -Inf # for either layer + + # adjust the other parts to be log if logPower is true + if logPower && allPositive + absThing = map(x -> abs.(x), (thingToPlot[0], thingToPlot[1], thingToPlot[2])) + clims = (min(minimum.(absThing)...), max(maximum.(absThing)...)) + firstLay = log10.(firstLay) + zeroLay = log10.(abs.(zeroLay)) + climszero = log10.(clims) + climsfirst = log10.(clims) + climssecond = log10.(clims) + elseif logPower + error("not currently plotting log power and negative values") + end + + if allPositive + c = cgrad(cSymbol, scale=sharedColorScaling) + cSecond = cgrad(cSymbol) + else + zeroAt = -climssecond[1] / (climssecond[2] - climssecond[1]) # set the mid color switch to zero + c = cgrad(cSymbol, [0, zeroAt], scale=sharedColorScaling) + cSecond = cgrad(cSymbol, [0, zeroAt]) + end + + # define the spatial locations as they correspond to the input + spaceLocs = range(1, size(St)[1], length=length(zeroLay)) + + p2 = plotSecondLayer1D(thingToPlot[2][:, :, :, targetExample], St; title="Second Layer", toHeat=toHeat, logPower=logPower, c=c, clims=climssecond, subClims=climssecond, cbar=false, xVals=(0.000, 0.993), yVals=(0.0, 0.994), transp=true, kwargs...) + freqs = getMeanFreq(St, δt) + freqs = map(x -> round.(x, sigdigits=freqigdigits), freqs) + p1 = heatmap(firstLay, c=c, title="First Layer", clims=climsfirst, cbar=false, yticks=((1:size(firstLay, 1)), ""), xticks=((1:size(firstLay, 2)), ""), bottom_margin=0Plots.px, left_margin=0Plots.mm) + p0 = heatmap(spaceLocs, 1:1, zeroLay, c=c, xlabel="Zeroth Layer", clims=climszero, cbar=false, yticks=nothing, top_margin=6Plots.mm, xguidefontsize=12, left_margin=0Plots.mm) + colorbarOnly = scatter([0, 0], [0, 1], zcolor=[0, 3], clims=climssecond, xlims=(1, 1.1), xshowaxis=false, yshowaxis=false, label="", c=c, grid=false, framestyle=:none) + originalSignalPlot = plot(origSig[:,1,targetExample], xlabel="Original Signal", legend=false, xlim=(0, length(origSig[:,1,targetExample])+1), color=:blue, top_margin=6Plots.mm, xguidefontsize=12) + titlePlot = plot(title=thingName, grid=false, showaxis=false, xticks=nothing, yticks=nothing, top_margin=0Plots.px, bottom_margin=0Plots.px) + lay = Plots.@layout [o{0.01h}; [[a b; c{0.1h} d{0.1h}] cb{0.04w}]] + plt = plot(titlePlot, p2, p1, originalSignalPlot, p0, colorbarOnly, layout=lay, size=(1920,1080), margin=5Plots.mm, left_margin=10Plots.mm) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + + + +""" + plotOriginalSignal2D(origSig; color=:grays, saveTo=nothing, index=1) +Function that plots the original signal `origSig` and saves it to a desired location. +""" +function plotOriginalSignal2D(origSig; color=:grays, saveTo=nothing, index=1) + plt = heatmap(origSig[:,:,1,index], title="Original Signal", legend=false, axis=false, color=color, colorbar=false, margin=5Plots.mm, size=(720,480)) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + +""" + plotZerothLayer2D(sf; color=:grays, saveTo=nothing, index=1) +Function that plots the zeroth layer of the scattering transform at a specified example index. +""" +function plotZerothLayer2D(sf; color=:grays, saveTo=nothing, index=1) + plt = heatmap(sf[0][:,:,1,index], title="Zeroth Layer", legend=false, axis=false, color=color, margin=5Plots.mm, size=(720,480)) if !isnothing(saveTo) + mkpath(dirname(saveTo)) savefig(plt, saveTo) end return plt end + +""" + plotFirstLayer2DSingleWavelet(j, sf, origSig; color=:grays, saveTo=nothing, index=1) +The variable `j` specifies which wavelet results to plot from the first layer, `index` specifies which example in the batch to plot, +`sf` is the `ScatteredOut` object containing the scattering transform results, and `origSig` is the original input signal. +""" +function plotFirstLayer2DSingleWavelet(j, sf, origSig; color=:grays, saveTo=nothing, index=1) + space = heatmap(sf[1][:,:,j,index], title="First Layer - Gradient Wavelet $j", legend=false, axis=false, color=color) + orig = heatmap(origSig[:,:,1,index], title="Original Signal", legend=false, axis=false, color=color) + l = Plots.@layout [a b] + plt = plot(space, orig, layout=l, size=(1280, 720), margin=5Plots.mm) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + +""" + visualizeFirstLayer2D(sf; color=:grays, index=1) +Helper function that creates a grid of heatmaps visualizing all the first layer gradient wavelets across space for a specified example index. +The variable `sf` is the `ScatteredOut` object containing the scattering transform results, `index` specifies which example in the batch to plot, and +`color` sets the color gradient for the heatmaps. The function organizes the heatmaps in a grid format, where each row corresponds to a different +scale of the wavelets, and each column corresponds to a different channel (e.g., real and imaginary parts). +""" +function visualizeFirstLayer2D(sf; color=:grays, index=1, clims=nothing, colorScaling=false) + nchannel = 3 + scale = Int(size(sf[1], 3) / nchannel) + + clims = isnothing(clims) ? (minimum(sf[1][:,:,:,index]), maximum(sf[1][:,:,:,index])) : clims + + plots = [] + for scal = 1:scale + o1 = sf[1][:,:,1 + 3*(scal-1), index] + o2 = sf[1][:,:,2 + 3*(scal-1), index] + o3 = sf[1][:,:,3 + 3*(scal-1), index] + if colorScaling + if scal == 1 + push!(plots, heatmap(o1, color=color, colorbar=false, axis=false, legend=false, ylabel=string(scal), title="1", clims=clims)) + push!(plots, heatmap(o2, color=color, colorbar=false, axis=false, legend=false, title="i", clims=clims)) + push!(plots, heatmap(o3, color=color, colorbar=false, axis=false, legend=false, title="j", clims=clims)) + else + push!(plots, heatmap(o1, color=color, colorbar=false, axis=false, legend=false, ylabel=string(scal), clims=clims)) + push!(plots, heatmap(o2, color=color, colorbar=false, axis=false, legend=false, clims=clims)) + push!(plots, heatmap(o3, color=color, colorbar=false, axis=false, legend=false, clims=clims)) + end + else + if scal == 1 + push!(plots, heatmap(o1, color=color, colorbar=false, axis=false, legend=false, ylabel=string(scal), title="1")) + push!(plots, heatmap(o2, color=color, colorbar=false, axis=false, legend=false, title="i")) + push!(plots, heatmap(o3, color=color, colorbar=false, axis=false, legend=false, title="j")) + else + push!(plots, heatmap(o1, color=color, colorbar=false, axis=false, legend=false, ylabel=string(scal))) + push!(plots, heatmap(o2, color=color, colorbar=false, axis=false, legend=false)) + push!(plots, heatmap(o3, color=color, colorbar=false, axis=false, legend=false)) + end + end + end + return plot(plots..., layout=(scale, nchannel), size=(540, 1080)) +end + +""" + plotFirstLayer2D(sf; color=:grays, saveTo=nothing, index=1) +Function that creates a heatmap of the first layer scattering transform results at a specified example index. +The variable `sf` is the scattered output, `saveTo` is the file path to save the plot, and `index` specifies which example in the batch to plot. +The user can also specify the color gradient for the heatmap using the `color` argument. +""" +function plotFirstLayer2D(sf; color=:grays, saveTo=nothing, index=1) + nchannel = 3 + scale = Int(size(sf[1], 3) / nchannel) + clims = (minimum(sf[1][:,:,:,index]), maximum(sf[1][:,:,:,index])) + + plots = [] + for scal = 1:scale + o1 = sf[1][:,:,1 + 3*(scal-1), index] + o2 = sf[1][:,:,2 + 3*(scal-1), index] + o3 = sf[1][:,:,3 + 3*(scal-1), index] + if scal == 1 + push!(plots, heatmap(o1, color=color, colorbar=false, axis=false, legend=false, ylabel=string(scal), title="1", clims=clims)) + push!(plots, heatmap(o2, color=color, colorbar=false, axis=false, legend=false, title="i", clims=clims)) + push!(plots, heatmap(o3, color=color, colorbar=false, axis=false, legend=false, title="j", clims=clims)) + else + push!(plots, heatmap(o1, color=color, colorbar=false, axis=false, legend=false, ylabel=string(scal), clims=clims)) + push!(plots, heatmap(o2, color=color, colorbar=false, axis=false, legend=false, clims=clims)) + push!(plots, heatmap(o3, color=color, colorbar=false, axis=false, legend=false, clims=clims)) + end + end + plt = plot(plots..., layout=(scale, nchannel), plot_title="First Layer", size=(1080, 960), margin=5Plots.mm) + + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + +""" + plotFirstLayer2DAll(sf, origSig; saveTo=nothing, index=1, color=:grays) +Function that plots all first layer gradient wavelets for a specific example signal `index` across space, along with the original signal. +The variable `index` specifies which example in the batch to plot, `sf` is the `ScatteredOut` object +containing the scattering transform results, `origSig` is the original input signal, and `saveTo` is the file path to save the plot. +""" +function plotFirstLayer2DAll(sf, origSig; saveTo=nothing, index=1, color=:grays, colorScaling=false) + global_min = min(minimum(sf[1][:,:,:,index]), minimum(origSig[:,:,1,index])) + global_max = max(maximum(sf[1][:,:,:,index]), maximum(origSig[:,:,1,index])) + clims = (global_min, global_max) + + if colorScaling + space = visualizeFirstLayer2D(sf; color=color, index=index, clims=clims, colorScaling=false) + orig = heatmap(origSig[:,:,1,index], title="Original Signal", legend=false, axis=false, color=color, colorbar=false, clims=clims) + else + space = visualizeFirstLayer2D(sf; color=color, index=index) + orig = heatmap(origSig[:,:,1,index], title="Original Signal", legend=false, axis=false, color=color, colorbar=false) + end + + l = Plots.@layout [a b{1.0w}] + plt = plot(space, orig, layout=l, size=(1280, 960), margin=5Plots.mm) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + + +""" + plotSecondLayer2DSingleWavelet(j, sf, origSig; color=:grays, saveTo=nothing, index=1) +The variable `j` specifies which wavelet results to plot from the second layer, `index` specifies which example in the batch to plot, +`sf` is the `ScatteredOut` object containing the scattering transform results, and `origSig` is the original input signal. +""" +function plotSecondLayer2DSingleWavelet(firstLayerWaveletIndex, secondLayerWaveletIndex, sf, origSig; color=:grays, saveTo=nothing, index=1) + space = heatmap(sf[2][:,:,secondLayerWaveletIndex,firstLayerWaveletIndex,index], title="Second Layer - Path $firstLayerWaveletIndex -> $secondLayerWaveletIndex", legend=false, axis=false, color=color) + orig = heatmap(origSig[:,:,1,index], title="Original Signal", legend=false, axis=false, color=color) + l = Plots.@layout [a b] + plt = plot(space, orig, layout=l, size=(1280, 720), margin=5Plots.mm) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end + +""" + visualizeSecondLayer2D(sf; color=:grays, index=1) +Function that creates a grid of heatmaps visualizing all the second layer gradient wavelets across space for a specified example index. +The variable `sf` is the `ScatteredOut` object containing the scattering transform results, `index` specifies which example in the batch to plot, and `color` sets the color gradient for the heatmaps. +The function organizes the heatmaps in a grid format, where each row corresponds to a different scale of the second layer wavelets, and each column corresponds to a different channel (e.g., real and imaginary parts). +""" +function visualizeSecondLayer2D(sf; color=:grays, index=1, clims=nothing) + nchannel = 3 + n_layer1 = size(sf[2], 4) + n_layer2 = size(sf[2], 3) + scale_layer1 = Int(n_layer1 / nchannel) + scale_layer2 = Int(n_layer2 / nchannel) + + clims = isnothing(clims) ? (minimum(sf[2][:,:,:,:,index]), maximum(sf[2][:,:,:,:,index])) : clims + plots = [] + + col_titles = ["(1,:)", "(i,:)", "(j,:)"] + row_titles = ["(:,1)", "(:,i)", "(:,j)"] + + for scal2 = 1:scale_layer2 + for k = 1:nchannel + for scal1 = 1:scale_layer1 + for j = 1:nchannel + o = sf[2][:, :, k + nchannel*(scal2-1), j + nchannel*(scal1-1), index] + + is_top_row = (scal2 == 1) && (k == 1) + is_left_col = (scal1 == 1) && (j == 1) + is_right_col = (scal1 == scale_layer1) && (j == nchannel) + is_bottom_row = (scal2 == scale_layer2) && (k == nchannel) + + title_str = is_top_row ? col_titles[j] : "" + ylabel_str = is_left_col ? string(scal2) : "" + xlabel_str = is_bottom_row ? string(scal1) : "" + + hm = heatmap(o, color=color, axis=false, legend=false, + title=title_str, ylabel=ylabel_str, xlabel=xlabel_str, clims=clims) + if is_right_col + annotate!(hm, [(size(o,2)*1.15, size(o,1)/2, + Plots.text(row_titles[k], 8, :left, :black))]) + end + push!(plots, hm) + end + end + end + end + + total_rows = n_layer2 + total_cols = n_layer1 + return plot(plots..., layout=(total_rows, total_cols), margin=2Plots.mm, size=(2000, 2200)) +end + +""" + plotSecondLayer2D(sf; color=:grays, saveTo=nothing, index=1) +Function that creates a grid of heatmaps visualizing all the second layer scattering transform results +at a specified example index. The variable `sf` is the `ScatteredOut` object, `saveTo` is the file path +to save the plot, `index` specifies which example in the batch to plot, and `color` sets the color gradient. +The grid rows correspond to second layer scales/channels and columns to first layer scales/channels. +""" +function plotSecondLayer2D(sf; color=:grays, saveTo=nothing, index=1) + nchannel = 3 + n_layer1 = size(sf[2], 4) + n_layer2 = size(sf[2], 3) + scale_layer1 = Int(n_layer1 / nchannel) + scale_layer2 = Int(n_layer2 / nchannel) + + clims = (minimum(sf[2][:,:,:,:,index]), maximum(sf[2][:,:,:,:,index])) + plots = [] + + col_titles = ["(1,:)", "(i,:)", "(j,:)"] + row_titles = ["(:,1)", "(:,i)", "(:,j)"] + + for scal2 = 1:scale_layer2 + for k = 1:nchannel + for scal1 = 1:scale_layer1 + for j = 1:nchannel + o = sf[2][:, :, k + nchannel*(scal2-1), j + nchannel*(scal1-1), index] + + is_top_row = (scal2 == 1) && (k == 1) + is_left_col = (scal1 == 1) && (j == 1) + is_right_col = (scal1 == scale_layer1) && (j == nchannel) + is_bottom_row = (scal2 == scale_layer2) && (k == nchannel) + + title_str = is_top_row ? col_titles[j] : "" + ylabel_str = is_left_col ? string(scal2) : "" + xlabel_str = is_bottom_row ? string(scal1) : "" + + hm = heatmap(o, color=color, axis=false, legend=false, + title=title_str, ylabel=ylabel_str, xlabel=xlabel_str, clims=clims) + if is_right_col + annotate!(hm, [(size(o,2)*1.15, size(o,1)/2, + Plots.text(row_titles[k], 8, :left, :black))]) + end + push!(plots, hm) + end + end + end + end + + total_rows = n_layer2 + total_cols = n_layer1 + plt = plot(plots..., layout=(total_rows, total_cols), plot_title="Second Layer", + size=(2000, 2200), margin=2Plots.mm, dpi=150) + if !isnothing(saveTo) + mkpath(dirname(saveTo)) + savefig(plt, saveTo) + end + return plt +end \ No newline at end of file diff --git a/src/transform.jl b/src/transform.jl index 64ce56c..c42b0bc 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -123,9 +123,7 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2, also, in the first layer, merge the channels into the first shearing, since max pool isn't defined for arbitrary arrays =# if i == 1 - listOfSizes[i+1] = (pooledSize..., - (nFilters - 1) * listOfSizes[1][Nd+1], - listOfSizes[i][Nd+2:end]...) + listOfSizes[i+1] = (pooledSize..., (nFilters - 1) * listOfSizes[1][Nd+1], listOfSizes[i][Nd+2:end]...) interstitial[3*i-1] = x -> begin ax = axes(x) return σ.(reshape(x[ax[1:Nd]..., ax[Nd+1][1:end-1], ax[(Nd+2):end]...], @@ -157,6 +155,14 @@ function stFlux(inputSize::NTuple{N}, m=2; outputPool = 2, poolBy = 3//2, outputSizes, outputPool, settings) end +function stFlux(mainChain::C, normalize::Bool, outputSizes::D, + outputPool::E, settings::F) where {C,D,E,F} + # Extract Dimension and Depth from the settings or infer from chain + Nd = settings[:outputPool] |> first |> length # spatial dims + m = length(outputSizes) - 1 # depth + stFlux{Nd, m, C, D, E, F}(mainChain, normalize, outputSizes, outputPool, settings) +end + function dispatchLayer(listOfSizes, Nd::Val{1}; varargs...) #= For 1D input signals, we use the conventional wavelet filters available in our `ContinuousWavelets.jl` package. The following function is defined in yet @@ -223,8 +229,8 @@ function extractAddPadding(x, adr, chunkSize, N) justUsing = x[fill(Colon(), N)..., :, adr] if length(adr) < chunkSize actualSize = chunkSize - length(adr) - return cat(justUsing, zeros(eltype(justUsing), size(x)[1:end-1]..., - actualSize), dims=ndims(justUsing)), length(adr) + return cat(justUsing, fill!(similar(justUsing, size(x)[1:end-1]..., + actualSize), 0), dims=ndims(justUsing)), length(adr) else return justUsing, chunkSize end @@ -233,7 +239,12 @@ end trim(x, actualSize) = x[axes(x)[1:end-1]..., 1:actualSize] function breakAndAdapt(St::stFlux{N,D}, x) where {N,D} mc = St.mainChain.layers - chunkSize = size(mc[1].fftPlan)[end] + cpu_chunk = size(mc[1].fftPlan)[end] + chunkSize = if x isa CuArray + min(size(x)[end], cpu_chunk * 32) # up to 32x larger chunks on GPU + else + cpu_chunk + end nSteps = ceil(Int, (size(x)[end]) / chunkSize) # the first entry is taken care of already containerType = typeof(St.mainChain[1].weight[1]) xAxes = axes(x) @@ -242,7 +253,8 @@ function breakAndAdapt(St::stFlux{N,D}, x) where {N,D} # do the first beforehand to get the sizes out = applyScattering(mc, maybeAdapt(containerType, firstEx), ndims(St), St, 0) # create storage - outputs = map(out -> zeros(eltype(out), size(out)[1:end-1]..., size(x)[end]...), out) + # outputs = map(o -> maybeAdapt(typeof(o), zeros(eltype(o), size(o)[1:end-1]..., size(x)[end]...)), out) + outputs = map(o -> maybeAdapt(typeof(o), fill!(similar(o, size(o)[1:end-1]..., size(x)[end]), zero(eltype(o)))), out) # util to write out to outputs at location batchInds function writeOut!(out, batchInds, actualSize) for jj in 1:length(out) @@ -263,7 +275,6 @@ function breakAndAdapt(St::stFlux{N,D}, x) where {N,D} return outputs end - function applyScattering(c::Tuple, x, Nd, st, M) res = first(c)(x) if (typeof(first(c)) <: ConvFFT) || (typeof(first(c)) <: MonoConvFFT) @@ -278,8 +289,7 @@ function applyScattering(c::Tuple, x, Nd, st, M) return (tmpRes, apld...) else tmpRes = r(real.(tmpRes)) - return (tmpRes, - applyScattering(tail(c), res, Nd, st, M + 1)...) + return (tmpRes, applyScattering(tail(c), res, Nd, st, M + 1)...) end else # this is either a reshaping layer or a subsampling layer, so no output @@ -287,6 +297,7 @@ function applyScattering(c::Tuple, x, Nd, st, M) end end + applyScattering(::Tuple{}, x, Nd, st, M) = tuple() # all of the returns should # happen along the way, not at the end @@ -297,23 +308,15 @@ normalize `x` over the dimensions `Nd` through `ndims(x)-1`. For example, if `Nd function normalize(x, Nd) n = ndims(x) totalThisLayer = prod(size(x)[(Nd+1):(n-1)]) - ax = axes(x) - buf = Zygote.Buffer(x) - for i = 1:size(x)[end] - xSlice = x[ax[1:(n-1)]..., i] - thisExample = totalThisLayer / (sum(abs.(xSlice) .^ 2)) .^ (0.5f0) - if isnan(thisExample) || thisExample ≈ Inf - buf[ax[1:(n-1)]..., i] = xSlice - else - buf[ax[1:(n-1)]..., i] = xSlice .* thisExample - end - end - return copy(buf) + sumSqDims = 1:(n-1) + normSq = sum(abs.(x) .^ 2, dims=sumSqDims) + scale = totalThisLayer ./ sqrt.(normSq) + scale = ifelse.(isnan.(scale) .| isinf.(scale), one(eltype(scale)), scale) + return x .* scale end normalize(sct::ScatteredOut, Nd) = ScatteredOut(map(x -> normalize(x, Nd), sct.output), sct.k) normalize(sct::ScatteredFull, Nd) = ScatteredFull(sct.m, sct.k, sct.data, - map(x -> normalize(x, Nd), - sct.output)) + map(x -> normalize(x, Nd), sct.output)) \ No newline at end of file diff --git a/src/utilities.jl b/src/utilities.jl index 22ac36d..a673b5d 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -35,11 +35,11 @@ export adapt Get the wavelets used in each layer. If `spaceDomain` is `true`, then it will also convert the filters from the stored positive Fourier representation to a space version. """ function getWavelets(sc::stFlux; spaceDomain=false) - freqDomain = map(x -> x.weight, filter(x -> (typeof(x) <: ConvFFT), sc.mainChain.layers)) # filter to only have ConvFFTs, and then return the wavelets of those + # freqDomain = map(x -> x.weight, filter(x -> (typeof(x) <: ConvFFT), sc.mainChain.layers)) # filter to only have ConvFFTs, and then return the wavelets of those if spaceDomain return map(originalDomain, filter(x -> (typeof(x) <: ConvFFT), sc.mainChain.layers)) else - return map(x -> x.weight, filter(x -> (typeof(x) <: ConvFFT), sc.mainChain.layers)) + return map(x -> x.weight, filter(x -> (typeof(x) <: ConvFFT), sc.mainChain.layers)) end end @@ -101,23 +101,32 @@ flatten(scatRes) = scatRes Given a scattering transform `st` and an array `toRoll` that is `NCoeffs×extraDims`, "roll" up `toRoll` into a `ScatteredOut`. """ function roll(toRoll, st::stFlux) + toRoll = ChainRulesCore.unthunk(toRoll) + toRoll = collect(toRoll) + Nd = ndims(st) oS = st.outputSizes roll(toRoll, oS, Nd) end function roll(toRoll, stOutput::S) where {S<:Scattered} + toRoll = ChainRulesCore.unthunk(toRoll) + toRoll = collect(toRoll) + Nd = ndims(stOutput) oS = map(size, stOutput.output) return roll(toRoll, oS, Nd) end +#= function roll(toRoll, oS::Tuple, Nd) + toRoll = collect(toRoll) + nExamples = size(toRoll)[2:end] - rolled = ([adapt(typeof(toRoll), zeros(eltype(toRoll), - sz[1:Nd+nPathDims(ii)]..., - nExamples...)) for (ii, sz) in - enumerate(oS)]...,) + rolled = ([fill!(similar(toRoll, eltype(toRoll), + sz[1:Nd+nPathDims(ii)]..., + nExamples...),0) for (ii, sz) in + enumerate(oS)]...,) locSoFar = 0 for (ii, x) in enumerate(rolled) @@ -130,6 +139,32 @@ function roll(toRoll, oS::Tuple, Nd) end return ScatteredOut(rolled, Nd) end +=# +function roll(toRoll, oS::Tuple, Nd) + toRoll = collect(toRoll) + + nExamples = ntuple(i -> size(toRoll, i+1), ndims(toRoll)-1) + rolled = ntuple(length(oS)) do ii + sz = oS[ii] + dims = (sz[1:Nd+nPathDims(ii)]..., nExamples...) + zeros(eltype(toRoll), dims) + end + + locSoFar = 0 + for (ii, x) in enumerate(rolled) + szThisLayer = oS[ii][1:Nd+nPathDims(ii)] + totalThisLayer = prod(szThisLayer) + + range = (locSoFar+1):(locSoFar+totalThisLayer) + addresses = (szThisLayer..., nExamples...) + + rolled[ii] .= reshape(toRoll[range, :], addresses) + + locSoFar += totalThisLayer + end + + return ScatteredOut(rolled, Nd) +end """ @@ -170,18 +205,14 @@ transform `x` using `stack`, but where `x` and `stack` may have different batch function batchOff(stack, x, batchSize) nRounds = ceil(Int, size(x)[end] // batchSize) firstRes = stack(x[:, :, :, 1:batchSize]) - result = cu(zeros(size(firstRes)[1:end-1]..., size(x)[end])) + result = fill!(similar(firstRes, size(firstRes)[1:end-1]..., size(x)[end]), 0) result[:, 1:batchSize] = firstRes for i = 2:(nRounds-1) result[:, 1+(i-1)*batchSize:(i*batchSize)] = stack(x[:, :, :, 1+(i-1)*batchSize:(i*batchSize)]) end - result[:, (1+(nRounds-1)*batchSize):end] = stack(cat(x[:, :, :, - (1+(nRounds-1)*batchSize):end], - cu(zeros(size(x)[1:3]..., - nRounds * batchSize - - - size(x, 4))), - dims=4))[:, 1:(size(x, 4)-(nRounds-1)*batchSize)] + padding = fill!(similar(x, size(x)[1:3]..., nRounds*batchSize - size(x,4)), 0) + result[:, (1+(nRounds-1)*batchSize):end] = stack(cat(x[:, :, :, + (1+(nRounds-1)*batchSize):end], padding, dims=4))[:, 1:(size(x,4)-(nRounds-1)*batchSize)] return result end @@ -240,6 +271,57 @@ function size(st::stFlux) sz = l.fftPlan.sz es = originalSize(sz[1:ndims(l.weight) - 1], l.bc) end - + return es end + + +""" + reshapeInputs(dataMat; is2DData=false) -> reshapedData, dims + +Reshapes a data matrix from (N, numSamples) format to (N, 1, numSamples) format. +Alos converts a data matrix from (N, M, numSamples) format to (N, M, numSamples) format. +Make signals suitable as scatteringTransform input. Can be used for 1D and 2D signals. + +# Arguments +- `dataMat`: Matrix where rows are data points and columns are different samples +- `is2DData`: If `true`, treats the input as 2D data and reshapes accordingly. If `false`, treats the input as 1D data. + +# Returns +- Reshaped array of size (N, M, numSamples) +- Tuple of dimensions for scatteringTransform + +# Example +```julia +N = 2047 +f = testfunction(N, "Doppler") +g = testfunction(N, "Bumps") +dataMat = hcat(f, g) +signals, dims = reshapeInputs(dataMat) +St = scatteringTransform(dims, 2, cw=Morlet(π), β=2, σ=abs) +sOut = St(signals) +""" +function reshapeInputs(dataMat; is2DData=false) + # Convert vector to column matrix if needed + if ndims(dataMat) == 1 + dataMat = reshape(dataMat, :, 1) + end + + if ndims(dataMat) == 2 + N, M = size(dataMat) + if is2DData + reshapedData = reshape(dataMat, N, M, 1, 1) + dims = (N, M, 1, 1) + else + reshapedData = reshape(dataMat, N, 1, M) + dims = (N, 1, M) + end + elseif ndims(dataMat) == 3 + N, M, numInputs = size(dataMat) + reshapedData = reshape(dataMat, N, M, 1, numInputs) + dims = (N, M, 1, numInputs) + else + error("Input data must be a vector, matrix, or 3D array.") + end + return reshapedData, dims +end diff --git a/test/GPUTests.jl b/test/GPUTests.jl new file mode 100644 index 0000000..e31d43f --- /dev/null +++ b/test/GPUTests.jl @@ -0,0 +1,157 @@ +using ScatteringTransform +using ContinuousWavelets +using AbstractFFTs, FFTW +using Test, LinearAlgebra, Statistics +using Flux, FourierFilterFlux, CUDA +using Zygote +using BenchmarkTools + +const gpu_available = CUDA.functional() + +@testset "GPU Tests" begin + if !gpu_available + @warn "No functional GPU found — skipping GPU tests" + else + @info "CUDA functional — running GPU comparison and timing tests" + + @testset "CPU/GPU consistency, 1D" begin + init = randn(Float32, 64, 1, 2) + sst = stFlux(size(init), 2, poolBy=3 // 2) + + resCPU = sst(init) + + sstGPU = cu(sst) + initGPU = cu(init) + resGPU = sstGPU(initGPU) + + @test typeof(resGPU.output[1]) <: CuArray + + for (cpuLayer, gpuLayer) in zip(resCPU.output, resGPU.output) + @test cpuLayer ≈ Array(gpuLayer) atol = 1e-3 + end + end + + #= + @testset "CPU/GPU consistency, 2D" begin + n_init_channels = 2 + batch_size = 2 + init = randn(Float32, 32, 32, n_init_channels, batch_size) + sst = stFlux(size(init), 2, poolBy=3 // 2, outputPool=(2,)) + + resCPU = sst(init) + + sstGPU = cu(sst) + initGPU = cu(init) + resGPU = sstGPU(initGPU) + + @test typeof(resGPU.output[1]) <: CuArray + + for (cpuLayer, gpuLayer) in zip(resCPU.output, resGPU.output) + @test cpuLayer ≈ Array(gpuLayer) atol = 1e-3 + end + end + =# + + @testset "roll/flatten CPU vs GPU" begin + initCPU = randn(Float32, 64, 1, 2) + sst = stFlux(size(initCPU), 2, poolBy=3 // 2) + resCPU = sst(initCPU) + smooshedCPU = ScatteringTransform.flatten(resCPU) + + sstGPU = cu(sst) + initGPU = cu(initCPU) + resGPU = sstGPU(initGPU) + smooshedGPU = ScatteringTransform.flatten(resGPU) + + @test typeof(smooshedGPU) <: CuArray + @test Array(smooshedGPU) ≈ smooshedCPU atol = 1e-3 + + reconstCPU = roll(smooshedCPU, sst) + reconstGPU = roll(smooshedGPU, sstGPU) + + @test all(reconstCPU .≈ resCPU) + for (cpuLayer, gpuLayer) in zip(reconstGPU.output, resGPU.output) + @test Array(cpuLayer) ≈ Array(gpuLayer) atol = 1e-3 + end + end + + @testset "normalize CPU vs GPU" begin + x = randn(Float32, 10, 4, 3, 5, 7) + xGPU = cu(x) + + xpCPU = ScatteringTransform.normalize(x, 2) + xpGPU = ScatteringTransform.normalize(xGPU, 2) + + @test typeof(xpGPU) <: CuArray + @test Array(xpGPU) ≈ xpCPU atol = 1e-3 + + for w in eachslice(xpCPU, dims=ndims(x)) + @test norm(w, 2) ≈ 3 * 5 + end + for w in eachslice(Array(xpGPU), dims=ndims(x)) + @test norm(w, 2) ≈ 3 * 5 + end + end + + @testset "Gradients CPU vs GPU" begin + init = randn(Float32, 64, 1, 1) + initGPU = cu(init) + sst = stFlux(size(init), 2, poolBy=3 // 2) + sstGPU = cu(sst) + + CUDA.allowscalar(true) + ∇CPU_Zeroth = Zygote.gradient(x -> sst(x)[0][19,1,1], init)[1] + ∇GPU_Zeroth = Zygote.gradient(x -> sstGPU(x)[0][19,1,1], initGPU)[1] + + ∇CPU_First = Zygote.gradient(x -> sst(x)[1][11,5,1], init)[1] + ∇GPU_First = Zygote.gradient(x -> sstGPU(x)[1][11,5,1], initGPU)[1] + + ∇CPU_Second = Zygote.gradient(x -> sst(x)[2][3,5,5,1], init)[1] + ∇GPU_Second = Zygote.gradient(x -> sstGPU(x)[2][3,5,5,1], initGPU)[1] + + @test typeof(∇GPU_Zeroth) <: CuArray + @test Array(∇GPU_Zeroth) ≈ ∇CPU_Zeroth atol = 1e-3 + + @test typeof(∇GPU_First) <: CuArray + @test Array(∇GPU_First) ≈ ∇CPU_First atol = 1e-3 + + @test typeof(∇GPU_Second) <: CuArray + @test Array(∇GPU_Second) ≈ ∇CPU_Second atol = 1e-3 + end + + @testset "CPU/GPU timing" begin + sizes = [256, 2048, 16384, 131072] + cpu_max_size = 16384 + + for sz in sizes + GC.gc() + CUDA.reclaim() + + init = randn(Float32, sz, 1, 1) + sst = stFlux(size(init), 2, poolBy=3 // 2) + sstGPU = cu(sst) + initGPU = cu(init) + + if sz > cpu_max_size + CUDA.@sync sstGPU(initGPU) # warmup + GC.gc(); CUDA.reclaim() + tGPU = @elapsed (CUDA.@sync sstGPU(initGPU)) + else + tGPU = @belapsed (CUDA.@sync $sstGPU($initGPU)) + end + + tCPU = @belapsed $sst($init) + speedup = tCPU / tGPU + @info "size=$sz" tCPU tGPU speedup + if sz >= 512 + @test tGPU < tCPU + end + + sstGPU = nothing + initGPU = nothing + GC.gc() + CUDA.reclaim() + end + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e86c93f..e2ab2fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,3 +9,4 @@ using Zygote include("pathTests.jl") include("fluxtests.jl") +include("GPUTests.jl")