Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0a004ad
Update plots, documentation, and utility functions
JaredW40 Jan 29, 2026
a29f4ae
Merge branch 'master' of https://github.com/JaredW40/ScatteringTransf…
JaredW40 Jan 29, 2026
1e53b3c
Updated function calls in documentation.
JaredW40 Jan 29, 2026
efe3461
Update plots documentation and scattering transform implementation
JaredW40 Feb 2, 2026
bb2ca63
Fix syntax error in reshapeInputs function
JaredW40 Feb 2, 2026
2491357
Merge branch 'BoundaryValueProblems:master' into master
JaredW40 Feb 13, 2026
4e55074
Merge branch 'BoundaryValueProblems:master' into master
JaredW40 Feb 21, 2026
0343058
Modify reshapeInputs to include is2DData parameter
JaredW40 Feb 21, 2026
a6744c2
your message here
JaredW40 Feb 21, 2026
ba42522
Couple of changes to functions and attempt to avoid test errors.
JaredW40 Feb 21, 2026
6f29c59
Added a number of 2D plotting functions
JaredW40 Feb 22, 2026
327a9e5
Added support for a number of 2D plotting functions (some are still u…
JaredW40 Apr 2, 2026
c28b171
Added GPU support and tests.
JaredW40 Jun 23, 2026
fbffd95
Remove Manifest.toml from tracking
JaredW40 Jun 23, 2026
977d9e4
Resolve merge conflict in scatteringplots.jl
JaredW40 Jun 23, 2026
45cdd2f
Fix docs example blocks to load Plots and trigger ScatteringPlotsExt
JaredW40 Jun 23, 2026
a6b91ec
Fix docs make.jl and plots.md
JaredW40 Jun 23, 2026
f90c3b0
Fix docs make.jl and remove duplicate plotOriginalSignal1D
JaredW40 Jun 23, 2026
443ed09
Enhance example with ScatteringPlotsExt imports
JaredW40 Jun 23, 2026
8391287
Fixed ScatteringPlotsExt imports and docs errors
JaredW40 Jun 23, 2026
546b123
Remove Manifest.toml from tracking
JaredW40 Jun 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ concurrency:
env:
JULIA_PKG_SERVER: ""
# Fix for Windows: Forces GR/Plots to be headless
JULIA_NUM_THREADS: 1
JULIA_PKG_PRECOMPILE_AUTO: 0
GKSwstype: "100"
# Fix for macOS: Forces Matplotlib to use non-interactive backend
MPLBACKEND: "Agg"
Expand Down
Binary file modified .gitignore
Binary file not shown.
17 changes: 14 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FourierFilterFlux = "3d7dfd45-6c90-4c9b-b697-194a05757159"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GLMNet = "8d5ece8b-de18-5317-b113-243142960cc6"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -44,26 +46,35 @@ Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
WaveletsExt = "8f464e1e-25db-479f-b0a5-b7680379e03f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
cuFFT = "533571aa-0936-420e-b4be-9c66f5f626ca"

[extensions]
ScatteringPlotsExt = "Plots"

[compat]
AbstractFFTs = "1"
Adapt = "3, 4"
CUDA = "4, 5"
CUDA = "4 - 6"
ContinuousWavelets = "1"
Distributions = "0.25"
FFTW = "1"
FileIO = "1"
FiniteDifferences = "0.12.34"
Flux = "0.13, 0.14, 0.15, 0.16"
Functors = "0.5.2"
JLD2 = "0.4, 0.5, 0.6"
Plots = "1"
Plots = "1.41.6"
Reexport = "1"
Wavelets = "0.9, 0.10"
WaveletsExt = "0.2.3"
Zygote = "0.6, 0.7"
cuFFT = "6.2.0"
julia = "1.10"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "CUDA", "BenchmarkTools"]
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ ENV["PLOTS_TEST"] = "true"
ENV["GKSwstype"] = "100"
ENV["LINES"] = "9"
ENV["COLUMNS"] = "60"
using Documenter, ScatteringTransform, ScatteringPlots
using Documenter, ScatteringTransform, Plots
mkpath("docs/src/figures")
const ScatteringPlotsExt = Base.get_extension(ScatteringTransform, :ScatteringPlotsExt)

makedocs(
sitename = "ScatteringTransform.jl",
format = Documenter.HTML(),
modules = [ScatteringTransform, ScatteringPlots],
modules = [ScatteringTransform, ScatteringPlotsExt],
authors="David Weber, Naoki Saito, Jared White",
clean=true,
checkdocs = :exports,
Expand Down
Binary file modified docs/src/figures/firstLayer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/firstLayerAll.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/jointPlot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/secondLayer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/secondLayerSpecificPath.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/sliceByFirst.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/sliceBySecond.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/figures/zerothLayer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 25 additions & 16 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,34 @@ For a comparable package in python, see [Kymatio](https://www.kymat.io/).

## Basic Example

```@setup
using ScatteringTransform, Wavelets, Plots
```@example ex
using ScatteringTransform, ContinuousWavelets, Wavelets, Plots
const ScatteringPlotsExt = Base.get_extension(ScatteringTransform, :ScatteringPlotsExt)
import .ScatteringPlotsExt: plotZerothLayer1D, plotFirstLayer1D, plotFirstLayer1DAll,
plotSecondLayer1DFixAndVary, plotSecondLayer1DSpecificPath,
plotSecondLayer1D, jointPlot1D
nothing # hide
```

As an example signal, lets work with the doppler signal:
As an example signal, let's work with a doppler signal:

```@example ex
using Wavelets, Plots
N = 2047
f = testfunction(N, "Doppler")
plot(f, legend=false, title="Doppler signal")
signal = testfunction(N, "Doppler")
plot(signal, legend=false, title="Doppler signal")
savefig("figures/rawDoppler.svg"); #hide
nothing # hide
```

![](figures/rawDoppler.svg)

First we need to make a `scatteringTransform` instance, which will create and store all of the necessary filters, subsampling operators, nonlinear functions, etc.
The parameters are described in the `scatteringTransform` type.
The parameters are described in the `scatteringTransform` type. The function `reshapeInputs` converts data matrices and vectors into a usable form and returns the reshaped array and its dimensions. It works for any shaped input signal.
Since the Doppler signal is smooth, but with varying frequency, let's set the wavelet family `cw=Morlet(π)` specifies the mother wavelet to be a Morlet wavelet with mean frequency π, and frequency spacing `β=2`:

```@example ex
using ScatteringTransform, ContinuousWavelets
St = scatteringTransform((N, 1, 1), 2, cw=Morlet(π), β=2, σ=abs)
f, dims = reshapeInputs(signal)
St = scatteringTransform(dims, 2, cw=Morlet(π), β=2, σ=abs)
sf = St(f)
```

Expand All @@ -56,7 +60,7 @@ plotZerothLayer1D(sf)
The first layer is the average of the absolute value of the scalogram:

```@example ex
plotFirstLayer(sf, St)
plotFirstLayer1D(sf, St)
```

With the plotting utilities included in this package, you are able to display the previous plot along with the original signal and the first layer wavelet gradients:
Expand All @@ -73,7 +77,7 @@ With our plotting utilities, you can display the second layer with respect to sp
To this end, lets make two gifs, the first with the _first_ layer frequency varying with time:

```@example ex
plotSecondLayerFixAndVary(sf, St, 1:30, 1, fps=1, saveTo="figures/sliceByFirst.gif")
plotSecondLayer1DFixAndVary(sf, St, 1:30, 1, fps=1, saveTo="figures/sliceByFirst.gif")
nothing # hide
```
![](figures/sliceByFirst.gif)
Expand All @@ -84,15 +88,15 @@ As the first layer frequency increases, the energy concentrates to the beginning
The second has the _second_ layer frequency varying with time:

```@example ex
plotSecondLayerFixAndVary(sf, St, 1, 1:28, fps=1, saveTo="figures/sliceBySecond.gif")
plotSecondLayer1DFixAndVary(sf, St, 1, 1:28, fps=1, saveTo="figures/sliceBySecond.gif")
nothing # hide
```
![](figures/sliceBySecond.gif)

If desired, this package allows one to plot the results of a specific path. Here is an example, where we are plotting the resulting plot if we were to use first layer wavelet 3 and second layer wavelet 1.

```@example ex
plotSecondLayerSpecificPath(sf, St, 3, 1, f)
plotSecondLayer1DSpecificPath(sf, St, 3, 1, f)
```

For any fixed second layer frequency, we get approximately the curve in the first layer scalogram, with different portions emphasized, and the overall mass decreasing as the frequency increases, corresponding to the decreasing amplitude of the envelope for the doppler signal.
Expand All @@ -101,16 +105,21 @@ These plots can also be created using various plotting utilities defined in this
For example, we can generate a denser representation with the `plotSecondLayer` function:

```@example ex
plotSecondLayer(sf, St)
plotSecondLayer1D(sf, St)
```

where the frequencies are along the axes, the heatmap gives the largest value across time for that path, and at each path is a small plot of the averaged timecourse.


### Joint Plot

Finally, we can constuct a joint plot of much of our prior information. This plot will display the zeroth layer, first layer and second layer information for a given example.
Finally, we can constuct a joint plot of much of our prior information. This plot will display the original signal, zeroth layer, first layer and second layer information for a given example.

```@example ex
jointPlot(sf, "Scattering Transform", :viridis, St)
jointPlot1D(sf, "Scattering Transform", :viridis, St, f)
```


## Future Updates

In the future we will be adding plotting support for the third layer of the Scattering Transform. In addition, 2D variants of the plotting functions will also be created and documented here with examples.
35 changes: 25 additions & 10 deletions docs/src/plots.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
# Plotting Scattering Transforms


## 1D Plotting Functions
```@docs
ScatteringTransform.plotZerothLayer1D
ScatteringTransform.plotFirstLayer1D
ScatteringTransform.gifFirstLayer
ScatteringTransform.plotFirstLayer1DAll
ScatteringTransform.plotFirstLayer
ScatteringTransform.plotSecondLayerSpecificPath
ScatteringTransform.plotSecondLayer1DSubsetGif
ScatteringTransform.plotSecondLayerFixAndVary
ScatteringTransform.plotSecondLayer
ScatteringTransform.jointPlot
ScatteringPlotsExt.plotZerothLayer1D
ScatteringPlotsExt.plotFirstLayer1DSingleWavelet
ScatteringPlotsExt.gifFirstLayer1D
ScatteringPlotsExt.plotFirstLayer1DAll
ScatteringPlotsExt.plotFirstLayer1D
ScatteringPlotsExt.plotSecondLayer1DSpecificPath
ScatteringPlotsExt.gifSecondLayer1DSubset
ScatteringPlotsExt.plotSecondLayer1DFixAndVary
ScatteringPlotsExt.plotSecondLayer1D
ScatteringPlotsExt.jointPlot1D
```

## 2D Plotting Functions
```@docs
ScatteringPlotsExt.plotOriginalSignal2D
ScatteringPlotsExt.plotZerothLayer2D
ScatteringPlotsExt.plotFirstLayer2DSingleWavelet
ScatteringPlotsExt.visualizeFirstLayer2D
ScatteringPlotsExt.plotFirstLayer2D
ScatteringPlotsExt.plotFirstLayer2DAll
ScatteringPlotsExt.plotSecondLayer2DSingleWavelet
ScatteringPlotsExt.visualizeSecondLayer2D
ScatteringPlotsExt.plotSecondLayer2D
```
1 change: 1 addition & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ ScatteringTransform.normalize
ScatteringTransform.processArgs
ScatteringTransform.getParameters
ScatteringTransform.extractAddPadding
ScatteringTransform.reshapeInputs
```
12 changes: 12 additions & 0 deletions ext/ScatteringPlotsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module ScatteringPlotsExt

using ScatteringTransform
using Plots
using FFTW: rfft, irfft, fft, ifft
using LinearAlgebra: norm

include(joinpath(@__DIR__, "..", "src", "scatteringplots.jl"))
export plotOriginalSignal1D, plotZerothLayer1D, plotFirstLayer1DSingleWavelet, gifFirstLayer1D, plotFirstLayer1DAll, plotFirstLayer1D, plotSecondLayer1DOld, plotSecondLayer1DSpecificPath, gifSecondLayer1DSubset, plotSecondLayer1DFixAndVary, plotSecondLayer1D, jointPlot1D
export plotOriginalSignal2D, plotZerothLayer2D, plotFirstLayer2DSingleWavelet, visualizeFirstLayer2D, plotFirstLayer2D, plotFirstLayer2DAll, plotSecondLayer2DSingleWavelet, visualizeSecondLayer2D, plotSecondLayer2D

end
24 changes: 20 additions & 4 deletions src/ScatteringTransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Adapt
using RecipesBase
using Base: tail
using ChainRulesCore
using Plots # who cares about weight really?
# using Plots # who cares about weight really?
using Statistics
using Dates

Expand Down Expand Up @@ -43,8 +43,24 @@ end

include("utilities.jl")
export getWavelets, flatten, roll, importantCoords, batchOff, getParameters, getMeanFreq, computeLoc
export roll, wrap, flatten
export roll, wrap, flatten, reshapeInputs
include("adjoints.jl")
include("scatteringplots.jl")
export plotZerothLayer1D, plotFirstLayer1D, gifFirstLayer, plotFirstLayer1DAll, plotFirstLayer, plotSecondLayer, plotSecondLayer1D, plotSecondLayerSpecificPath, plotSecondLayer1DSubsetGif, plotSecondLayerFixAndVary, jointPlot

for f in [:plotOriginalSignal1D, :plotZerothLayer1D, :plotFirstLayer1DSingleWavelet,
:gifFirstLayer1D, :plotFirstLayer1DAll, :plotFirstLayer1D,
:plotSecondLayer1DSpecificPath, :gifSecondLayer1DSubset,
:plotSecondLayer1DFixAndVary, :plotSecondLayer1D, :jointPlot1D,
:plotOriginalSignal2D, :plotZerothLayer2D, :plotFirstLayer2DSingleWavelet,
:visualizeFirstLayer2D, :plotFirstLayer2D, :plotFirstLayer2DAll,
:plotSecondLayer2DSingleWavelet, :visualizeSecondLayer2D, :plotSecondLayer2D]
@eval function $f(args...; kwargs...)
ext = Base.get_extension(ScatteringTransform, :ScatteringPlotsExt)
if isnothing(ext)
error("Load Plots.jl to use plotting functions: `using Plots`")
end
getfield(ext, $(QuoteNode(f)))(args...; kwargs...)
end
@eval export $f
end

end # end Module
53 changes: 42 additions & 11 deletions src/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ end

Zygote.@adjoint function getindex(F::T, i::Integer) where {T<:Scattered}
function getInd_rrule(Ȳ)
zeroNonRefed = map(ii -> ii - 1 == i ? Ȳ : zeros(eltype(F.output[ii]),
size(F.output[ii])...),
(1:length(F.output)...,))
# zeroNonRefed = map(ii -> ii - 1 == i ? Ȳ : zeros(eltype(F.output[ii]), size(F.output[ii])...), (1:length(F.output)...,))
zeroNonRefed = map(ii -> ii - 1 == i ? Ȳ : zero(F.output[ii]), (1:length(F.output)...,))
∂F = T(F.m, F.k, zeroNonRefed)
return ∂F, nothing
end
Expand All @@ -24,9 +23,8 @@ end

Zygote.@adjoint function getindex(F::T, inds::AbstractArray) where {T<:Scattered}
function getInd_rrule(Ȳ)
zeroNonRefed = map(ii -> ii - 1 in inds ? Ȳ[indexin(ii - 1, inds)[1]] :
zeros(eltype(F.output[ii]), size(F.output[ii])...),
(1:length(F.output)...,))
# zeroNonRefed = map(ii -> ii - 1 in inds ? Ȳ[indexin(ii - 1, inds)[1]] : zeros(eltype(F.output[ii]), size(F.output[ii])...), (1:length(F.output)...,))
zeroNonRefed = map(ii -> ii - 1 in inds ? Ȳ[indexin(ii - 1, inds)[1]] : zero(F.output[ii]), (1:length(F.output)...,))
∂F = T(F.m, F.k, zeroNonRefed)
return ∂F, nothing
end
Expand All @@ -35,9 +33,8 @@ end

Zygote.@adjoint function getindex(x::T, p::pathLocs) where {T<:Scattered}
function getInd_rrule(Δ)
zeroNonRefed = map(ii -> zeros(eltype(x.output[ii]),
size(x.output[ii])...),
(1:length(x.output)...,))
# zeroNonRefed = map(ii -> zeros(eltype(x.output[ii]), size(x.output[ii])...), (1:length(x.output)...,))
zeroNonRefed = map(ii -> zero(x.output[ii]), (1:length(x.output)...,))
∂x = T(x.m, x.k, zeroNonRefed)
∂x[p] = Δ
return ∂x, nothing
Expand All @@ -47,13 +44,47 @@ end

function rrule(::typeof(flatten), scatRes)
function ∇flatten(Δarray)
return (NO_FIELDS, roll(Δarray, scatRes),)
return (NoTangent(), roll(Δarray, scatRes),)
end
return flatten(scatRes), ∇flatten
end
function rrule(::typeof(roll), toRoll, stOutput)
function ∇roll(Δ)
return NO_FIELDS, flatten(Δ), NO_FIELDS
return NoTangent(), flatten(Δ), NoTangent()
end
return roll(toRoll, stOutput), ∇roll
end


function ChainRulesCore.rrule(::typeof(normalize), x, Nd)
n = ndims(x)
totalThisLayer = prod(size(x)[(Nd+1):(n-1)])
sumSqDims = 1:(n-1)
normSq = sum(abs.(x) .^ 2, dims=sumSqDims)
scale = totalThisLayer ./ sqrt.(normSq)
scale = ifelse.(isnan.(scale) .| isinf.(scale), one(eltype(scale)), scale)
y = x .* scale

function normalize_pullback(Δy)
# All operations stay on the same device as x
∂x = Δy .* scale
return NoTangent(), ∂x, NoTangent()
end
return y, normalize_pullback
end

Zygote.@adjoint function normalize(x, Nd)
n = ndims(x)
totalThisLayer = prod(size(x)[(Nd+1):(n-1)])
sumSqDims = 1:(n-1)
normSq = sum(abs.(x) .^ 2, dims=sumSqDims)
scale = totalThisLayer ./ sqrt.(normSq)
scale = ifelse.(isnan.(scale) .| isinf.(scale), one(eltype(scale)), scale)
y = x .* scale
function normalize_pullback(Δy)
scale_adapted = adapt(typeof(Δy), scale)
∂x = Δy .* scale_adapted
return ∂x, nothing
end
return y, normalize_pullback
end
Loading
Loading