From 81d7838fbbc65e9cdc94fd8ff6a36a6848b635b1 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 19 Nov 2023 19:37:11 +0100 Subject: [PATCH 1/8] Update to dl4j 1.0.0-M2.1 (and java 21) --- pom.xml | 85 +++++++++++++------ .../java/ampcontrol/admin/AmpControlMain.java | 9 +- .../amp/midi/MidiProgChangeFactory.java | 2 +- .../ampcontrol/amp/probabilities/ArgMax.java | 2 +- .../amp/probabilities/ThresholdFilter.java | 2 +- .../audio/asio/AsioAudioInputBuffer.java | 2 +- .../ProcessingFactoryFromString.java | 3 - .../EnsembleWeightedSumClassifier.java | 4 +- .../model/inference/SpyClassifier.java | 2 +- .../inference/StoredGraphClassifier.java | 2 - .../training/ModelEvaluationWorkBench.java | 6 +- .../model/training/TrainingDescription.java | 5 +- .../model/training/TrainingHarness.java | 6 +- .../training/data/CyclingLabelSupplier.java | 2 +- .../data/iterators/factory/AutoFromSize.java | 6 +- .../Cnn2DtoCnn1DInputPreprocessor.java | 2 +- .../CnnHeightWidthSwapInputPreprocessor.java | 6 +- .../SequentialHoldFileSupplier.java | 2 +- .../listen/ActivationContribution.java | 4 +- .../description/MutatingConv2dFactory.java | 9 +- .../crossover/graph/CrossoverPoint.java | 3 - .../evolve/mutate/MemoryAwareMutation.java | 10 +-- .../layer/blockfunctions/SpyFunction.java | 2 +- .../evolve/transfer/ParameterTransfer.java | 12 ++- .../evolve/transfer/SingleTransferTask.java | 16 ++-- .../evolve/transfer/TransferRegistry.java | 10 ++- .../model/vertex/ChannelMultVertex.java | 5 +- .../model/vertex/ChannelMultVertexImpl.java | 16 ++-- .../model/vertex/ElementWiseVertexLatest.java | 5 +- .../vertex/ElementWiseVertexLatestImpl.java | 14 +-- .../model/vertex/EpsilonSpyVertex.java | 5 +- .../model/vertex/EpsilonSpyVertexImpl.java | 11 +-- .../audio/Cnn2DInputProviderTest.java | 12 ++- .../EnsembleWeightedSumClassifierTest.java | 11 +-- .../model/inference/SpyClassifierTest.java | 10 ++- .../model/training/TrainingHarnessTest.java | 2 +- .../CnnToManyToOneRnnPreProcessorTest.java | 4 +- .../validate/ValidateCachingIter.java | 4 +- .../listen/ActivationContributionTest.java | 6 +- .../model/training/listen/MockModel.java | 7 +- .../model/GenericModelHandleTest.java | 4 +- .../training/model/evolve/GraphUtils.java | 4 +- .../graph/ParameterTransferCrossover.java | 8 ++ .../fitness/InstrumentEpsilonSpiesTest.java | 2 +- .../model/evolve/mutate/NoutMutationTest.java | 5 +- .../layer/blockfunctions/SpyFunctionTest.java | 2 +- .../evolve/mutate/util/ForwardOfTest.java | 4 +- .../transfer/MergeTransferBufferTest.java | 8 +- .../ParameterTransferNoutMutationTest.java | 53 +++++++----- .../transfer/SingleTransferTaskTest.java | 20 +++-- .../layerblocks/graph/DummyOutputLayer.java | 4 +- .../model/layerblocks/graph/ForkAggTest.java | 8 +- .../layerblocks/graph/MinMaxPoolTest.java | 11 ++- .../model/layerblocks/graph/ResBlockTest.java | 9 +- .../vertex/ChannelMultVertexImplTest.java | 5 +- 55 files changed, 276 insertions(+), 197 deletions(-) diff --git a/pom.xml b/pom.xml index 4ced7eae..0e6ef1cf 100644 --- a/pom.xml +++ b/pom.xml @@ -3,7 +3,7 @@ 4.0.0 AmpControl AmpControl - 0.5.2-SNAPSHOT + 0.5.3-SNAPSHOT AmpControl Controls amplifiers @@ -11,32 +11,22 @@ UTF-8 bin - 1.8 - 1.8 - 1.8 - 1.0.0-beta3 - 1.0.0-beta3 + 21 + 21 + 1.21 + 1.0.0-M2.1 + 1.0.0-M2.1 1.7.21 - 1.0.0-beta3 - 1.0.0-beta3 - 1.0.0-beta3 + 1.0.0-beta7 4.11 3.5.2 1.72 - - - 1.0.0-beta3_spark_1 - 1.0.0-beta3_spark_1 - - - - 2.11 19.0 1.1.7 1.0.13 1.0.23 - 3.7.0 + 3.11.0 2.4.3 1.4.0 3.3.1 @@ -44,7 +34,7 @@ 1.11.109 2.6.6 3.2.2 - 1.18.2 + 1.18.30 @@ -58,6 +48,13 @@ ${maven.compiler.source} ${maven.compiler.target} + + + org.projectlombok + lombok + ${lombok.version} + + @@ -86,7 +83,7 @@ - + org.jacoco jacoco-maven-plugin @@ -131,7 +128,25 @@ org.apache.maven.plugins maven-surefire-plugin - 2.21.0 + 3.2.2 + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.1 + + + + exec + java + + + + + ampcontrol.model.training.model.perftest.ConvPerfTest + maven + @@ -192,17 +207,22 @@ true - nd4j-cuda-10.0-platform + nd4j-cuda-11.4-platform + org.bytedeco + cuda-platform-redist + 11.4-8.2-1.5.6 + + org.deeplearning4j deeplearning4j-core @@ -215,12 +235,11 @@ org.deeplearning4j - deeplearning4j-ui_2.11 + deeplearning4j-ui ${dl4j.version} org.nd4j - ${nd4j-backend} ${nd4j.version} @@ -288,6 +307,16 @@ slf4j-api ${slf4j.version} + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + org.jetbrains + annotations + 24.1.0 + @@ -295,12 +324,12 @@ jzy3d-snapshots Jzy3d Snapshots - http://maven.jzy3d.org/snapshots + https://maven.jzy3d.org/snapshots jzy3d-releases Jzy3d Releases - http://maven.jzy3d.org/releases + https://maven.jzy3d.org/releases jitpack.io diff --git a/src/main/java/ampcontrol/admin/AmpControlMain.java b/src/main/java/ampcontrol/admin/AmpControlMain.java index 8264efa4..53237044 100644 --- a/src/main/java/ampcontrol/admin/AmpControlMain.java +++ b/src/main/java/ampcontrol/admin/AmpControlMain.java @@ -8,11 +8,8 @@ import ampcontrol.audio.asio.AsioClassifierInputFactory; import ampcontrol.model.inference.Classifier; import ampcontrol.model.inference.ClassifierFromParameters; -import ampcontrol.model.training.model.vertex.ChannelMultVertex; -import ampcontrol.model.training.model.vertex.ElementWiseVertexLatest; import com.beust.jcommander.JCommander; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.factory.Nd4j; @@ -62,9 +59,9 @@ public static void main(String[] args) { try { // Might need to move into concrete Classifiers if something else is used in training - DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); + DataTypeUtil.setDTypeForContext(DataType.FLOAT16); Nd4j.getMemoryManager().setAutoGcWindow(5000); - NeuralNetConfiguration.registerLegacyCustomClassesForJSON(ChannelMultVertex.class, ElementWiseVertexLatest.class); + //NeuralNetConfiguration.registerLegacyCustomClassesForJSON(ChannelMultVertex.class, ElementWiseVertexLatest.class); final Classifier classifier = classifierFromParameters.getClassifier(inputProviderFactory); audioClassificationService.initialize( classificationListenerAgg, diff --git a/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java b/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java index 542c254c..c4cc0125 100644 --- a/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java +++ b/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java @@ -31,7 +31,7 @@ public class MidiProgChangeFactory implements AmpInterface.Factory { private Predicate device = Devices.audioBox44Vsl; @Parameter(names = {"-labelToProg", "-ltp"}, - description = "Comma separated list of how to map labels programs. First program is mapped to label 0 etc") + description = "Comma separated list of how to map labels to programs. First program is mapped to label 0 etc") private List programChangesList = Arrays.asList( 2, // Silence 2, // Noise diff --git a/src/main/java/ampcontrol/amp/probabilities/ArgMax.java b/src/main/java/ampcontrol/amp/probabilities/ArgMax.java index c677dbd6..06c6217c 100644 --- a/src/main/java/ampcontrol/amp/probabilities/ArgMax.java +++ b/src/main/java/ampcontrol/amp/probabilities/ArgMax.java @@ -14,6 +14,6 @@ public class ArgMax implements Interpreter { @Override public List apply(INDArray probabilities) { - return Collections.singletonList(probabilities.argMax(1).getInt(0)); + return Collections.singletonList(probabilities.argMax(probabilities.rank()-1).getInt(0)); } } diff --git a/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java b/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java index 88590bd3..404650eb 100644 --- a/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java +++ b/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java @@ -26,7 +26,7 @@ public ThresholdFilter(int index, double threshold, Interpreter next) { @Override public List apply(INDArray indArray) { - if(indArray.argMax(1).getInt(0) == index) { + if(indArray.argMax(indArray.rank()-1).getInt(0) == index) { if (indArray.getDouble(index) < threshold) { return new ArrayList<>(); } diff --git a/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java b/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java index 17210ba2..58c9a810 100644 --- a/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java +++ b/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java @@ -111,7 +111,7 @@ public static void main(String[] args) { driver.createBuffers(activeChannels); driver.start(); - + System.out.println("Driver started, sleep now"); Thread.sleep(2000); diff --git a/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java b/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java index 118d8260..b0be0391 100644 --- a/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java +++ b/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java @@ -1,7 +1,5 @@ package ampcontrol.audio.processing; -import org.jetbrains.annotations.Nullable; - /** * Creates a {@link ProcessingResult.Factory} based on a name string. Used for persistence since * all preprocessing needs be recreated when restoring a saved model. @@ -99,7 +97,6 @@ public ProcessingResult.Factory get(String nameStr) { return new UnitMaxZeroMean(); } - @Nullable private ProcessingResult.Factory getPipedSupplier( String nameStr, ProcessingResult.Factory first, diff --git a/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java b/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java index 485ed48a..cbc749fb 100644 --- a/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java +++ b/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java @@ -1,7 +1,7 @@ package ampcontrol.model.inference; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; @@ -19,7 +19,7 @@ class EnsembleWeightedSumClassifier implements Classifier { final static BiFunction avgNormalizer = (sumAcc, aggClass) -> aggClass.div(sumAcc); - final static BiFunction softMaxNormalizer = (sumAcc, aggClass) -> Nd4j.getExecutioner().execAndReturn(new OldSoftMax(aggClass)); + final static BiFunction softMaxNormalizer = (sumAcc, aggClass) -> Nd4j.getExecutioner().execAndReturn(new SoftMax(aggClass)).getOutputArgument(0); private final List ensemble; private final BiFunction normalizer; diff --git a/src/main/java/ampcontrol/model/inference/SpyClassifier.java b/src/main/java/ampcontrol/model/inference/SpyClassifier.java index f47a0734..8d505631 100644 --- a/src/main/java/ampcontrol/model/inference/SpyClassifier.java +++ b/src/main/java/ampcontrol/model/inference/SpyClassifier.java @@ -136,7 +136,7 @@ public INDArray classify() { } private void accumInput(INDArray classification) { - int highestProb = classification.argMax(1).getInt(0); + int highestProb = classification.argMax(classification.rank()-1).getInt(0); if(classification.getDouble(highestProb) > threshold) { if(highestProb == right) { diff --git a/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java b/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java index f793c7d7..54f58062 100644 --- a/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java +++ b/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java @@ -4,7 +4,6 @@ import ampcontrol.model.training.model.validation.listen.BestEvalScore; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.util.ModelSerializer; -import org.jetbrains.annotations.NotNull; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.IOException; @@ -31,7 +30,6 @@ public class StoredGraphClassifier implements Classifier { this.accuracy = new BestEvalScore(realFileName + ".score").get(); } - @NotNull private static String getHashedFileNameFromModelName(String path) { final String modelName = Paths.get(path).getFileName().toString(); String pathHashCode = modelName; diff --git a/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java b/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java index a667c8c9..05451cae 100644 --- a/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java +++ b/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java @@ -114,8 +114,7 @@ private void evalDecreaseKernelSize(ComputationGraph graph, String layer) { .map(layerConf -> (ConvolutionLayer) layerConf) .map(convConf -> { convConf.setKernelSize(new int[]{convConf.getKernelSize()[0] - 1, convConf.getKernelSize()[1]}); - convConf.setWeightInit(WeightInit.DISTRIBUTION); - convConf.setDist(new ConstantDistribution(0)); + convConf.setWeightInitFn(WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(0))); return convConf; }) .orElseThrow(() -> new IllegalArgumentException("Could not mutate layer from " + layerConfIn))) @@ -138,8 +137,7 @@ private void evalIncreaseKernelSize(ComputationGraph graph, String layer) { .map(layerConf -> (ConvolutionLayer) layerConf) .map(convConf -> { convConf.setKernelSize(new int[]{convConf.getKernelSize()[0] + 1, convConf.getKernelSize()[1]}); - convConf.setWeightInit(WeightInit.DISTRIBUTION); - convConf.setDist(new ConstantDistribution(0)); + convConf.setWeightInitFn(WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(0))); return convConf; }) .orElseThrow(() -> new IllegalArgumentException("Could not mutate layer from " + layerConfIn))) diff --git a/src/main/java/ampcontrol/model/training/TrainingDescription.java b/src/main/java/ampcontrol/model/training/TrainingDescription.java index 0ad5c25d..38dbfda2 100644 --- a/src/main/java/ampcontrol/model/training/TrainingDescription.java +++ b/src/main/java/ampcontrol/model/training/TrainingDescription.java @@ -17,7 +17,7 @@ import ampcontrol.model.visualize.RealTimePlot; import ch.qos.logback.classic.Level; import org.jetbrains.annotations.NotNull; -import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -79,7 +79,8 @@ public static void main(String[] args) { ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME); root.setLevel(Level.INFO); - DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); + DataTypeUtil.setDTypeForContext(DataType.FLOAT16); + DataTypeUtil.getDtypeFromContext(); List modelData = new ArrayList<>(); diff --git a/src/main/java/ampcontrol/model/training/TrainingHarness.java b/src/main/java/ampcontrol/model/training/TrainingHarness.java index 0582213e..501eca81 100644 --- a/src/main/java/ampcontrol/model/training/TrainingHarness.java +++ b/src/main/java/ampcontrol/model/training/TrainingHarness.java @@ -9,10 +9,10 @@ import ampcontrol.model.training.model.validation.*; import ampcontrol.model.training.model.validation.listen.*; import ampcontrol.model.visualize.Plot; -import org.deeplearning4j.api.storage.StatsStorage; +import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.ui.api.UIServer; -import org.deeplearning4j.ui.stats.StatsListener; -import org.deeplearning4j.ui.storage.FileStatsStorage; +import org.deeplearning4j.ui.model.stats.StatsListener; +import org.deeplearning4j.ui.model.storage.FileStatsStorage; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; diff --git a/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java b/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java index 50d02b33..94a2fc24 100644 --- a/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java +++ b/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java @@ -1,7 +1,7 @@ package ampcontrol.model.training.data; import ampcontrol.model.training.data.state.StateFactory; -import org.apache.commons.lang.mutable.MutableInt; +import org.apache.commons.lang3.mutable.MutableInt; import java.util.List; import java.util.function.Supplier; diff --git a/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java b/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java index 1c8e7fa3..ecdce9d9 100644 --- a/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java +++ b/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java @@ -3,7 +3,6 @@ import ampcontrol.model.training.data.iterators.MiniEpochDataSetIterator; import ampcontrol.model.training.data.state.ResetableState; import lombok.Builder; -import org.jetbrains.annotations.NotNull; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -82,7 +81,7 @@ public MiniEpochDataSetIterator create(Input input) { return factory.create(input.sourceInput); } - @NotNull + private DataSetIteratorFactory createFactory(Input input, long sizeOfOneBatch, long sizeOfWholeDataSet) { DataSetIteratorFactory factory; if (margin * sizeOfWholeDataSet > memoryAllowance) { @@ -93,7 +92,7 @@ private DataSetIteratorFactory createFactory(Input< return factory; } - @NotNull + private DataSetIteratorFactory createAsynchFactory(Input input, long sizeOfOneBatch, long sizeOfWholeDataSet) { DataSetIteratorFactory factory; log.info("Create Asynch iter for set of size {} with memory allowance {}", sizeOfWholeDataSet, memoryAllowance); @@ -105,7 +104,6 @@ private DataSetIteratorFactory createAsynchFactory( return factory; } - @NotNull private DataSetIteratorFactory createCachingFactory(Input input, long sizeOfWholeDataSet) { DataSetIteratorFactory factory; log.info("Create Caching iter for set of size {} with memory allowance {}", sizeOfWholeDataSet, memoryAllowance); diff --git a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java index dc492cf7..eac84034 100644 --- a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java +++ b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java @@ -4,8 +4,8 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; /** * {@link InputPreProcessor} which changes CNN 2D input to CNN 1D input. Assumes only one channel. diff --git a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java index fbb1b586..1d5e4f9c 100644 --- a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java +++ b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java @@ -4,10 +4,8 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; - -import javax.ws.rs.NotSupportedException; /** * {@link InputPreProcessor} which swaps height and width dimensions of CNN input. Intended use is when doing 1D @@ -45,6 +43,6 @@ public InputType getOutputType(InputType inputType) { @Override public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { - throw new NotSupportedException("Not implemented yet!"); + throw new UnsupportedOperationException("Not implemented yet!"); } } \ No newline at end of file diff --git a/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java b/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java index df4905ca..95346e5c 100644 --- a/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java +++ b/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java @@ -1,7 +1,7 @@ package ampcontrol.model.training.data.processing; import ampcontrol.model.training.data.state.StateFactory; -import org.apache.commons.lang.mutable.MutableInt; +import org.apache.commons.lang3.mutable.MutableInt; import java.nio.file.Path; import java.util.List; diff --git a/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java b/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java index 1f7f920d..0796a2e9 100644 --- a/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java +++ b/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java @@ -13,6 +13,7 @@ import org.nd4j.linalg.api.memory.enums.SpillPolicy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Map; import java.util.function.Consumer; @@ -61,7 +62,8 @@ private void setEps(INDArray eps) { try (MemoryWorkspace wss = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, wsName)) { final INDArray tmpEps = eps.migrate(false); int[] meanDims = IntStream.range(0, act.rank()).filter(dim -> dim != 1).toArray(); - listener.accept(tmpEps.muli(act).amean(meanDims)); + // amean seems broken in M2.1: it always returns a single element array + listener.accept(Transforms.abs(tmpEps.muli(act)).mean(meanDims)); } } } diff --git a/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java b/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java index d138b0ea..3a32b420 100644 --- a/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java +++ b/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java @@ -49,13 +49,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import lombok.Builder; import lombok.Getter; -import org.apache.commons.lang.mutable.MutableLong; +import org.apache.commons.lang3.mutable.MutableLong; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.jetbrains.annotations.NotNull; import org.nd4j.jita.memory.CudaMemoryManager; import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.factory.Nd4j; @@ -403,7 +402,7 @@ private AccessibleState> createInitialEvolutionState(Mu } } - @NotNull + private Population createPopulation( Map modelAgeMap, ModelComparatorRegistry comparatorRegistry, @@ -595,7 +594,7 @@ private Mutation createAddBlockMutation( .filter(mut -> rng.nextDouble() < 0.1)); } - @NotNull + private static Function createAfterGlobPoolLayerFactory(UnaryOperator spyFactory, Random rng) { final Function spyConfig = lbc -> new SpyBlock(lbc) @@ -608,7 +607,7 @@ private static Function createAfterGlobPoolLayerFactory( return nOut -> afterGpBlocks.andThen(spyConfig).apply(nOut); } - @NotNull + private static Function createBeforeGlobPoolLayerFactory(UnaryOperator spyFactory, Random rng) { final Function spyConfig = lbc -> new SpyBlock(lbc) .setFactory(spyFactory); diff --git a/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java b/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java index 7d811dce..9d934401 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java @@ -7,7 +7,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -82,7 +81,6 @@ GraphInfo execute() { return new GraphInfo.Result(builder, infoMap); } - @NotNull private GraphBuilder initBuilder() { final ComputationGraphConfiguration conf = bottom.builder().build(); final GraphBuilder builder = new GraphBuilder( @@ -93,7 +91,6 @@ private GraphBuilder initBuilder() { return builder; } - @NotNull private Map addTop(GraphBuilder builder) { // Need to access names which are already added to the builder below, so first create the mapping // between names in top and a name which is unique in builder. diff --git a/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java b/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java index 6f92908b..cf93925e 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java @@ -1,7 +1,5 @@ package ampcontrol.model.training.model.evolve.mutate; -import org.bytedeco.javacpp.IntPointer; -import org.bytedeco.javacpp.Pointer; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.function.Function; @@ -27,17 +25,17 @@ interface MemoryProvider { private static class DeviceMemoryProvider implements MemoryProvider { - private final Pointer devicePointer; + private final int device; private final double totalMemory; private DeviceMemoryProvider() { - this.devicePointer = new IntPointer(NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice()); - this.totalMemory = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceTotalMemory(devicePointer); + this.device = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); + this.totalMemory = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceTotalMemory(device); } @Override public double getUsage() { - return (totalMemory - NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(devicePointer)) / totalMemory; + return (totalMemory - NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(device)) / totalMemory; } } diff --git a/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java b/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java index 66062d2e..71c4fe7e 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java @@ -38,7 +38,7 @@ public LayerBlockConfig apply(Long nOut) { public static SpyFunction weightInit(Function source, WeightInit weightInit) { return new SpyFunction(factory -> new LayerSpyAdapter((layerName, layer, layerInputs) -> { if (layer instanceof BaseLayer) { - ((BaseLayer) layer).setWeightInit(weightInit); + ((BaseLayer) layer).setWeightInitFn(weightInit.getWeightInitFunction()); } }, factory), source); } diff --git a/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java b/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java index 883cf15a..b90afffb 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java @@ -14,8 +14,6 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; import java.util.function.Function; @@ -269,7 +267,7 @@ private TransferTask.ListBuilder transferOtherWeights(TransferContext transferCo .addDependentTask(createDependentTask( registry.register(paramPair.source.get(DefaultParamInitializer.BIAS_KEY), paramPair.layerName + "_source_b"), registry.register(paramPair.target.get(DefaultParamInitializer.BIAS_KEY), paramPair.layerName + "_target_b"), - dim -> 1, // Always 1 for bias! + dim -> 0, // Always 0 for bias! transferContext.inputDimension, 2, 3, 4) ); } @@ -526,11 +524,11 @@ private static void setIdentityMapping(INDArray weights) { weights.assign(Nd4j.eye(weights.size(0))); } else if (weights.shape().length == 4 && weights.size(2) % 2 == 1 && weights.size(3) % 2 == 1) { weights.assign(Nd4j.zeros(weights.shape())); - final long centerH = weights.size(2) / 2; - final long centerW = weights.size(3) / 2; + final int centerH = (int)weights.size(2) / 2; + final int centerW = (int)weights.size(3) / 2; + for (int i = 0; i < weights.size(0); i++) { - weights.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(i), NDArrayIndex.point(centerH), NDArrayIndex.point(centerW)}, - Nd4j.ones(1)); + weights.put(new int[] {i, i, centerH, centerW}, Nd4j.ones(1)); } } } diff --git a/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java b/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java index aa89116f..9bd18af5 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java @@ -23,15 +23,12 @@ @Builder(builderClassName = "Builder", buildMethodName = "buildInternal") public class SingleTransferTask implements TransferTask { - @lombok.Builder.Default - private final IndMapping source = IndMapping.builder().build(); - @lombok.Builder.Default - private final IndMapping target = IndMapping.builder().build(); + private final IndMapping source; + private final IndMapping target; private final Function> compFactory; @lombok.Singular("maskDim") private final Set dimensionMask; - @lombok.Builder.Default - private final TransferTask dependentTask = new NoTransferTask(); + private final TransferTask dependentTask; @lombok.Builder(builderClassName = "Builder") @@ -111,12 +108,15 @@ public void execute() { } public static class Builder implements ListBuilder { + private IndMapping source = IndMapping.builder().build(); + private IndMapping target = IndMapping.builder().build(); + private TransferTask dependentTask = new NoTransferTask(); private ListBuilder dependentTaskBuilder = NoTransferTask.builder(); // To trick lombok so all boilerplate is done in autogenerated buildInternal public SingleTransferTask build() { target.getEntry().put(source.getEntry()); - dependentTask(dependentTaskBuilder.build()); + setDependentTask(dependentTaskBuilder.build()); return this.buildInternal(); } @@ -132,7 +132,7 @@ public Builder addDependentTask(ListBuilder dependentTaskBuilder) { return this; } - private Builder dependendTask(TransferTask dependentTask) { + private Builder setDependentTask(TransferTask dependentTask) { this.dependentTask = dependentTask; return this; } diff --git a/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java b/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java index 68a3dcfe..9ed1c661 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java @@ -20,7 +20,7 @@ */ class TransferRegistry { - private final Map registry = new HashMap<>(); + private final Map registry = new IdentityHashMap<>(); private final Map actions = new LinkedHashMap<>(); class ArrayEntry { @@ -61,6 +61,9 @@ public int compare(Integer e1, Integer e2) { if(e1.equals(e2)) { return 0; } + if(tensorDimensions.length == 0) { + return -Double.compare(Math.abs(array.getDouble(e1)), Math.abs(array.getDouble(e2))); + } return -Double.compare( abs(array.tensorAlongDimension(e1, tensorDimensions)).sumNumber().doubleValue(), @@ -140,6 +143,10 @@ private INDArray get() { } catch (ND4JIllegalStateException e) { throw new ND4JIllegalStateException("Could not get array " + debugName + "! Target array of shape " + Arrays.toString(array.shape()) + ". Wanted indexes " + Arrays.toString(asIndArray()), e); + } catch (NullPointerException e) { + // Workaround for what seems to be a bug in dl4j: It sometimes gets has a nullpointer for some array indices + // which only shows up when using certain INDArrayIndex types (e.g. index, indices, SpecifiedIndex). + return addBackSingletonDimensions(array.dup().get(asIndArray())); } } @@ -173,7 +180,6 @@ private INDArrayIndex merge(INDArrayIndex index1, Optional index2 private INDArrayIndex[] asIndArray() { return IntStream.range(0, array.rank()) .mapToObj(dim -> indexMap.getOrDefault(dim, NDArrayIndex.all())) - .peek(INDArrayIndex::reset) .toArray(INDArrayIndex[]::new); } diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java index 0612e3f6..601448be 100644 --- a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java +++ b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java @@ -6,6 +6,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -51,8 +52,8 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { - return new ChannelMultVertexImpl(graph, name, idx); + INDArray paramsView, boolean initializeParams, DataType dataType) { + return new ChannelMultVertexImpl(graph, name, idx, dataType); } @Override diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java index 759f17b9..963fcafb 100644 --- a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java +++ b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java @@ -9,8 +9,10 @@ import org.deeplearning4j.nn.graph.vertex.VertexIndices; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; + /** * {@link BaseGraphVertex} which multiplies each channel in a convolutional activation of size [b,c,h,w] with a scalar @@ -21,14 +23,14 @@ */ public class ChannelMultVertexImpl extends BaseGraphVertex { - public ChannelMultVertexImpl(ComputationGraph graph, String name, int vertexIndex) { - this(graph, name, vertexIndex, null, null); + public ChannelMultVertexImpl(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { + this(graph, name, vertexIndex, null, null, dataType); } public ChannelMultVertexImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); } @Override @@ -87,8 +89,8 @@ private INDArray scaleChannels(INDArray channelActivations, INDArray scaleFactor // Goal: multiply each h*w activation a_i with the corresponding scale factor s_i, i in range [0, b*c-1] - final long nrofChannelBatch = channelActivations.tensorssAlongDimension(2, 3); // view each channel and each batch - final long nrofFeaturesInActivation = channelActivations.tensorssAlongDimension(0, 1); // view each h and w activation for all batches + final long nrofChannelBatch = channelActivations.tensorsAlongDimension(2, 3); // view each channel and each batch + final long nrofFeaturesInActivation = channelActivations.tensorsAlongDimension(0, 1); // view each h and w activation for all batches // From empiric testing: Whatever makes the fewest number of loops is fastest if (nrofChannelBatch < nrofFeaturesInActivation) { diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java index eed0bb20..ddbab6b6 100644 --- a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java +++ b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -90,7 +91,7 @@ public int maxVertexInputs() { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, - INDArray paramsView, boolean initializeParams) { + INDArray paramsView, boolean initializeParams, DataType dataType) { ElementWiseVertexLatestImpl.Op op; switch (this.op) { case Add: @@ -108,7 +109,7 @@ public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGra default: throw new UnsupportedOperationException("No support for op: " + this.op); } - return new ElementWiseVertexLatestImpl(graph, name, idx, op); + return new ElementWiseVertexLatestImpl(graph, name, idx, op, dataType); } @Override diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java index 607f6463..e78a2ada 100644 --- a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java +++ b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java @@ -25,10 +25,11 @@ import org.deeplearning4j.nn.graph.vertex.VertexIndices; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.Or; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; /** An ElementWiseVertex is used to combine the activations of two or more layer in an element-wise manner
* For example, the activations may be combined by addition, subtraction or multiplication or by selecting the maximum. @@ -46,13 +47,13 @@ public enum Op { private Op op; private int nInForwardPass; - public ElementWiseVertexLatestImpl(ComputationGraph graph, String name, int vertexIndex, Op op) { - this(graph, name, vertexIndex, null, null, op); + public ElementWiseVertexLatestImpl(ComputationGraph graph, String name, int vertexIndex, Op op, DataType dataType) { + this(graph, name, vertexIndex, null, null, op, dataType); } public ElementWiseVertexLatestImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, Op op) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + VertexIndices[] outputVertices, Op op, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.op = op; } @@ -179,6 +180,7 @@ public Pair feedForwardMaskArrays(INDArray[] maskArrays, Ma return new Pair<>(maskArrays[0], currentMaskState); } else { INDArray ret = maskArrays[0].dup(maskArrays[0].ordering()); + Nd4j.getExecutioner().exec(new Or(maskArrays[0], maskArrays[1], ret)); for (int i = 2; i < maskArrays.length; i++) { Nd4j.getExecutioner().exec(new Or(maskArrays[i], ret, ret)); diff --git a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java index 8cda9fac..e823853b 100644 --- a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java +++ b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java @@ -6,6 +6,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -50,8 +51,8 @@ public int maxVertexInputs() { } @Override - public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { - return new EpsilonSpyVertexImpl(graph, name, idx); + public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType dataType) { + return new EpsilonSpyVertexImpl(graph, name, idx, dataType); } diff --git a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java index 6d6d43f5..cb95047e 100644 --- a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java +++ b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java @@ -8,8 +8,9 @@ import org.deeplearning4j.nn.graph.vertex.VertexIndices; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.Optional; import java.util.function.Consumer; @@ -25,12 +26,12 @@ public class EpsilonSpyVertexImpl extends BaseGraphVertex { private Consumer listener; - protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex) { - this(graph, name, vertexIndex, null, null); + protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { + this(graph, name, vertexIndex, null, null, dataType); } - protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices) { - super(graph, name, vertexIndex, inputVertices, outputVertices); + protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, DataType dataType) { + super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); } public void setListener(Consumer listener) { diff --git a/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java b/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java index 236a9eba..e7178cc2 100644 --- a/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java +++ b/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java @@ -11,7 +11,7 @@ import java.util.function.Supplier; import java.util.stream.Stream; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; /** * Test cases for {@link Cnn2DInputProvider}. @@ -38,7 +38,10 @@ public void getModelInput() { final ProcessingResult res0 = proc.create(new SingletonDoubleInput(audioFrames[0])); final INDArray expected0 = Nd4j.create(res0.stream().findFirst().get()); for(int channel = 0; channel < nrofChannels; channel++) { - assertEquals("Incorrect model input!", expected0, inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel))); + final INDArray res = inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel)); + assertArrayEquals("Incorrect shape!", expected0.shape(), res.shape()); + assertArrayEquals("Incorrect model input!", expected0.reshape(expected0.length()).toDoubleVector(), + res.reshape(res.length()).toDoubleVector(), 1e-10); } mockBuffer.advance(); @@ -46,7 +49,10 @@ public void getModelInput() { final ProcessingResult res1 = proc.create(new SingletonDoubleInput(audioFrames[1])); final INDArray expected1 = Nd4j.create(res1.stream().findFirst().get()); for(int channel = 0; channel < nrofChannels; channel++) { - assertEquals("Incorrect model input!", expected1, inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel))); + final INDArray res = inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel)); + assertArrayEquals("Incorrect shape!", expected1.shape(), res.shape()); + assertArrayEquals("Incorrect model input!", expected1.reshape(expected1.length()).toDoubleVector(), + res.reshape(res.length()).toDoubleVector(), 1e-10); } } diff --git a/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java b/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java index d5f4b188..aff9c3f1 100644 --- a/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java +++ b/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java @@ -62,19 +62,10 @@ void testMultiModelEnsemble(BiFunction normalizer) { final INDArray result = classifier.classify(); mockEnsemble.forEach(mockClassifier -> mockClassifier.assertCalled(true)); - assertEquals("Incorrect result!", 1d, result.sum(1).getDouble(0), 1e-10); + assertEquals("Incorrect result!", 1d, result.sum(0).getDouble(0), 1e-10); assertEquals("Incorrect result!", 3, result.argMax().getInt(0)); } - /** - * Fails with CPU backend - */ - @Test - public void testSum() { - final INDArray test = Nd4j.create(new double[]{1, 2}); - assertEquals("Incorrect result!", Nd4j.create(new double[] {3}), test.sum(1)); - } - private static void testSingleModelEnsemble(BiFunction normalizer) { final List mockEnsemble = Arrays.asList( diff --git a/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java b/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java index b04410f4..2fb60b91 100644 --- a/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java +++ b/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** @@ -44,8 +45,9 @@ public void spyInputTypes() { incorrectClassifiction, // 12 incorrectClassifiction // 13 ); - final INDArray expectedCorrect = Nd4j.create(new double[]{6, 7, 8}); - final INDArray expectedIncorrect = Nd4j.create(new double[]{9, 10, 11}); + + final INDArray expectedCorrect = Nd4j.create(new double[]{6, 7, 8}).reshape(3, 1); + final INDArray expectedIncorrect = Nd4j.create(new double[]{9, 10, 11}).reshape(3, 1); final SpyClassifier spyClassifier = new SpyClassifier( new MockClassifier("mock", 0.7, classifications), new CountingInputProvider(), @@ -60,8 +62,8 @@ public void spyInputTypes() { assertEquals("Incorrect hasInput after " + i +" samples!", i >= 11, listener.hasInput()); } - assertEquals("Incorrect correctly classified input!", expectedCorrect, listener.getCorrectlyClassifiedInput()); - assertEquals("Incorrect incorrectly classified input!", expectedIncorrect, listener.getIncorrectlyClassifiedInput()); + assertArrayEquals("Incorrect correctly classified input!", expectedCorrect.toDoubleVector(), listener.getCorrectlyClassifiedInput().toDoubleVector(), 1e-10); + assertArrayEquals("Incorrect incorrectly classified input!", expectedIncorrect.toDoubleVector(), listener.getIncorrectlyClassifiedInput().toDoubleVector(), 1e-10); } private static class CountingInputProvider implements ClassifierInputProvider { diff --git a/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java b/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java index 855e3720..44730d78 100644 --- a/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java +++ b/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java @@ -76,7 +76,7 @@ public void eval() { nrofEvalCalls++; validations.stream().map(Validation::get) - .forEach(ieOpt -> ieOpt.ifPresent(ie -> ie.eval(Nd4j.create(result), Nd4j.zeros(labels.size())))); + .forEach(ieOpt -> ieOpt.ifPresent(ie -> ie.eval(Nd4j.create(result).reshape(labels.size(), 1), Nd4j.zeros(labels.size(), 1)))); validations.forEach(Validation::notifyComplete); } diff --git a/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java b/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java index 43c8ca6c..4d89c6e0 100644 --- a/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java +++ b/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java @@ -7,7 +7,6 @@ import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; /** * Test cases for {@link CnnToManyToOneRnnPreProcessor}. @@ -32,7 +31,8 @@ public void preProcess() { pp.preProcess(testSet); assertArrayEquals("Incorrect feature shape!", expectedFeatureShape, testSet.getFeatures().shape()); assertArrayEquals("Incorrect labels shape!", expectedLabelsShape, testSet.getLabels().shape()); - assertEquals("Incorrect mask!", expectedLabelsMask, testSet.getLabelsMaskArray()); + assertArrayEquals("Incorrect mask!", expectedLabelsMask.reshape(expectedLabelsMask.length()).toDoubleVector(), + testSet.getLabelsMaskArray().reshape(testSet.getLabelsMaskArray().length()).toDoubleVector(), 1e-10); } diff --git a/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java b/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java index 28154c4a..b44c61fb 100644 --- a/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java +++ b/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java @@ -106,10 +106,10 @@ static DataSetIterator createTestDataSetIterator(Supplier state) { @Override public synchronized void fetch(int numExamples) { - final double[] features = new double[batchSize]; + final double[][] features = new double[batchSize][1]; final double[][] labels = new double[batchSize][10]; IntStream.range(0, batchSize).forEach(batch -> { - features[batch] = state.get().intValue(); + features[batch][0] = state.get().intValue(); labels[batch][state.get().intValue() % 10] = 1; state.get().increment(); }); diff --git a/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java b/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java index a80e8572..d904f4bf 100644 --- a/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java +++ b/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java @@ -63,7 +63,7 @@ public void dense() { activationContribution.onEpochEnd(graph); final INDArray expected = Nd4j.create(new double[]{0, 2}); // gradient is 0,2 and "activation" is 1,1 probe.assertNrofCalls(1); - assertEquals("Incorrect output!", expected, probe.last); + assertArrayEquals("Incorrect output!", expected.toDoubleVector(), probe.last.toDoubleVector(), 1e-10); } /** @@ -93,7 +93,7 @@ public void convWithBias() { final Probe probe = new Probe(); final ActivationContribution activationContribution = new ActivationContribution(layerName, probe); graph.addListeners(activationContribution); - final int nrofOutputs = graph.layerSize(outputName); + final long nrofOutputs = graph.layerSize(outputName); final INDArray feature = Nd4j.linspace(-10, 10, heigh * width * depth * miniBatchSize).reshape(miniBatchSize, depth, heigh, width); final INDArray label = Nd4j.linspace(-2, 2, nrofOutputs * miniBatchSize).reshape(miniBatchSize, nrofOutputs); @@ -101,7 +101,7 @@ public void convWithBias() { activationContribution.onEpochEnd(graph); probe.assertNrofCalls(1); - assertArrayEquals("Incorrect output shape", new long[] {1, nOut}, probe.last.shape()); + assertArrayEquals("Incorrect output shape", new long[] {nOut}, probe.last.shape()); } @NotNull diff --git a/src/test/java/ampcontrol/model/training/listen/MockModel.java b/src/test/java/ampcontrol/model/training/listen/MockModel.java index 789c4a20..e6a9127c 100644 --- a/src/test/java/ampcontrol/model/training/listen/MockModel.java +++ b/src/test/java/ampcontrol/model/training/listen/MockModel.java @@ -6,8 +6,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.primitives.Pair; import java.util.Collection; import java.util.Map; @@ -169,4 +169,9 @@ public void clear() { public void applyConstraints(int iteration, int epoch) { //Ignore } + + @Override + public void close() { + //Ignore + } } \ No newline at end of file diff --git a/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java b/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java index 67ecb51d..0405d940 100644 --- a/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java +++ b/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java @@ -91,7 +91,9 @@ private static class PosNegDataSetIter implements DataSetIterator { public DataSet next(int num) { final double[] terms1 = rng.ints(num, -10, 10).mapToDouble(i -> i).toArray(); final double[][] evenOrOdd = DoubleStream.of(terms1).mapToInt(d -> (int) d).mapToObj(i -> i > 0 ? positive : negative).collect(Collectors.toList()).toArray(new double[][]{}); - return new DataSet(Nd4j.create(terms1).transpose(), Nd4j.create(evenOrOdd)); + + + return new DataSet(Nd4j.create(terms1).reshape(terms1.length, 1), Nd4j.create(evenOrOdd)); } @Override diff --git a/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java b/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java index 69ca92c3..e363a3c5 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java @@ -53,7 +53,9 @@ public static ComputationGraph getCnnGraph(String conv1Name, String conv2Name, S .addLayer(conv1Name, new Convolution2D.Builder(3, 3) .nOut(10) .build(), inputName) - .addLayer(batchNormName, new BatchNormalization.Builder().build(), conv1Name) + .addLayer(batchNormName, new BatchNormalization.Builder() + .useLogStd(false) + .build(), conv1Name) .addLayer(conv2Name, new Convolution2D.Builder(1, 1) .nOut(5) .build(), batchNormName) diff --git a/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java b/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java index ab92eb63..1dcce1eb 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java @@ -61,6 +61,14 @@ public void crossSimpleDense() { final ComputationGraph graphFirst = GraphUtils.getGraph(first[0], first[1], first[2]); final ComputationGraph graphSecond = GraphUtils.getGraph(second[0], second[1], second[2]); + crossover( + InputType.feedForward(33), + graphFirst, + first[1], + graphSecond, + second[0] + ); + for (String firstCp : first) { for (String secondCp : second) { crossover( diff --git a/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java b/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java index dab5f2e5..c4c4e295 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java @@ -61,6 +61,6 @@ public void instrumentAndFit() { assertFalse("Did not expect comparator!", registry.get(graph).apply("second").isPresent()); assertTrue("Expected comparator!", registry.get(graph).apply("third").isPresent()); - graph.fit(new DataSet(arr, Nd4j.ones(1))); + graph.fit(new DataSet(arr, Nd4j.ones(1,1))); } } \ No newline at end of file diff --git a/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java b/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java index 4821b22e..a89b629e 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java @@ -12,13 +12,14 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.Test; +import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; import java.util.stream.Stream; import static junit.framework.TestCase.assertEquals; +import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.XENT; /** * Test cases for {@link NoutMutation} @@ -416,7 +417,7 @@ public void mutateAfterResBeforeSizeTransparentFork() { .addVertex("rbMvInput0", new MergeVertex(), "fb1_branch_0_0", "fb1_branch_1_0") .addLayer("3", new Convolution2D.Builder().convolutionMode(ConvolutionMode.Same).nOut(7).build(), "rbMvInput0") .addLayer("4", new GlobalPoolingLayer(), "3") - .addLayer("output", new CenterLossOutputLayer.Builder().nOut(2).lossFunction(new LossBinaryXENT()).build(), "4") + .addLayer("output", new CenterLossOutputLayer.Builder().nOut(2).activation(Activation.SIGMOID).lossFunction(XENT).build(), "4") .build()); graph.init(); diff --git a/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java b/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java index 3166cb03..05b03041 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java @@ -53,6 +53,6 @@ public void weightInit() { spiedLayers.stream() .filter(layer -> layer instanceof BaseLayer) .map(layer -> (BaseLayer)layer) - .forEach(baseLayer -> assertEquals("Incorrect weight init!", expected, baseLayer.getWeightInit())); + .forEach(baseLayer -> assertEquals("Incorrect weight init!", expected.getWeightInitFunction(), baseLayer.getWeightInitFn())); } } \ No newline at end of file diff --git a/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java b/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java index 8b65b0cc..8a385bfd 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java @@ -95,13 +95,13 @@ public void childrenOfComputationGraph() { final Graph graph = new ForwardOf(compGraph); assertEquals("Incorrect children!", - Arrays.asList("vertex1_0", "vertex1_1"), graph.children("input1").collect(Collectors.toList())); + Arrays.asList("vertex1_0", "vertex1_1"), graph.children("input1").sorted().collect(Collectors.toList())); assertEquals("Incorrect children!", Collections.singletonList("vertex2_0"), graph.children("input2").collect(Collectors.toList())); assertEquals("Incorrect children!", - Arrays.asList("vertex3_0", "vertex3_1"), graph.children("vertex1_0").collect(Collectors.toList())); + Arrays.asList("vertex3_0", "vertex3_1"), graph.children("vertex1_0").sorted().collect(Collectors.toList())); assertEquals("Incorrect children!", Collections.emptyList(), graph.children("vertex2_0").collect(Collectors.toList())); diff --git a/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java b/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java index cfa80fbc..9f20ad45 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java @@ -22,17 +22,17 @@ public class MergeTransferBufferTest { public void transferFirstSource1D() { final long length1 = 5; final long length2 = 8; - final INDArray source1 = Nd4j.linspace(0, length1 - 1, length1); + final INDArray source1 = Nd4j.linspace(0, length1 - 1, length1).reshape(1, length1); final INDArray target1 = Nd4j.zeros(1, source1.length() - 2); - final INDArray source2 = Nd4j.linspace(source1.length(), source1.length() + length2 - 1, length2); + final INDArray source2 = Nd4j.linspace(source1.length(), source1.length() + length2 - 1, length2).reshape(1, length2); final INDArray target2 = Nd4j.zeros(1, source2.length()); - final INDArray mergedSource = Nd4j.linspace(0, length1 + length2 - 1, length1 + length2); + final INDArray mergedSource = Nd4j.linspace(0, length1 + length2 - 1, length1 + length2).reshape(1, length1+length2); final INDArray mergedTarget = Nd4j.zeros(1, length1 + length2 - 2); final int[] expectedIndexes = {0, 2, 4, 1, 3}; - final INDArray expectedTarget1 = source1.get(new SpecifiedIndex(expectedIndexes)).get(NDArrayIndex.interval(0, target1.length())).transpose(); + final INDArray expectedTarget1 = source1.get(NDArrayIndex.point(0), new SpecifiedIndex(expectedIndexes)).get(NDArrayIndex.interval(0, target1.length())).reshape(1, target1.length() ); final INDArray expectedMergedTarget = Nd4j.concat(1, expectedTarget1, source2); final TransferRegistry registry = new TransferRegistry(); diff --git a/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java b/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java index a41751bd..9b4a0dfa 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java @@ -24,10 +24,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; +import java.util.*; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -251,7 +248,7 @@ public void decreaseForkedConv() { newGraph.init(); newGraph.output(Nd4j.randn(new long[]{1, 3, 33, 33})); - final int oldNout = graph.layerSize(fork1NameToMutate); + final int oldNout = (int)graph.layerSize(fork1NameToMutate); // Drop the first oldNout - newNout elements from fork1NameToMutate final int[] orderToKeep = IntStream.range(0, oldNout) @@ -271,7 +268,7 @@ public void decreaseForkedConv() { final int[] expectedToKeep = IntStream.concat( IntStream.of(orderToKeep).limit(newNout), - IntStream.range(0, graph.layerSize(fork2Name)).map(i -> i + oldNout)).toArray(); + IntStream.range(0, (int)graph.layerSize(fork2Name)).map(i -> i + oldNout)).toArray(); final INDArray sourceAfter = graph.getLayer(afterName).getParam(GraphUtils.W); final INDArray targetAfter = mutatedGraph.getLayer(afterName).getParam(GraphUtils.W); assertDims(1, expectedToKeep, sourceAfter, targetAfter); @@ -417,7 +414,7 @@ ComputationGraph decreaseDecreaseNout( final INDArray sourceBias = graph.getLayer(mutationName).getParam(GraphUtils.B); final INDArray targetBias = mutatedGraph.getLayer(mutationName).getParam(GraphUtils.B); - assertDims(1, orderToKeepFirst, sourceBias, targetBias); + assertDims(0, orderToKeepFirst, sourceBias, targetBias); final INDArray sourceNext = graph.getLayer(nextMutationName).getParam(GraphUtils.W); final INDArray targetNext = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.W); @@ -425,7 +422,7 @@ ComputationGraph decreaseDecreaseNout( final INDArray sourceNextBias = graph.getLayer(nextMutationName).getParam(GraphUtils.B); final INDArray targetNextBias = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.B); - assertDims(1, orderToKeepSecond, sourceNextBias, targetNextBias); + assertDims(0, orderToKeepSecond, sourceNextBias, targetNextBias); final INDArray sourceOutput = graph.getLayer(afterName).getParam(GraphUtils.W); final INDArray targetOutput = mutatedGraph.getLayer(afterName).getParam(GraphUtils.W); @@ -451,7 +448,7 @@ ComputationGraph decreaseIncreaseNout( final long mutationNewNout = 5; final long nextMutationNewNout = 9; - final int nextMutationPrevNout = graph.layerSize(nextMutationName); + final int nextMutationPrevNout = (int)graph.layerSize(nextMutationName); final double nextMutationNewVal = 666d; // Is this obtainable somehow? final ComputationGraph newGraph = new ComputationGraph(new NoutMutation( @@ -477,7 +474,7 @@ ComputationGraph decreaseIncreaseNout( final INDArray sourceBias = graph.getLayer(mutationName).getParam(GraphUtils.B); final INDArray targetBias = mutatedGraph.getLayer(mutationName).getParam(GraphUtils.B); - assertDims(1, orderToKeepFirst, sourceBias, targetBias); + assertDims(0, orderToKeepFirst, sourceBias, targetBias); final INDArray sourceNext = graph.getLayer(nextMutationName).getParam(GraphUtils.W); final INDArray targetNext = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.W); @@ -486,8 +483,8 @@ ComputationGraph decreaseIncreaseNout( final INDArray sourceNextBias = graph.getLayer(nextMutationName).getParam(GraphUtils.B); final INDArray targetNextBias = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.B); - assertDims(1, IntStream.range(0, (int) nextMutationNewNout).toArray(), sourceNextBias, targetNextBias); - assertScalar(1, nextMutationPrevNout, 0, targetNextBias); + assertDims(0, IntStream.range(0, (int) nextMutationNewNout).toArray(), sourceNextBias, targetNextBias); + assertScalar(0, nextMutationPrevNout, 0, targetNextBias); final INDArray sourceOutput = graph.getLayer(afterName).getParam(GraphUtils.W); final INDArray targetOutput = mutatedGraph.getLayer(afterName).getParam(GraphUtils.W); @@ -511,10 +508,18 @@ private static void assertDims( final long[] shapeTarget = target.shape(); final long[] shapeSource = source.shape(); final int[] dims = IntStream.range(0, shapeTarget.length).filter(i -> i != dim).toArray(); - for (int elemInd = 0; elemInd < Math.min(shapeTarget[dim], shapeSource[dim]); elemInd++) { - assertEquals("Incorrect target for element index " + elemInd + "!", - source.tensorAlongDimension(orderToKeep[elemInd], dims), - target.tensorAlongDimension(elemInd, dims)); + if(dims.length == 0) { + for (int elemInd = 0; elemInd < Math.min(shapeTarget[dim], shapeSource[dim]); elemInd++) { + assertEquals("Incorrect target for element index " + elemInd + "!", + source.getDouble(orderToKeep[elemInd]), + target.getDouble(elemInd), 1e-10); + } + } else { + for (int elemInd = 0; elemInd < Math.min(shapeTarget[dim], shapeSource[dim]); elemInd++) { + assertEquals("Incorrect target for element index " + elemInd + "!", + source.tensorAlongDimension(orderToKeep[elemInd], dims), + target.tensorAlongDimension(elemInd, dims)); + } } } @@ -567,9 +572,19 @@ private static void assertDoubleDims( for (int elemInd0 = 0; elemInd0 < Math.min(shapeTarget[outputDim], shapeSource[outputDim]); elemInd0++) { for (int elemInd1 = 0; elemInd1 < Math.min(shapeTarget[inputDim], shapeSource[inputDim]); elemInd1++) { - assertEquals("Incorrect target output for element index " + elemInd0 + ", " + elemInd1 + "!", - source.tensorAlongDimension(expectedElementOrderDim0[elemInd0], firstTensorDims).tensorAlongDimension(expectedElementOrderDim1[elemInd1], secondTensorDims), - target.tensorAlongDimension(elemInd0, firstTensorDims).tensorAlongDimension(elemInd1, secondTensorDims)); + // Somewhat convoluted code to get old behaviour where tensorAlongDimension could return a 0d array from a 1d array + final INDArray sourcedim0 = source.tensorAlongDimension(expectedElementOrderDim0[elemInd0], firstTensorDims); + final INDArray targetdim0 = target.tensorAlongDimension(elemInd0, firstTensorDims); + + if(shapeSource.length == 2) { + assertEquals("Incorrect target output for element index " + elemInd0 + ", " + elemInd1 + "!", + sourcedim0.getDouble(expectedElementOrderDim1[elemInd1]), + targetdim0.getDouble(elemInd1), 1e-10); + } else { + assertEquals("Incorrect target output for element index " + elemInd0 + ", " + elemInd1 + "!", + sourcedim0.tensorAlongDimension(expectedElementOrderDim1[elemInd1], secondTensorDims), + targetdim0.tensorAlongDimension(elemInd1, secondTensorDims)); + } } } } diff --git a/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java b/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java index dc16ccfe..fc7d91f6 100644 --- a/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java +++ b/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java @@ -55,7 +55,10 @@ public void applySizeDecrease() { } }}); - assertEquals("Incorrect output!", expected, target); + assertEquals("Lengths does not match!", expected.length(), target.length()); + for(int i = 0; i < expected.length(); i++) { + assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), target.getDouble(i), 1e-6); + } } /** @@ -89,7 +92,10 @@ public void applySizeDecreaseEdges() { {{0}, {3}}}, }); - assertEquals("Incorrect output!", expected, target); + assertEquals("Lengths does not match!", expected.length(), target.length()); + for(int i = 0; i < expected.length(); i++) { + assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), target.getDouble(i), 1e-10); + } } /** @@ -143,7 +149,11 @@ public int compare(Integer e1, Integer e2) { {{{3}}, {{2}}, {{0}}}, }); - assertEquals("Incorrect output!", expected, target); + assertEquals("Lengths does not match!", expected.length(), target.length()); + for(int i = 0; i < expected.length(); i++) { + assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), target.getDouble(i), 1e-10); + } + } /** @@ -246,8 +256,8 @@ public void applyMaskCoupled() { for (int elemInd0 = 0; elemInd0 < shapeTarget[0]; elemInd0++) { for (int elemInd1 = 0; elemInd1 < shapeTarget[1]; elemInd1++) { assertEquals("Incorrect target for element index " + elemInd0 + "," + elemInd1 + "!", - source.tensorAlongDimension(orderToKeep[elemInd0], 1).tensorAlongDimension(orderToKeep[elemInd1], 0), - target.tensorAlongDimension(elemInd0, 1).tensorAlongDimension(elemInd1, 0)); + source.tensorAlongDimension(orderToKeep[elemInd0], 1).getDouble(orderToKeep[elemInd1]), + target.tensorAlongDimension(elemInd0, 1).getDouble(elemInd1), 1e-10); } } diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java index 9094efb5..1c8fadf6 100644 --- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java +++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java @@ -5,7 +5,6 @@ import org.deeplearning4j.nn.conf.distribution.BinomialDistribution; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.impl.LossMSE; @@ -29,8 +28,7 @@ public BlockInfo addLayers(BuilderAdapter builder, BlockInfo info) { .nOut(info.getPrevNrofOutputs()) .activation(new ActivationIdentity()) .biasInit(0) - .weightInit(WeightInit.DISTRIBUTION) - .dist(new BinomialDistribution(1,1)) // 100% probability of 1 + .weightInit(new BinomialDistribution(1,1)) // 100% probability of 1 .build(); return builder.layer(info,output); } diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java index 86107a18..ba5d1798 100644 --- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java +++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java @@ -43,7 +43,7 @@ public String name() { @Test public void addLayers() { final String inputName = "input"; - final INDArray input = Nd4j.create(new double[] {-1.23, 2.34}).transpose(); + final INDArray input = Nd4j.create(new double[] {-1.23, 2.34}).reshape(2, 1); final double scal0 = 10.11; final double scal1 = -13.3; @@ -66,6 +66,10 @@ public void addLayers() { DummyOutputLayer.setEyeOutput(graph); final INDArray expected = Nd4j.hstack(input.mul(scal0), input.mul(scal1)); - assertEquals("Incorrect output!", expected, graph.output(input)[0]); + final INDArray result = graph.output(input)[0]; + assertEquals("Lengths does not match!", expected.length(), result.length()); + for(int i = 0; i < expected.length(); i++) { + assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), result.getDouble(i), 1e-6); + } } } \ No newline at end of file diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java index 56e08a4b..2b6e58ff 100644 --- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java +++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java @@ -1,14 +1,13 @@ package ampcontrol.model.training.model.layerblocks.graph; import ampcontrol.model.training.model.layerblocks.LayerBlockConfig; +import junit.framework.TestCase; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; - /** * Test cases for {@link MinMaxPool} * @@ -39,6 +38,12 @@ public void addLayers() { DummyOutputLayer.setEyeOutput(graph); - assertEquals("Incorrect output", Nd4j.create(expected), graph.output(input)[0]); + final INDArray expectedarr = Nd4j.create(expected); + final INDArray result = graph.output(input)[0]; + TestCase.assertEquals("Lengths does not match!", expectedarr.length(), result.length()); + for(int i = 0; i < expectedarr.length(); i++) { + TestCase.assertEquals("Incorrect output for element " + i + "!", expectedarr.getDouble(i), result.getDouble(i), 1e-10); + } + } } \ No newline at end of file diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java index 2ceb0d3b..34977c1b 100644 --- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java +++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java @@ -26,7 +26,7 @@ public class ResBlockTest { @Test public void addLayers() { final String inputName = "input"; - final INDArray input = Nd4j.create(new double[] {77}); + final INDArray input = Nd4j.create(new double[] {77}).reshape(1,1); final ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder().graphBuilder() .addInputs(inputName) @@ -43,6 +43,11 @@ public void addLayers() { graphBuilder.setOutputs(output.getInputsNames()); final ComputationGraph graph = new ComputationGraph(graphBuilder.build()); graph.init(); - assertEquals("Incorrect output!", input.mul(2), graph.output(input)[0]); + final INDArray expected = input.mul(2); + final INDArray result = graph.output(input)[0]; + assertEquals("Lengths does not match!", expected.length(), result.length()); + for(int i = 0; i < expected.length(); i++) { + assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), result.getDouble(i), 1e-10); + } } } \ No newline at end of file diff --git a/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java b/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java index 5074ae7c..d656ecc0 100644 --- a/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java +++ b/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java @@ -3,6 +3,7 @@ import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -51,7 +52,7 @@ public void doBackward() { epsilon.putScalar(4, 0d); epsilon.putScalar(110, 0d); epsilon.putScalar(7300, 0d); - final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0); + final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0, DataType.FLOAT16); toTest.setInputs(new INDArray[]{convInput, gates}); toTest.setEpsilon(epsilon); final INDArray[] result = toTest.doBackward(false, wsMgr).getSecond(); @@ -68,7 +69,7 @@ private static void testDoForward(int batchSize, int nrofChannels) { gates.putScalar(oneZeroedChannel, 0); gates.putScalar(11, 0); gates.putScalar(13, 0); - final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0); + final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0, DataType.FLOAT16); toTest.setInputs(new INDArray[]{convInput, gates}); final INDArray result = toTest.doForward(false, wsMgr); assertEquals("Incorrect mean!", gates.mean(1).getDouble(0), result.mean(1).getDouble(0), 1e-10); From 228c252dd30797f783b34b574a1458c1b92c8773 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sat, 25 Nov 2023 11:22:03 +0100 Subject: [PATCH 2/8] Fix assembly plugin config --- pom.xml | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/pom.xml b/pom.xml index 0e6ef1cf..cca0cdc7 100644 --- a/pom.xml +++ b/pom.xml @@ -130,24 +130,6 @@ maven-surefire-plugin 3.2.2
- - - org.codehaus.mojo - exec-maven-plugin - 3.1.1 - - - - exec - java - - - - - ampcontrol.model.training.model.perftest.ConvPerfTest - maven - - @@ -178,14 +160,16 @@ --> maven-assembly-plugin + 3.6.0 + true ampcontrol.admin.AmpControlMain - AmpControl + jar-with-dependencies From b504dec7659c2179a1e5dad2c2f10154ff844bff Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 26 Nov 2023 22:05:17 +0100 Subject: [PATCH 3/8] Update a couple of other deps. --- README.md | 4 +++- pom.xml | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5996b1fd..065d40a7 100644 --- a/README.md +++ b/README.md @@ -56,12 +56,14 @@ I don't really expect anyone but me to use this project and as of now it does no ### Prerequisites -Maven, Git, Data set or models. Current configuration uses the CUDA 9.2 backend for the models and this requires CUDA 9.2 along with an NVIDIA GPU which supports CUDA 8.0. See https://deeplearning4j.org/gpu +Maven, Git, Data set or models. Current configuration uses the CUDA 11.4 backend for the models which requires a NVIDIA GPU which supports CUDA. See https://deeplearning4j.org/gpu ### Installing Clone the repo and fire it up in your IDE of choice. I had some issues getting CUDA to work in eclipse due to some m2e bug/feature but this resolved when using IntelliJ. After I got it to work with IntelliJ it worked in eclipse for me as well. +Using `mvn compile assembly:single` will produce a jar file under `target` which runs `ampcontrol.admin.AmpControlMain`. + ## Running the tests `mvn clean test` diff --git a/pom.xml b/pom.xml index cca0cdc7..ec2de0a0 100644 --- a/pom.xml +++ b/pom.xml @@ -16,14 +16,15 @@ 1.21 1.0.0-M2.1 1.0.0-M2.1 - 1.7.21 + 2.0.9 1.0.0-beta7 4.11 3.5.2 1.72 19.0 - 1.1.7 + 1.2.5 + 1.4.11 1.0.13 1.0.23 3.11.0 @@ -251,7 +252,7 @@ org.eclipse.paho org.eclipse.paho.client.mqttv3 - 1.1.0 + ${paho.version} com.github.wendykierp From 6592cc37687b2e4a3e7cf8d4ac0fdf29e4fa7c75 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 26 Nov 2023 22:20:40 +0100 Subject: [PATCH 4/8] add github ci action --- .github/workflows/maven.yaml | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 .github/workflows/maven.yaml diff --git a/.github/workflows/maven.yaml b/.github/workflows/maven.yaml new file mode 100644 index 00000000..c74da9d3 --- /dev/null +++ b/.github/workflows/maven.yaml @@ -0,0 +1,43 @@ +# This workflow will build a Java project with Maven, and cache/restore any dependencies to improve the workflow execution time +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-java-with-maven + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Java CI with Maven + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: true + matrix: + os: + - ubuntu-latest + - windows-latest + arch: + - x64 + + steps: + - uses: actions/checkout@v3 + - name: Set up JDK 21 + uses: actions/setup-java@v3 + with: + java-version: '21' + distribution: 'temurin' + cache: maven + - name: Build with Maven + run: mvn clean test -P backend-CPU -B -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn + + # Optional: Uploads the full dependency graph to GitHub to improve the quality of Dependabot alerts this repository can receive + - name: Update dependency graph + uses: advanced-security/maven-dependency-submission-action@571e99aab1055c2e71a1e2309b9691de18d6b7d6 \ No newline at end of file From 1c1a56fd78ffef19ce3210ed3f8ea332e95733bb Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 26 Nov 2023 22:31:21 +0100 Subject: [PATCH 5/8] Remove slf4j config --- .github/workflows/maven.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/maven.yaml b/.github/workflows/maven.yaml index c74da9d3..5e2b5c88 100644 --- a/.github/workflows/maven.yaml +++ b/.github/workflows/maven.yaml @@ -36,7 +36,7 @@ jobs: distribution: 'temurin' cache: maven - name: Build with Maven - run: mvn clean test -P backend-CPU -B -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn + run: mvn clean test -P backend-CPU -B # Optional: Uploads the full dependency graph to GitHub to improve the quality of Dependabot alerts this repository can receive - name: Update dependency graph From e0ad933e7dd5d9a2ed2ac3be29224d33f37f32db Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 26 Nov 2023 22:56:48 +0100 Subject: [PATCH 6/8] Remove import of GPU only class --- README.md | 6 +-- pom.xml | 2 +- .../description/MutatingConv2dFactory.java | 7 +-- .../model/evolve/EvolvingPopulation.java | 7 +-- .../service/control/mqtt/MockMqttClient.java | 45 +++++++++++++++++++ .../MutatingConv2dFactoryTest.java | 1 + 6 files changed, 52 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 065d40a7..48f682d8 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,11 @@ -Travis [![Build Status](https://travis-ci.org/DrChainsaw/AmpControl.svg?branch=master)](https://travis-ci.org/DrChainsaw/AmpControl) -AppVeyor [![Build status](https://ci.appveyor.com/api/projects/status/b9e4h8g0em7r7c1v/branch/master?svg=true)](https://ci.appveyor.com/project/DrChainsaw/ampcontrol-fsko3) +![Build Status](https://github.com/DrChainsaw/AmpControl/actions/workflows/maven.yml/badge.svg) + [![codebeat badge](https://codebeat.co/badges/998446a1-99e0-4f8f-9b62-d8dda4ef780d)](https://codebeat.co/projects/github-com-drchainsaw-ampcontrol-master) -[![Codacy Badge](https://api.codacy.com/project/badge/Grade/1b55604515c3475cb7d4826fd67f7817)](https://www.codacy.com/app/DrChainsaw/AmpControl?utm_source=github.com&utm_medium=referral&utm_content=DrChainsaw/AmpControl&utm_campaign=Badge_Grade) [![Maintainability](https://api.codeclimate.com/v1/badges/94297293618a7a420e6d/maintainability)](https://codeclimate.com/github/DrChainsaw/AmpControl/maintainability) [![Test Coverage](https://api.codeclimate.com/v1/badges/94297293618a7a420e6d/test_coverage)](https://codeclimate.com/github/DrChainsaw/AmpControl/test_coverage) - # AmpControl Do you have an amplifier hidden from plain sight with a dedicated wireless system for casual guitar practice in your living room? Are you sometimes annoyed that when playing along to recordings you can't find a sound which is a good balance for both rhythm and leads, but you also don't want to have a foot switch in your living room? Do you happen to have a moderately powerful computer standing close enough to the amplifier? diff --git a/pom.xml b/pom.xml index 76579180..fc407a06 100644 --- a/pom.xml +++ b/pom.xml @@ -110,7 +110,7 @@ org.jacoco jacoco-maven-plugin - 0.8.5 + 0.8.11 default-prepare-agent diff --git a/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java b/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java index 3a32b420..da2a3585 100644 --- a/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java +++ b/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java @@ -55,7 +55,6 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.jita.memory.CudaMemoryManager; import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; @@ -257,10 +256,8 @@ public void addModelData(List modelData) { initialPopulation.add(adapter); }); - // Its either this or catch an exception since everything but the CudaMemoryManager throws an exception - if (Nd4j.getMemoryManager() instanceof CudaMemoryManager) { - Nd4j.getMemoryManager().purgeCaches(); - } + + Nd4j.getMemoryManager().purgeCaches(); final Population population; try { diff --git a/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java b/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java index 8e5067bf..77ab9d65 100644 --- a/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java +++ b/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java @@ -2,7 +2,6 @@ import ampcontrol.model.training.model.evolve.fitness.FitnessPolicy; import ampcontrol.model.training.model.evolve.selection.Selection; -import org.nd4j.jita.memory.CudaMemoryManager; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,11 +72,7 @@ public EvolvingPopulation evolve() { .collect(Collectors.toList())); evalCands.clear(); - // Its either this or catch an exception since everything but the CudaMemoryManager throws an exception - if(Nd4j.getMemoryManager() instanceof CudaMemoryManager) { - // Needed to free memory? - Nd4j.getMemoryManager().purgeCaches(); - } + Nd4j.getMemoryManager().purgeCaches(); return this; } diff --git a/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java b/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java index 8773c9a1..c645ead2 100644 --- a/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java +++ b/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java @@ -116,6 +116,46 @@ public void subscribe(String[] topicFilters, int[] qos, IMqttMessageListener[] m subscribedTopics.addAll(Arrays.asList(topicFilters)); } + @Override + public IMqttToken subscribeWithResponse(String s) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String s, IMqttMessageListener iMqttMessageListener) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String s, int i) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String s, int i, IMqttMessageListener iMqttMessageListener) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String[] strings) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String[] strings, IMqttMessageListener[] iMqttMessageListeners) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String[] strings, int[] ints) throws MqttException { + return null; + } + + @Override + public IMqttToken subscribeWithResponse(String[] strings, int[] ints, IMqttMessageListener[] iMqttMessageListeners) throws MqttException { + return null; + } + @Override public void unsubscribe(String topicFilter) { //Ignore @@ -173,6 +213,11 @@ public void setManualAcks(boolean manualAcks) { //Ignore } + @Override + public void reconnect() throws MqttException { + + } + @Override public void messageArrivedComplete(int messageId, int qos) { //Ignore diff --git a/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java b/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java index 072af0e9..418007a3 100644 --- a/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java +++ b/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java @@ -1,6 +1,7 @@ package ampcontrol.model.training.model.description; import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; From 812c95ec01963a80fc439812ecc22950f1a9df89 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 26 Nov 2023 23:09:10 +0100 Subject: [PATCH 7/8] Remove dependency graph step due to maximum call stack size exceeded --- .github/workflows/maven.yaml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/maven.yaml b/.github/workflows/maven.yaml index 5e2b5c88..942ab522 100644 --- a/.github/workflows/maven.yaml +++ b/.github/workflows/maven.yaml @@ -37,7 +37,4 @@ jobs: cache: maven - name: Build with Maven run: mvn clean test -P backend-CPU -B - - # Optional: Uploads the full dependency graph to GitHub to improve the quality of Dependabot alerts this repository can receive - - name: Update dependency graph - uses: advanced-security/maven-dependency-submission-action@571e99aab1055c2e71a1e2309b9691de18d6b7d6 \ No newline at end of file + \ No newline at end of file From 3fa75b3bc7b4b7426e4da97641a598404bd4c008 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Mon, 27 Nov 2023 21:11:49 +0100 Subject: [PATCH 8/8] Try to add coverage stats to maven.yaml. --- .appveyor.yml | 8 -------- .github/workflows/maven.yaml | 24 +++++++++++++++++++++--- .travis.yml | 14 -------------- 3 files changed, 21 insertions(+), 25 deletions(-) delete mode 100644 .appveyor.yml delete mode 100644 .travis.yml diff --git a/.appveyor.yml b/.appveyor.yml deleted file mode 100644 index d45c0741..00000000 --- a/.appveyor.yml +++ /dev/null @@ -1,8 +0,0 @@ -version: '{build}' -platform: x86 -cache: - - '%USERPROFILE%\.m2' -environment: - BACKEND_PRIORITY_CPU: 100000 -build_script: -- cmd: mvn clean test -P backend-CPU -B -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn diff --git a/.github/workflows/maven.yaml b/.github/workflows/maven.yaml index 942ab522..0b8ec39b 100644 --- a/.github/workflows/maven.yaml +++ b/.github/workflows/maven.yaml @@ -35,6 +35,24 @@ jobs: java-version: '21' distribution: 'temurin' cache: maven - - name: Build with Maven - run: mvn clean test -P backend-CPU -B - \ No newline at end of file + - name: Setup code coverage + run: | + if [ "$RUNNER_OS" == "Linux" ]; then + curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter + chmod +x ./cc-test-reporter + ./cc-test-reporter before-build + fi + - name: Build and Test + run: | + if [ "$RUNNER_OS" == "Linux" ]; then + mvn clean test jacoco:report -P backend-CPU -B + else + mvn clean test -P backend-CPU -B + fi + + - name: Upload coverage + run: | + if [ "$RUNNER_OS" == "Linux" ]; then + JACOCO_SOURCE_PATH=src/main/java ./cc-test-reporter format-coverage target/site/jacoco/jacoco.xml --input-type jacoco + ./cc-test-reporter upload-coverage + fi diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 663f9559..00000000 --- a/.travis.yml +++ /dev/null @@ -1,14 +0,0 @@ -language: java -sudo: false -env: - global: - - CC_TEST_REPORTER_ID=e35562b8a6ef581798be079625fa963012327fc15a4707ec8ee2d2f28d580558 - - BACKEND_PRIORITY_CPU="100000" -before_script: - - curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter - - chmod +x ./cc-test-reporter - - ./cc-test-reporter before-build -script: mvn clean test jacoco:report -P backend-CPU -B -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -after_script: - - JACOCO_SOURCE_PATH=src/main/java ./cc-test-reporter format-coverage target/site/jacoco/jacoco.xml --input-type jacoco - - ./cc-test-reporter upload-coverage \ No newline at end of file