diff --git a/.appveyor.yml b/.appveyor.yml
deleted file mode 100644
index d45c0741..00000000
--- a/.appveyor.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-version: '{build}'
-platform: x86
-cache:
- - '%USERPROFILE%\.m2'
-environment:
- BACKEND_PRIORITY_CPU: 100000
-build_script:
-- cmd: mvn clean test -P backend-CPU -B -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn
diff --git a/.github/workflows/maven.yaml b/.github/workflows/maven.yaml
new file mode 100644
index 00000000..0b8ec39b
--- /dev/null
+++ b/.github/workflows/maven.yaml
@@ -0,0 +1,58 @@
+# This workflow will build a Java project with Maven, and cache/restore any dependencies to improve the workflow execution time
+# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-java-with-maven
+
+# This workflow uses actions that are not certified by GitHub.
+# They are provided by a third-party and are governed by
+# separate terms of service, privacy policy, and support
+# documentation.
+
+name: Java CI with Maven
+
+on:
+ push:
+ branches: [ "master" ]
+ pull_request:
+ branches: [ "master" ]
+
+jobs:
+ build:
+
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: true
+ matrix:
+ os:
+ - ubuntu-latest
+ - windows-latest
+ arch:
+ - x64
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up JDK 21
+ uses: actions/setup-java@v3
+ with:
+ java-version: '21'
+ distribution: 'temurin'
+ cache: maven
+ - name: Setup code coverage
+ run: |
+ if [ "$RUNNER_OS" == "Linux" ]; then
+ curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter
+ chmod +x ./cc-test-reporter
+ ./cc-test-reporter before-build
+ fi
+ - name: Build and Test
+ run: |
+ if [ "$RUNNER_OS" == "Linux" ]; then
+ mvn clean test jacoco:report -P backend-CPU -B
+ else
+ mvn clean test -P backend-CPU -B
+ fi
+
+ - name: Upload coverage
+ run: |
+ if [ "$RUNNER_OS" == "Linux" ]; then
+ JACOCO_SOURCE_PATH=src/main/java ./cc-test-reporter format-coverage target/site/jacoco/jacoco.xml --input-type jacoco
+ ./cc-test-reporter upload-coverage
+ fi
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 663f9559..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,14 +0,0 @@
-language: java
-sudo: false
-env:
- global:
- - CC_TEST_REPORTER_ID=e35562b8a6ef581798be079625fa963012327fc15a4707ec8ee2d2f28d580558
- - BACKEND_PRIORITY_CPU="100000"
-before_script:
- - curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter
- - chmod +x ./cc-test-reporter
- - ./cc-test-reporter before-build
-script: mvn clean test jacoco:report -P backend-CPU -B -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn
-after_script:
- - JACOCO_SOURCE_PATH=src/main/java ./cc-test-reporter format-coverage target/site/jacoco/jacoco.xml --input-type jacoco
- - ./cc-test-reporter upload-coverage
\ No newline at end of file
diff --git a/README.md b/README.md
index 5996b1fd..48f682d8 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,11 @@
-Travis [](https://travis-ci.org/DrChainsaw/AmpControl)
-AppVeyor [](https://ci.appveyor.com/project/DrChainsaw/ampcontrol-fsko3)
+
+
[](https://codebeat.co/projects/github-com-drchainsaw-ampcontrol-master)
-[](https://www.codacy.com/app/DrChainsaw/AmpControl?utm_source=github.com&utm_medium=referral&utm_content=DrChainsaw/AmpControl&utm_campaign=Badge_Grade)
[](https://codeclimate.com/github/DrChainsaw/AmpControl/maintainability)
[](https://codeclimate.com/github/DrChainsaw/AmpControl/test_coverage)
-
# AmpControl
Do you have an amplifier hidden from plain sight with a dedicated wireless system for casual guitar practice in your living room? Are you sometimes annoyed that when playing along to recordings you can't find a sound which is a good balance for both rhythm and leads, but you also don't want to have a foot switch in your living room? Do you happen to have a moderately powerful computer standing close enough to the amplifier?
@@ -56,12 +54,14 @@ I don't really expect anyone but me to use this project and as of now it does no
### Prerequisites
-Maven, Git, Data set or models. Current configuration uses the CUDA 9.2 backend for the models and this requires CUDA 9.2 along with an NVIDIA GPU which supports CUDA 8.0. See https://deeplearning4j.org/gpu
+Maven, Git, Data set or models. Current configuration uses the CUDA 11.4 backend for the models which requires a NVIDIA GPU which supports CUDA. See https://deeplearning4j.org/gpu
### Installing
Clone the repo and fire it up in your IDE of choice. I had some issues getting CUDA to work in eclipse due to some m2e bug/feature but this resolved when using IntelliJ. After I got it to work with IntelliJ it worked in eclipse for me as well.
+Using `mvn compile assembly:single` will produce a jar file under `target` which runs `ampcontrol.admin.AmpControlMain`.
+
## Running the tests
`mvn clean test`
diff --git a/pom.xml b/pom.xml
index 9dadcb75..fc407a06 100644
--- a/pom.xml
+++ b/pom.xml
@@ -3,7 +3,7 @@
4.0.0
AmpControl
AmpControl
- 0.5.2-SNAPSHOT
+ 0.5.3-SNAPSHOT
AmpControl
Controls amplifiers
@@ -11,32 +11,23 @@
UTF-8
bin
- 1.8
- 1.8
- 1.8
- 1.0.0-beta3
- 1.0.0-beta3
- 1.7.21
- 1.0.0-beta3
- 1.0.0-beta3
- 1.0.0-beta3
+ 21
+ 21
+ 1.21
+ 1.0.0-M2.1
+ 1.0.0-M2.1
+ 2.0.9
+ 1.0.0-beta7
4.13.1
3.5.2
1.72
-
-
- 1.0.0-beta3_spark_1
- 1.0.0-beta3_spark_1
-
-
-
- 2.11
19.0
- 1.1.7
+ 1.2.5
+ 1.4.11
1.0.13
1.0.23
- 3.7.0
+ 3.11.0
2.4.3
1.4.0
3.3.1
@@ -44,7 +35,7 @@
1.11.109
2.6.6
3.2.2
- 1.18.2
+ 1.18.30
@@ -58,6 +49,13 @@
${maven.compiler.source}
${maven.compiler.target}
+
+
+ org.projectlombok
+ lombok
+ ${lombok.version}
+
+
@@ -86,7 +84,7 @@
-
+
org.jacoco
jacoco-maven-plugin
- 0.8.5
+ 0.8.11
default-prepare-agent
@@ -131,7 +129,7 @@
org.apache.maven.plugins
maven-surefire-plugin
- 2.21.0
+ 3.2.2
@@ -163,14 +161,16 @@
-->
maven-assembly-plugin
+ 3.6.0
+ true
ampcontrol.admin.AmpControlMain
- AmpControl
+ jar-with-dependencies
@@ -192,17 +192,22 @@
true
- nd4j-cuda-10.0-platform
+ nd4j-cuda-11.4-platform
+ org.bytedeco
+ cuda-platform-redist
+ 11.4-8.2-1.5.6
+
+
org.deeplearning4j
deeplearning4j-core
@@ -215,12 +220,11 @@
org.deeplearning4j
- deeplearning4j-ui_2.11
+ deeplearning4j-ui
${dl4j.version}
org.nd4j
-
${nd4j-backend}
${nd4j.version}
@@ -248,7 +252,7 @@
org.eclipse.paho
org.eclipse.paho.client.mqttv3
- 1.1.0
+ ${paho.version}
com.github.wendykierp
@@ -288,6 +292,16 @@
slf4j-api
${slf4j.version}
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${jackson.version}
+
+
+ org.jetbrains
+ annotations
+ 24.1.0
+
diff --git a/src/main/java/ampcontrol/admin/AmpControlMain.java b/src/main/java/ampcontrol/admin/AmpControlMain.java
index 8264efa4..53237044 100644
--- a/src/main/java/ampcontrol/admin/AmpControlMain.java
+++ b/src/main/java/ampcontrol/admin/AmpControlMain.java
@@ -8,11 +8,8 @@
import ampcontrol.audio.asio.AsioClassifierInputFactory;
import ampcontrol.model.inference.Classifier;
import ampcontrol.model.inference.ClassifierFromParameters;
-import ampcontrol.model.training.model.vertex.ChannelMultVertex;
-import ampcontrol.model.training.model.vertex.ElementWiseVertexLatest;
import com.beust.jcommander.JCommander;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
-import org.nd4j.linalg.api.buffer.DataBuffer;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j;
@@ -62,9 +59,9 @@ public static void main(String[] args) {
try {
// Might need to move into concrete Classifiers if something else is used in training
- DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
+ DataTypeUtil.setDTypeForContext(DataType.FLOAT16);
Nd4j.getMemoryManager().setAutoGcWindow(5000);
- NeuralNetConfiguration.registerLegacyCustomClassesForJSON(ChannelMultVertex.class, ElementWiseVertexLatest.class);
+ //NeuralNetConfiguration.registerLegacyCustomClassesForJSON(ChannelMultVertex.class, ElementWiseVertexLatest.class);
final Classifier classifier = classifierFromParameters.getClassifier(inputProviderFactory);
audioClassificationService.initialize(
classificationListenerAgg,
diff --git a/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java b/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java
index 542c254c..c4cc0125 100644
--- a/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java
+++ b/src/main/java/ampcontrol/amp/midi/MidiProgChangeFactory.java
@@ -31,7 +31,7 @@ public class MidiProgChangeFactory implements AmpInterface.Factory {
private Predicate device = Devices.audioBox44Vsl;
@Parameter(names = {"-labelToProg", "-ltp"},
- description = "Comma separated list of how to map labels programs. First program is mapped to label 0 etc")
+ description = "Comma separated list of how to map labels to programs. First program is mapped to label 0 etc")
private List programChangesList = Arrays.asList(
2, // Silence
2, // Noise
diff --git a/src/main/java/ampcontrol/amp/probabilities/ArgMax.java b/src/main/java/ampcontrol/amp/probabilities/ArgMax.java
index c677dbd6..06c6217c 100644
--- a/src/main/java/ampcontrol/amp/probabilities/ArgMax.java
+++ b/src/main/java/ampcontrol/amp/probabilities/ArgMax.java
@@ -14,6 +14,6 @@ public class ArgMax implements Interpreter {
@Override
public List apply(INDArray probabilities) {
- return Collections.singletonList(probabilities.argMax(1).getInt(0));
+ return Collections.singletonList(probabilities.argMax(probabilities.rank()-1).getInt(0));
}
}
diff --git a/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java b/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java
index 88590bd3..404650eb 100644
--- a/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java
+++ b/src/main/java/ampcontrol/amp/probabilities/ThresholdFilter.java
@@ -26,7 +26,7 @@ public ThresholdFilter(int index, double threshold, Interpreter next) {
@Override
public List apply(INDArray indArray) {
- if(indArray.argMax(1).getInt(0) == index) {
+ if(indArray.argMax(indArray.rank()-1).getInt(0) == index) {
if (indArray.getDouble(index) < threshold) {
return new ArrayList<>();
}
diff --git a/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java b/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java
index 17210ba2..58c9a810 100644
--- a/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java
+++ b/src/main/java/ampcontrol/audio/asio/AsioAudioInputBuffer.java
@@ -111,7 +111,7 @@ public static void main(String[] args) {
driver.createBuffers(activeChannels);
driver.start();
-
+ System.out.println("Driver started, sleep now");
Thread.sleep(2000);
diff --git a/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java b/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java
index 118d8260..b0be0391 100644
--- a/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java
+++ b/src/main/java/ampcontrol/audio/processing/ProcessingFactoryFromString.java
@@ -1,7 +1,5 @@
package ampcontrol.audio.processing;
-import org.jetbrains.annotations.Nullable;
-
/**
* Creates a {@link ProcessingResult.Factory} based on a name string. Used for persistence since
* all preprocessing needs be recreated when restoring a saved model.
@@ -99,7 +97,6 @@ public ProcessingResult.Factory get(String nameStr) {
return new UnitMaxZeroMean();
}
- @Nullable
private ProcessingResult.Factory getPipedSupplier(
String nameStr,
ProcessingResult.Factory first,
diff --git a/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java b/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java
index 485ed48a..cbc749fb 100644
--- a/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java
+++ b/src/main/java/ampcontrol/model/inference/EnsembleWeightedSumClassifier.java
@@ -1,7 +1,7 @@
package ampcontrol.model.inference;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
@@ -19,7 +19,7 @@ class EnsembleWeightedSumClassifier implements Classifier {
final static BiFunction avgNormalizer = (sumAcc, aggClass) -> aggClass.div(sumAcc);
- final static BiFunction softMaxNormalizer = (sumAcc, aggClass) -> Nd4j.getExecutioner().execAndReturn(new OldSoftMax(aggClass));
+ final static BiFunction softMaxNormalizer = (sumAcc, aggClass) -> Nd4j.getExecutioner().execAndReturn(new SoftMax(aggClass)).getOutputArgument(0);
private final List ensemble;
private final BiFunction normalizer;
diff --git a/src/main/java/ampcontrol/model/inference/SpyClassifier.java b/src/main/java/ampcontrol/model/inference/SpyClassifier.java
index f47a0734..8d505631 100644
--- a/src/main/java/ampcontrol/model/inference/SpyClassifier.java
+++ b/src/main/java/ampcontrol/model/inference/SpyClassifier.java
@@ -136,7 +136,7 @@ public INDArray classify() {
}
private void accumInput(INDArray classification) {
- int highestProb = classification.argMax(1).getInt(0);
+ int highestProb = classification.argMax(classification.rank()-1).getInt(0);
if(classification.getDouble(highestProb) > threshold) {
if(highestProb == right) {
diff --git a/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java b/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java
index f793c7d7..54f58062 100644
--- a/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java
+++ b/src/main/java/ampcontrol/model/inference/StoredGraphClassifier.java
@@ -4,7 +4,6 @@
import ampcontrol.model.training.model.validation.listen.BestEvalScore;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
-import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.IOException;
@@ -31,7 +30,6 @@ public class StoredGraphClassifier implements Classifier {
this.accuracy = new BestEvalScore(realFileName + ".score").get();
}
- @NotNull
private static String getHashedFileNameFromModelName(String path) {
final String modelName = Paths.get(path).getFileName().toString();
String pathHashCode = modelName;
diff --git a/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java b/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java
index a667c8c9..05451cae 100644
--- a/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java
+++ b/src/main/java/ampcontrol/model/training/ModelEvaluationWorkBench.java
@@ -114,8 +114,7 @@ private void evalDecreaseKernelSize(ComputationGraph graph, String layer) {
.map(layerConf -> (ConvolutionLayer) layerConf)
.map(convConf -> {
convConf.setKernelSize(new int[]{convConf.getKernelSize()[0] - 1, convConf.getKernelSize()[1]});
- convConf.setWeightInit(WeightInit.DISTRIBUTION);
- convConf.setDist(new ConstantDistribution(0));
+ convConf.setWeightInitFn(WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(0)));
return convConf;
})
.orElseThrow(() -> new IllegalArgumentException("Could not mutate layer from " + layerConfIn)))
@@ -138,8 +137,7 @@ private void evalIncreaseKernelSize(ComputationGraph graph, String layer) {
.map(layerConf -> (ConvolutionLayer) layerConf)
.map(convConf -> {
convConf.setKernelSize(new int[]{convConf.getKernelSize()[0] + 1, convConf.getKernelSize()[1]});
- convConf.setWeightInit(WeightInit.DISTRIBUTION);
- convConf.setDist(new ConstantDistribution(0));
+ convConf.setWeightInitFn(WeightInit.DISTRIBUTION.getWeightInitFunction(new ConstantDistribution(0)));
return convConf;
})
.orElseThrow(() -> new IllegalArgumentException("Could not mutate layer from " + layerConfIn)))
diff --git a/src/main/java/ampcontrol/model/training/TrainingDescription.java b/src/main/java/ampcontrol/model/training/TrainingDescription.java
index 0ad5c25d..38dbfda2 100644
--- a/src/main/java/ampcontrol/model/training/TrainingDescription.java
+++ b/src/main/java/ampcontrol/model/training/TrainingDescription.java
@@ -17,7 +17,7 @@
import ampcontrol.model.visualize.RealTimePlot;
import ch.qos.logback.classic.Level;
import org.jetbrains.annotations.NotNull;
-import org.nd4j.linalg.api.buffer.DataBuffer;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -79,7 +79,8 @@ public static void main(String[] args) {
ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME);
root.setLevel(Level.INFO);
- DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
+ DataTypeUtil.setDTypeForContext(DataType.FLOAT16);
+ DataTypeUtil.getDtypeFromContext();
List modelData = new ArrayList<>();
diff --git a/src/main/java/ampcontrol/model/training/TrainingHarness.java b/src/main/java/ampcontrol/model/training/TrainingHarness.java
index 0582213e..501eca81 100644
--- a/src/main/java/ampcontrol/model/training/TrainingHarness.java
+++ b/src/main/java/ampcontrol/model/training/TrainingHarness.java
@@ -9,10 +9,10 @@
import ampcontrol.model.training.model.validation.*;
import ampcontrol.model.training.model.validation.listen.*;
import ampcontrol.model.visualize.Plot;
-import org.deeplearning4j.api.storage.StatsStorage;
+import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.ui.api.UIServer;
-import org.deeplearning4j.ui.stats.StatsListener;
-import org.deeplearning4j.ui.storage.FileStatsStorage;
+import org.deeplearning4j.ui.model.stats.StatsListener;
+import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
diff --git a/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java b/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java
index 50d02b33..94a2fc24 100644
--- a/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java
+++ b/src/main/java/ampcontrol/model/training/data/CyclingLabelSupplier.java
@@ -1,7 +1,7 @@
package ampcontrol.model.training.data;
import ampcontrol.model.training.data.state.StateFactory;
-import org.apache.commons.lang.mutable.MutableInt;
+import org.apache.commons.lang3.mutable.MutableInt;
import java.util.List;
import java.util.function.Supplier;
diff --git a/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java b/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java
index 1c8e7fa3..ecdce9d9 100644
--- a/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java
+++ b/src/main/java/ampcontrol/model/training/data/iterators/factory/AutoFromSize.java
@@ -3,7 +3,6 @@
import ampcontrol.model.training.data.iterators.MiniEpochDataSetIterator;
import ampcontrol.model.training.data.state.ResetableState;
import lombok.Builder;
-import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -82,7 +81,7 @@ public MiniEpochDataSetIterator create(Input input) {
return factory.create(input.sourceInput);
}
- @NotNull
+
private DataSetIteratorFactory createFactory(Input input, long sizeOfOneBatch, long sizeOfWholeDataSet) {
DataSetIteratorFactory factory;
if (margin * sizeOfWholeDataSet > memoryAllowance) {
@@ -93,7 +92,7 @@ private DataSetIteratorFactory createFactory(Input<
return factory;
}
- @NotNull
+
private DataSetIteratorFactory createAsynchFactory(Input input, long sizeOfOneBatch, long sizeOfWholeDataSet) {
DataSetIteratorFactory factory;
log.info("Create Asynch iter for set of size {} with memory allowance {}", sizeOfWholeDataSet, memoryAllowance);
@@ -105,7 +104,6 @@ private DataSetIteratorFactory createAsynchFactory(
return factory;
}
- @NotNull
private DataSetIteratorFactory createCachingFactory(Input input, long sizeOfWholeDataSet) {
DataSetIteratorFactory factory;
log.info("Create Caching iter for set of size {} with memory allowance {}", sizeOfWholeDataSet, memoryAllowance);
diff --git a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java
index dc492cf7..eac84034 100644
--- a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java
+++ b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/Cnn2DtoCnn1DInputPreprocessor.java
@@ -4,8 +4,8 @@
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.primitives.Pair;
/**
* {@link InputPreProcessor} which changes CNN 2D input to CNN 1D input. Assumes only one channel.
diff --git a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java
index fbb1b586..1d5e4f9c 100644
--- a/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java
+++ b/src/main/java/ampcontrol/model/training/data/iterators/preprocs/CnnHeightWidthSwapInputPreprocessor.java
@@ -4,10 +4,8 @@
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.primitives.Pair;
-
-import javax.ws.rs.NotSupportedException;
/**
* {@link InputPreProcessor} which swaps height and width dimensions of CNN input. Intended use is when doing 1D
@@ -45,6 +43,6 @@ public InputType getOutputType(InputType inputType) {
@Override
public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
- throw new NotSupportedException("Not implemented yet!");
+ throw new UnsupportedOperationException("Not implemented yet!");
}
}
\ No newline at end of file
diff --git a/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java b/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java
index df4905ca..95346e5c 100644
--- a/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java
+++ b/src/main/java/ampcontrol/model/training/data/processing/SequentialHoldFileSupplier.java
@@ -1,7 +1,7 @@
package ampcontrol.model.training.data.processing;
import ampcontrol.model.training.data.state.StateFactory;
-import org.apache.commons.lang.mutable.MutableInt;
+import org.apache.commons.lang3.mutable.MutableInt;
import java.nio.file.Path;
import java.util.List;
diff --git a/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java b/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java
index 1f7f920d..0796a2e9 100644
--- a/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java
+++ b/src/main/java/ampcontrol/model/training/listen/ActivationContribution.java
@@ -13,6 +13,7 @@
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Map;
import java.util.function.Consumer;
@@ -61,7 +62,8 @@ private void setEps(INDArray eps) {
try (MemoryWorkspace wss = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, wsName)) {
final INDArray tmpEps = eps.migrate(false);
int[] meanDims = IntStream.range(0, act.rank()).filter(dim -> dim != 1).toArray();
- listener.accept(tmpEps.muli(act).amean(meanDims));
+ // amean seems broken in M2.1: it always returns a single element array
+ listener.accept(Transforms.abs(tmpEps.muli(act)).mean(meanDims));
}
}
}
diff --git a/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java b/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java
index d138b0ea..da2a3585 100644
--- a/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java
+++ b/src/main/java/ampcontrol/model/training/model/description/MutatingConv2dFactory.java
@@ -49,14 +49,12 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Builder;
import lombok.Getter;
-import org.apache.commons.lang.mutable.MutableLong;
+import org.apache.commons.lang3.mutable.MutableLong;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
-import org.jetbrains.annotations.NotNull;
-import org.nd4j.jita.memory.CudaMemoryManager;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
@@ -258,10 +256,8 @@ public void addModelData(List modelData) {
initialPopulation.add(adapter);
});
- // Its either this or catch an exception since everything but the CudaMemoryManager throws an exception
- if (Nd4j.getMemoryManager() instanceof CudaMemoryManager) {
- Nd4j.getMemoryManager().purgeCaches();
- }
+
+ Nd4j.getMemoryManager().purgeCaches();
final Population population;
try {
@@ -403,7 +399,7 @@ private AccessibleState> createInitialEvolutionState(Mu
}
}
- @NotNull
+
private Population createPopulation(
Map modelAgeMap,
ModelComparatorRegistry comparatorRegistry,
@@ -595,7 +591,7 @@ private Mutation createAddBlockMutation(
.filter(mut -> rng.nextDouble() < 0.1));
}
- @NotNull
+
private static Function createAfterGlobPoolLayerFactory(UnaryOperator spyFactory, Random rng) {
final Function spyConfig = lbc -> new SpyBlock(lbc)
@@ -608,7 +604,7 @@ private static Function createAfterGlobPoolLayerFactory(
return nOut -> afterGpBlocks.andThen(spyConfig).apply(nOut);
}
- @NotNull
+
private static Function createBeforeGlobPoolLayerFactory(UnaryOperator spyFactory, Random rng) {
final Function spyConfig = lbc -> new SpyBlock(lbc)
.setFactory(spyFactory);
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java b/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java
index 8e5067bf..77ab9d65 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/EvolvingPopulation.java
@@ -2,7 +2,6 @@
import ampcontrol.model.training.model.evolve.fitness.FitnessPolicy;
import ampcontrol.model.training.model.evolve.selection.Selection;
-import org.nd4j.jita.memory.CudaMemoryManager;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -73,11 +72,7 @@ public EvolvingPopulation evolve() {
.collect(Collectors.toList()));
evalCands.clear();
- // Its either this or catch an exception since everything but the CudaMemoryManager throws an exception
- if(Nd4j.getMemoryManager() instanceof CudaMemoryManager) {
- // Needed to free memory?
- Nd4j.getMemoryManager().purgeCaches();
- }
+ Nd4j.getMemoryManager().purgeCaches();
return this;
}
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java b/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java
index 7d811dce..9d934401 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/crossover/graph/CrossoverPoint.java
@@ -7,7 +7,6 @@
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
-import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -82,7 +81,6 @@ GraphInfo execute() {
return new GraphInfo.Result(builder, infoMap);
}
- @NotNull
private GraphBuilder initBuilder() {
final ComputationGraphConfiguration conf = bottom.builder().build();
final GraphBuilder builder = new GraphBuilder(
@@ -93,7 +91,6 @@ private GraphBuilder initBuilder() {
return builder;
}
- @NotNull
private Map addTop(GraphBuilder builder) {
// Need to access names which are already added to the builder below, so first create the mapping
// between names in top and a name which is unique in builder.
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java b/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java
index 6f92908b..cf93925e 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/mutate/MemoryAwareMutation.java
@@ -1,7 +1,5 @@
package ampcontrol.model.training.model.evolve.mutate;
-import org.bytedeco.javacpp.IntPointer;
-import org.bytedeco.javacpp.Pointer;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.function.Function;
@@ -27,17 +25,17 @@ interface MemoryProvider {
private static class DeviceMemoryProvider implements MemoryProvider {
- private final Pointer devicePointer;
+ private final int device;
private final double totalMemory;
private DeviceMemoryProvider() {
- this.devicePointer = new IntPointer(NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice());
- this.totalMemory = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceTotalMemory(devicePointer);
+ this.device = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
+ this.totalMemory = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceTotalMemory(device);
}
@Override
public double getUsage() {
- return (totalMemory - NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(devicePointer)) / totalMemory;
+ return (totalMemory - NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(device)) / totalMemory;
}
}
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java b/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java
index 66062d2e..71c4fe7e 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunction.java
@@ -38,7 +38,7 @@ public LayerBlockConfig apply(Long nOut) {
public static SpyFunction weightInit(Function source, WeightInit weightInit) {
return new SpyFunction(factory -> new LayerSpyAdapter((layerName, layer, layerInputs) -> {
if (layer instanceof BaseLayer) {
- ((BaseLayer) layer).setWeightInit(weightInit);
+ ((BaseLayer) layer).setWeightInitFn(weightInit.getWeightInitFunction());
}
}, factory), source);
}
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java b/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java
index 883cf15a..b90afffb 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransfer.java
@@ -14,8 +14,6 @@
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.indexing.INDArrayIndex;
-import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.*;
import java.util.function.Function;
@@ -269,7 +267,7 @@ private TransferTask.ListBuilder transferOtherWeights(TransferContext transferCo
.addDependentTask(createDependentTask(
registry.register(paramPair.source.get(DefaultParamInitializer.BIAS_KEY), paramPair.layerName + "_source_b"),
registry.register(paramPair.target.get(DefaultParamInitializer.BIAS_KEY), paramPair.layerName + "_target_b"),
- dim -> 1, // Always 1 for bias!
+ dim -> 0, // Always 0 for bias!
transferContext.inputDimension, 2, 3, 4)
);
}
@@ -526,11 +524,11 @@ private static void setIdentityMapping(INDArray weights) {
weights.assign(Nd4j.eye(weights.size(0)));
} else if (weights.shape().length == 4 && weights.size(2) % 2 == 1 && weights.size(3) % 2 == 1) {
weights.assign(Nd4j.zeros(weights.shape()));
- final long centerH = weights.size(2) / 2;
- final long centerW = weights.size(3) / 2;
+ final int centerH = (int)weights.size(2) / 2;
+ final int centerW = (int)weights.size(3) / 2;
+
for (int i = 0; i < weights.size(0); i++) {
- weights.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.point(i), NDArrayIndex.point(centerH), NDArrayIndex.point(centerW)},
- Nd4j.ones(1));
+ weights.put(new int[] {i, i, centerH, centerW}, Nd4j.ones(1));
}
}
}
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java b/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java
index aa89116f..9bd18af5 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTask.java
@@ -23,15 +23,12 @@
@Builder(builderClassName = "Builder", buildMethodName = "buildInternal")
public class SingleTransferTask implements TransferTask {
- @lombok.Builder.Default
- private final IndMapping source = IndMapping.builder().build();
- @lombok.Builder.Default
- private final IndMapping target = IndMapping.builder().build();
+ private final IndMapping source;
+ private final IndMapping target;
private final Function> compFactory;
@lombok.Singular("maskDim")
private final Set dimensionMask;
- @lombok.Builder.Default
- private final TransferTask dependentTask = new NoTransferTask();
+ private final TransferTask dependentTask;
@lombok.Builder(builderClassName = "Builder")
@@ -111,12 +108,15 @@ public void execute() {
}
public static class Builder implements ListBuilder {
+ private IndMapping source = IndMapping.builder().build();
+ private IndMapping target = IndMapping.builder().build();
+ private TransferTask dependentTask = new NoTransferTask();
private ListBuilder dependentTaskBuilder = NoTransferTask.builder();
// To trick lombok so all boilerplate is done in autogenerated buildInternal
public SingleTransferTask build() {
target.getEntry().put(source.getEntry());
- dependentTask(dependentTaskBuilder.build());
+ setDependentTask(dependentTaskBuilder.build());
return this.buildInternal();
}
@@ -132,7 +132,7 @@ public Builder addDependentTask(ListBuilder dependentTaskBuilder) {
return this;
}
- private Builder dependendTask(TransferTask dependentTask) {
+ private Builder setDependentTask(TransferTask dependentTask) {
this.dependentTask = dependentTask;
return this;
}
diff --git a/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java b/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java
index 68a3dcfe..9ed1c661 100644
--- a/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java
+++ b/src/main/java/ampcontrol/model/training/model/evolve/transfer/TransferRegistry.java
@@ -20,7 +20,7 @@
*/
class TransferRegistry {
- private final Map registry = new HashMap<>();
+ private final Map registry = new IdentityHashMap<>();
private final Map actions = new LinkedHashMap<>();
class ArrayEntry {
@@ -61,6 +61,9 @@ public int compare(Integer e1, Integer e2) {
if(e1.equals(e2)) {
return 0;
}
+ if(tensorDimensions.length == 0) {
+ return -Double.compare(Math.abs(array.getDouble(e1)), Math.abs(array.getDouble(e2)));
+ }
return -Double.compare(
abs(array.tensorAlongDimension(e1, tensorDimensions)).sumNumber().doubleValue(),
@@ -140,6 +143,10 @@ private INDArray get() {
} catch (ND4JIllegalStateException e) {
throw new ND4JIllegalStateException("Could not get array " + debugName + "! Target array of shape "
+ Arrays.toString(array.shape()) + ". Wanted indexes " + Arrays.toString(asIndArray()), e);
+ } catch (NullPointerException e) {
+ // Workaround for what seems to be a bug in dl4j: It sometimes gets has a nullpointer for some array indices
+ // which only shows up when using certain INDArrayIndex types (e.g. index, indices, SpecifiedIndex).
+ return addBackSingletonDimensions(array.dup().get(asIndArray()));
}
}
@@ -173,7 +180,6 @@ private INDArrayIndex merge(INDArrayIndex index1, Optional index2
private INDArrayIndex[] asIndArray() {
return IntStream.range(0, array.rank())
.mapToObj(dim -> indexMap.getOrDefault(dim, NDArrayIndex.all()))
- .peek(INDArrayIndex::reset)
.toArray(INDArrayIndex[]::new);
}
diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java
index 0612e3f6..601448be 100644
--- a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java
+++ b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertex.java
@@ -6,6 +6,7 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@@ -51,8 +52,8 @@ public int maxVertexInputs() {
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
- INDArray paramsView, boolean initializeParams) {
- return new ChannelMultVertexImpl(graph, name, idx);
+ INDArray paramsView, boolean initializeParams, DataType dataType) {
+ return new ChannelMultVertexImpl(graph, name, idx, dataType);
}
@Override
diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java
index 759f17b9..963fcafb 100644
--- a/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java
+++ b/src/main/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImpl.java
@@ -9,8 +9,10 @@
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.common.primitives.Pair;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.primitives.Pair;
+
/**
* {@link BaseGraphVertex} which multiplies each channel in a convolutional activation of size [b,c,h,w] with a scalar
@@ -21,14 +23,14 @@
*/
public class ChannelMultVertexImpl extends BaseGraphVertex {
- public ChannelMultVertexImpl(ComputationGraph graph, String name, int vertexIndex) {
- this(graph, name, vertexIndex, null, null);
+ public ChannelMultVertexImpl(ComputationGraph graph, String name, int vertexIndex, DataType dataType) {
+ this(graph, name, vertexIndex, null, null, dataType);
}
public ChannelMultVertexImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
- VertexIndices[] outputVertices) {
- super(graph, name, vertexIndex, inputVertices, outputVertices);
+ VertexIndices[] outputVertices, DataType dataType) {
+ super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
}
@Override
@@ -87,8 +89,8 @@ private INDArray scaleChannels(INDArray channelActivations, INDArray scaleFactor
// Goal: multiply each h*w activation a_i with the corresponding scale factor s_i, i in range [0, b*c-1]
- final long nrofChannelBatch = channelActivations.tensorssAlongDimension(2, 3); // view each channel and each batch
- final long nrofFeaturesInActivation = channelActivations.tensorssAlongDimension(0, 1); // view each h and w activation for all batches
+ final long nrofChannelBatch = channelActivations.tensorsAlongDimension(2, 3); // view each channel and each batch
+ final long nrofFeaturesInActivation = channelActivations.tensorsAlongDimension(0, 1); // view each h and w activation for all batches
// From empiric testing: Whatever makes the fewest number of loops is fastest
if (nrofChannelBatch < nrofFeaturesInActivation) {
diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java
index eed0bb20..ddbab6b6 100644
--- a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java
+++ b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatest.java
@@ -22,6 +22,7 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;
@@ -90,7 +91,7 @@ public int maxVertexInputs() {
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
- INDArray paramsView, boolean initializeParams) {
+ INDArray paramsView, boolean initializeParams, DataType dataType) {
ElementWiseVertexLatestImpl.Op op;
switch (this.op) {
case Add:
@@ -108,7 +109,7 @@ public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGra
default:
throw new UnsupportedOperationException("No support for op: " + this.op);
}
- return new ElementWiseVertexLatestImpl(graph, name, idx, op);
+ return new ElementWiseVertexLatestImpl(graph, name, idx, op, dataType);
}
@Override
diff --git a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java
index 607f6463..e78a2ada 100644
--- a/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java
+++ b/src/main/java/ampcontrol/model/training/model/vertex/ElementWiseVertexLatestImpl.java
@@ -25,10 +25,11 @@
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.common.primitives.Pair;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.impl.transforms.Or;
+import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.primitives.Pair;
/** An ElementWiseVertex is used to combine the activations of two or more layer in an element-wise manner
* For example, the activations may be combined by addition, subtraction or multiplication or by selecting the maximum.
@@ -46,13 +47,13 @@ public enum Op {
private Op op;
private int nInForwardPass;
- public ElementWiseVertexLatestImpl(ComputationGraph graph, String name, int vertexIndex, Op op) {
- this(graph, name, vertexIndex, null, null, op);
+ public ElementWiseVertexLatestImpl(ComputationGraph graph, String name, int vertexIndex, Op op, DataType dataType) {
+ this(graph, name, vertexIndex, null, null, op, dataType);
}
public ElementWiseVertexLatestImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
- VertexIndices[] outputVertices, Op op) {
- super(graph, name, vertexIndex, inputVertices, outputVertices);
+ VertexIndices[] outputVertices, Op op, DataType dataType) {
+ super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
this.op = op;
}
@@ -179,6 +180,7 @@ public Pair feedForwardMaskArrays(INDArray[] maskArrays, Ma
return new Pair<>(maskArrays[0], currentMaskState);
} else {
INDArray ret = maskArrays[0].dup(maskArrays[0].ordering());
+
Nd4j.getExecutioner().exec(new Or(maskArrays[0], maskArrays[1], ret));
for (int i = 2; i < maskArrays.length; i++) {
Nd4j.getExecutioner().exec(new Or(maskArrays[i], ret, ret));
diff --git a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java
index 8cda9fac..e823853b 100644
--- a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java
+++ b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertex.java
@@ -6,6 +6,7 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
@@ -50,8 +51,8 @@ public int maxVertexInputs() {
}
@Override
- public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) {
- return new EpsilonSpyVertexImpl(graph, name, idx);
+ public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType dataType) {
+ return new EpsilonSpyVertexImpl(graph, name, idx, dataType);
}
diff --git a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java
index 6d6d43f5..cb95047e 100644
--- a/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java
+++ b/src/main/java/ampcontrol/model/training/model/vertex/EpsilonSpyVertexImpl.java
@@ -8,8 +8,9 @@
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.common.primitives.Pair;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.primitives.Pair;
import java.util.Optional;
import java.util.function.Consumer;
@@ -25,12 +26,12 @@ public class EpsilonSpyVertexImpl extends BaseGraphVertex {
private Consumer listener;
- protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex) {
- this(graph, name, vertexIndex, null, null);
+ protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex, DataType dataType) {
+ this(graph, name, vertexIndex, null, null, dataType);
}
- protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices) {
- super(graph, name, vertexIndex, inputVertices, outputVertices);
+ protected EpsilonSpyVertexImpl(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, DataType dataType) {
+ super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
}
public void setListener(Consumer listener) {
diff --git a/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java b/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java
index 8773c9a1..c645ead2 100644
--- a/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java
+++ b/src/test/java/ampcontrol/admin/service/control/mqtt/MockMqttClient.java
@@ -116,6 +116,46 @@ public void subscribe(String[] topicFilters, int[] qos, IMqttMessageListener[] m
subscribedTopics.addAll(Arrays.asList(topicFilters));
}
+ @Override
+ public IMqttToken subscribeWithResponse(String s) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String s, IMqttMessageListener iMqttMessageListener) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String s, int i) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String s, int i, IMqttMessageListener iMqttMessageListener) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String[] strings) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String[] strings, IMqttMessageListener[] iMqttMessageListeners) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String[] strings, int[] ints) throws MqttException {
+ return null;
+ }
+
+ @Override
+ public IMqttToken subscribeWithResponse(String[] strings, int[] ints, IMqttMessageListener[] iMqttMessageListeners) throws MqttException {
+ return null;
+ }
+
@Override
public void unsubscribe(String topicFilter) {
//Ignore
@@ -173,6 +213,11 @@ public void setManualAcks(boolean manualAcks) {
//Ignore
}
+ @Override
+ public void reconnect() throws MqttException {
+
+ }
+
@Override
public void messageArrivedComplete(int messageId, int qos) {
//Ignore
diff --git a/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java b/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java
index 236a9eba..e7178cc2 100644
--- a/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java
+++ b/src/test/java/ampcontrol/audio/Cnn2DInputProviderTest.java
@@ -11,7 +11,7 @@
import java.util.function.Supplier;
import java.util.stream.Stream;
-import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
/**
* Test cases for {@link Cnn2DInputProvider}.
@@ -38,7 +38,10 @@ public void getModelInput() {
final ProcessingResult res0 = proc.create(new SingletonDoubleInput(audioFrames[0]));
final INDArray expected0 = Nd4j.create(res0.stream().findFirst().get());
for(int channel = 0; channel < nrofChannels; channel++) {
- assertEquals("Incorrect model input!", expected0, inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel)));
+ final INDArray res = inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel));
+ assertArrayEquals("Incorrect shape!", expected0.shape(), res.shape());
+ assertArrayEquals("Incorrect model input!", expected0.reshape(expected0.length()).toDoubleVector(),
+ res.reshape(res.length()).toDoubleVector(), 1e-10);
}
mockBuffer.advance();
@@ -46,7 +49,10 @@ public void getModelInput() {
final ProcessingResult res1 = proc.create(new SingletonDoubleInput(audioFrames[1]));
final INDArray expected1 = Nd4j.create(res1.stream().findFirst().get());
for(int channel = 0; channel < nrofChannels; channel++) {
- assertEquals("Incorrect model input!", expected1, inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel)));
+ final INDArray res = inputProvider.getModelInput().get(NDArrayIndex.point(0), NDArrayIndex.point(channel));
+ assertArrayEquals("Incorrect shape!", expected1.shape(), res.shape());
+ assertArrayEquals("Incorrect model input!", expected1.reshape(expected1.length()).toDoubleVector(),
+ res.reshape(res.length()).toDoubleVector(), 1e-10);
}
}
diff --git a/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java b/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java
index d5f4b188..aff9c3f1 100644
--- a/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java
+++ b/src/test/java/ampcontrol/model/inference/EnsembleWeightedSumClassifierTest.java
@@ -62,19 +62,10 @@ void testMultiModelEnsemble(BiFunction normalizer) {
final INDArray result = classifier.classify();
mockEnsemble.forEach(mockClassifier -> mockClassifier.assertCalled(true));
- assertEquals("Incorrect result!", 1d, result.sum(1).getDouble(0), 1e-10);
+ assertEquals("Incorrect result!", 1d, result.sum(0).getDouble(0), 1e-10);
assertEquals("Incorrect result!", 3, result.argMax().getInt(0));
}
- /**
- * Fails with CPU backend
- */
- @Test
- public void testSum() {
- final INDArray test = Nd4j.create(new double[]{1, 2});
- assertEquals("Incorrect result!", Nd4j.create(new double[] {3}), test.sum(1));
- }
-
private static void testSingleModelEnsemble(BiFunction normalizer) {
final List mockEnsemble = Arrays.asList(
diff --git a/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java b/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java
index b04410f4..2fb60b91 100644
--- a/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java
+++ b/src/test/java/ampcontrol/model/inference/SpyClassifierTest.java
@@ -8,6 +8,7 @@
import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
@@ -44,8 +45,9 @@ public void spyInputTypes() {
incorrectClassifiction, // 12
incorrectClassifiction // 13
);
- final INDArray expectedCorrect = Nd4j.create(new double[]{6, 7, 8});
- final INDArray expectedIncorrect = Nd4j.create(new double[]{9, 10, 11});
+
+ final INDArray expectedCorrect = Nd4j.create(new double[]{6, 7, 8}).reshape(3, 1);
+ final INDArray expectedIncorrect = Nd4j.create(new double[]{9, 10, 11}).reshape(3, 1);
final SpyClassifier spyClassifier = new SpyClassifier(
new MockClassifier("mock", 0.7, classifications),
new CountingInputProvider(),
@@ -60,8 +62,8 @@ public void spyInputTypes() {
assertEquals("Incorrect hasInput after " + i +" samples!", i >= 11, listener.hasInput());
}
- assertEquals("Incorrect correctly classified input!", expectedCorrect, listener.getCorrectlyClassifiedInput());
- assertEquals("Incorrect incorrectly classified input!", expectedIncorrect, listener.getIncorrectlyClassifiedInput());
+ assertArrayEquals("Incorrect correctly classified input!", expectedCorrect.toDoubleVector(), listener.getCorrectlyClassifiedInput().toDoubleVector(), 1e-10);
+ assertArrayEquals("Incorrect incorrectly classified input!", expectedIncorrect.toDoubleVector(), listener.getIncorrectlyClassifiedInput().toDoubleVector(), 1e-10);
}
private static class CountingInputProvider implements ClassifierInputProvider {
diff --git a/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java b/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java
index 855e3720..44730d78 100644
--- a/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java
+++ b/src/test/java/ampcontrol/model/training/TrainingHarnessTest.java
@@ -76,7 +76,7 @@ public void eval() {
nrofEvalCalls++;
validations.stream().map(Validation::get)
- .forEach(ieOpt -> ieOpt.ifPresent(ie -> ie.eval(Nd4j.create(result), Nd4j.zeros(labels.size()))));
+ .forEach(ieOpt -> ieOpt.ifPresent(ie -> ie.eval(Nd4j.create(result).reshape(labels.size(), 1), Nd4j.zeros(labels.size(), 1))));
validations.forEach(Validation::notifyComplete);
}
diff --git a/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java b/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java
index 43c8ca6c..4d89c6e0 100644
--- a/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java
+++ b/src/test/java/ampcontrol/model/training/data/iterators/preprocs/CnnToManyToOneRnnPreProcessorTest.java
@@ -7,7 +7,6 @@
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
/**
* Test cases for {@link CnnToManyToOneRnnPreProcessor}.
@@ -32,7 +31,8 @@ public void preProcess() {
pp.preProcess(testSet);
assertArrayEquals("Incorrect feature shape!", expectedFeatureShape, testSet.getFeatures().shape());
assertArrayEquals("Incorrect labels shape!", expectedLabelsShape, testSet.getLabels().shape());
- assertEquals("Incorrect mask!", expectedLabelsMask, testSet.getLabelsMaskArray());
+ assertArrayEquals("Incorrect mask!", expectedLabelsMask.reshape(expectedLabelsMask.length()).toDoubleVector(),
+ testSet.getLabelsMaskArray().reshape(testSet.getLabelsMaskArray().length()).toDoubleVector(), 1e-10);
}
diff --git a/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java b/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java
index 28154c4a..b44c61fb 100644
--- a/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java
+++ b/src/test/java/ampcontrol/model/training/data/iterators/validate/ValidateCachingIter.java
@@ -106,10 +106,10 @@ static DataSetIterator createTestDataSetIterator(Supplier state) {
@Override
public synchronized void fetch(int numExamples) {
- final double[] features = new double[batchSize];
+ final double[][] features = new double[batchSize][1];
final double[][] labels = new double[batchSize][10];
IntStream.range(0, batchSize).forEach(batch -> {
- features[batch] = state.get().intValue();
+ features[batch][0] = state.get().intValue();
labels[batch][state.get().intValue() % 10] = 1;
state.get().increment();
});
diff --git a/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java b/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java
index a80e8572..d904f4bf 100644
--- a/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java
+++ b/src/test/java/ampcontrol/model/training/listen/ActivationContributionTest.java
@@ -63,7 +63,7 @@ public void dense() {
activationContribution.onEpochEnd(graph);
final INDArray expected = Nd4j.create(new double[]{0, 2}); // gradient is 0,2 and "activation" is 1,1
probe.assertNrofCalls(1);
- assertEquals("Incorrect output!", expected, probe.last);
+ assertArrayEquals("Incorrect output!", expected.toDoubleVector(), probe.last.toDoubleVector(), 1e-10);
}
/**
@@ -93,7 +93,7 @@ public void convWithBias() {
final Probe probe = new Probe();
final ActivationContribution activationContribution = new ActivationContribution(layerName, probe);
graph.addListeners(activationContribution);
- final int nrofOutputs = graph.layerSize(outputName);
+ final long nrofOutputs = graph.layerSize(outputName);
final INDArray feature = Nd4j.linspace(-10, 10, heigh * width * depth * miniBatchSize).reshape(miniBatchSize, depth, heigh, width);
final INDArray label = Nd4j.linspace(-2, 2, nrofOutputs * miniBatchSize).reshape(miniBatchSize, nrofOutputs);
@@ -101,7 +101,7 @@ public void convWithBias() {
activationContribution.onEpochEnd(graph);
probe.assertNrofCalls(1);
- assertArrayEquals("Incorrect output shape", new long[] {1, nOut}, probe.last.shape());
+ assertArrayEquals("Incorrect output shape", new long[] {nOut}, probe.last.shape());
}
@NotNull
diff --git a/src/test/java/ampcontrol/model/training/listen/MockModel.java b/src/test/java/ampcontrol/model/training/listen/MockModel.java
index 789c4a20..e6a9127c 100644
--- a/src/test/java/ampcontrol/model/training/listen/MockModel.java
+++ b/src/test/java/ampcontrol/model/training/listen/MockModel.java
@@ -6,8 +6,8 @@
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.primitives.Pair;
import java.util.Collection;
import java.util.Map;
@@ -169,4 +169,9 @@ public void clear() {
public void applyConstraints(int iteration, int epoch) {
//Ignore
}
+
+ @Override
+ public void close() {
+ //Ignore
+ }
}
\ No newline at end of file
diff --git a/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java b/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java
index 67ecb51d..0405d940 100644
--- a/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java
+++ b/src/test/java/ampcontrol/model/training/model/GenericModelHandleTest.java
@@ -91,7 +91,9 @@ private static class PosNegDataSetIter implements DataSetIterator {
public DataSet next(int num) {
final double[] terms1 = rng.ints(num, -10, 10).mapToDouble(i -> i).toArray();
final double[][] evenOrOdd = DoubleStream.of(terms1).mapToInt(d -> (int) d).mapToObj(i -> i > 0 ? positive : negative).collect(Collectors.toList()).toArray(new double[][]{});
- return new DataSet(Nd4j.create(terms1).transpose(), Nd4j.create(evenOrOdd));
+
+
+ return new DataSet(Nd4j.create(terms1).reshape(terms1.length, 1), Nd4j.create(evenOrOdd));
}
@Override
diff --git a/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java b/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java
index 072af0e9..418007a3 100644
--- a/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java
+++ b/src/test/java/ampcontrol/model/training/model/description/MutatingConv2dFactoryTest.java
@@ -1,6 +1,7 @@
package ampcontrol.model.training.model.description;
import org.junit.Test;
+import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java b/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java
index 69ca92c3..e363a3c5 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/GraphUtils.java
@@ -53,7 +53,9 @@ public static ComputationGraph getCnnGraph(String conv1Name, String conv2Name, S
.addLayer(conv1Name, new Convolution2D.Builder(3, 3)
.nOut(10)
.build(), inputName)
- .addLayer(batchNormName, new BatchNormalization.Builder().build(), conv1Name)
+ .addLayer(batchNormName, new BatchNormalization.Builder()
+ .useLogStd(false)
+ .build(), conv1Name)
.addLayer(conv2Name, new Convolution2D.Builder(1, 1)
.nOut(5)
.build(), batchNormName)
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java b/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java
index ab92eb63..1dcce1eb 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/crossover/graph/ParameterTransferCrossover.java
@@ -61,6 +61,14 @@ public void crossSimpleDense() {
final ComputationGraph graphFirst = GraphUtils.getGraph(first[0], first[1], first[2]);
final ComputationGraph graphSecond = GraphUtils.getGraph(second[0], second[1], second[2]);
+ crossover(
+ InputType.feedForward(33),
+ graphFirst,
+ first[1],
+ graphSecond,
+ second[0]
+ );
+
for (String firstCp : first) {
for (String secondCp : second) {
crossover(
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java b/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java
index dab5f2e5..c4c4e295 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/fitness/InstrumentEpsilonSpiesTest.java
@@ -61,6 +61,6 @@ public void instrumentAndFit() {
assertFalse("Did not expect comparator!", registry.get(graph).apply("second").isPresent());
assertTrue("Expected comparator!", registry.get(graph).apply("third").isPresent());
- graph.fit(new DataSet(arr, Nd4j.ones(1)));
+ graph.fit(new DataSet(arr, Nd4j.ones(1,1)));
}
}
\ No newline at end of file
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java b/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java
index 4821b22e..a89b629e 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/mutate/NoutMutationTest.java
@@ -12,13 +12,14 @@
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
+import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import java.util.stream.Stream;
import static junit.framework.TestCase.assertEquals;
+import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.XENT;
/**
* Test cases for {@link NoutMutation}
@@ -416,7 +417,7 @@ public void mutateAfterResBeforeSizeTransparentFork() {
.addVertex("rbMvInput0", new MergeVertex(), "fb1_branch_0_0", "fb1_branch_1_0")
.addLayer("3", new Convolution2D.Builder().convolutionMode(ConvolutionMode.Same).nOut(7).build(), "rbMvInput0")
.addLayer("4", new GlobalPoolingLayer(), "3")
- .addLayer("output", new CenterLossOutputLayer.Builder().nOut(2).lossFunction(new LossBinaryXENT()).build(), "4")
+ .addLayer("output", new CenterLossOutputLayer.Builder().nOut(2).activation(Activation.SIGMOID).lossFunction(XENT).build(), "4")
.build());
graph.init();
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java b/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java
index 3166cb03..05b03041 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/mutate/layer/blockfunctions/SpyFunctionTest.java
@@ -53,6 +53,6 @@ public void weightInit() {
spiedLayers.stream()
.filter(layer -> layer instanceof BaseLayer)
.map(layer -> (BaseLayer)layer)
- .forEach(baseLayer -> assertEquals("Incorrect weight init!", expected, baseLayer.getWeightInit()));
+ .forEach(baseLayer -> assertEquals("Incorrect weight init!", expected.getWeightInitFunction(), baseLayer.getWeightInitFn()));
}
}
\ No newline at end of file
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java b/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java
index 8b65b0cc..8a385bfd 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/mutate/util/ForwardOfTest.java
@@ -95,13 +95,13 @@ public void childrenOfComputationGraph() {
final Graph graph = new ForwardOf(compGraph);
assertEquals("Incorrect children!",
- Arrays.asList("vertex1_0", "vertex1_1"), graph.children("input1").collect(Collectors.toList()));
+ Arrays.asList("vertex1_0", "vertex1_1"), graph.children("input1").sorted().collect(Collectors.toList()));
assertEquals("Incorrect children!",
Collections.singletonList("vertex2_0"), graph.children("input2").collect(Collectors.toList()));
assertEquals("Incorrect children!",
- Arrays.asList("vertex3_0", "vertex3_1"), graph.children("vertex1_0").collect(Collectors.toList()));
+ Arrays.asList("vertex3_0", "vertex3_1"), graph.children("vertex1_0").sorted().collect(Collectors.toList()));
assertEquals("Incorrect children!",
Collections.emptyList(), graph.children("vertex2_0").collect(Collectors.toList()));
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java b/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java
index cfa80fbc..9f20ad45 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/transfer/MergeTransferBufferTest.java
@@ -22,17 +22,17 @@ public class MergeTransferBufferTest {
public void transferFirstSource1D() {
final long length1 = 5;
final long length2 = 8;
- final INDArray source1 = Nd4j.linspace(0, length1 - 1, length1);
+ final INDArray source1 = Nd4j.linspace(0, length1 - 1, length1).reshape(1, length1);
final INDArray target1 = Nd4j.zeros(1, source1.length() - 2);
- final INDArray source2 = Nd4j.linspace(source1.length(), source1.length() + length2 - 1, length2);
+ final INDArray source2 = Nd4j.linspace(source1.length(), source1.length() + length2 - 1, length2).reshape(1, length2);
final INDArray target2 = Nd4j.zeros(1, source2.length());
- final INDArray mergedSource = Nd4j.linspace(0, length1 + length2 - 1, length1 + length2);
+ final INDArray mergedSource = Nd4j.linspace(0, length1 + length2 - 1, length1 + length2).reshape(1, length1+length2);
final INDArray mergedTarget = Nd4j.zeros(1, length1 + length2 - 2);
final int[] expectedIndexes = {0, 2, 4, 1, 3};
- final INDArray expectedTarget1 = source1.get(new SpecifiedIndex(expectedIndexes)).get(NDArrayIndex.interval(0, target1.length())).transpose();
+ final INDArray expectedTarget1 = source1.get(NDArrayIndex.point(0), new SpecifiedIndex(expectedIndexes)).get(NDArrayIndex.interval(0, target1.length())).reshape(1, target1.length() );
final INDArray expectedMergedTarget = Nd4j.concat(1, expectedTarget1, source2);
final TransferRegistry registry = new TransferRegistry();
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java b/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java
index a41751bd..9b4a0dfa 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/transfer/ParameterTransferNoutMutationTest.java
@@ -24,10 +24,7 @@
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
+import java.util.*;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@@ -251,7 +248,7 @@ public void decreaseForkedConv() {
newGraph.init();
newGraph.output(Nd4j.randn(new long[]{1, 3, 33, 33}));
- final int oldNout = graph.layerSize(fork1NameToMutate);
+ final int oldNout = (int)graph.layerSize(fork1NameToMutate);
// Drop the first oldNout - newNout elements from fork1NameToMutate
final int[] orderToKeep = IntStream.range(0, oldNout)
@@ -271,7 +268,7 @@ public void decreaseForkedConv() {
final int[] expectedToKeep = IntStream.concat(
IntStream.of(orderToKeep).limit(newNout),
- IntStream.range(0, graph.layerSize(fork2Name)).map(i -> i + oldNout)).toArray();
+ IntStream.range(0, (int)graph.layerSize(fork2Name)).map(i -> i + oldNout)).toArray();
final INDArray sourceAfter = graph.getLayer(afterName).getParam(GraphUtils.W);
final INDArray targetAfter = mutatedGraph.getLayer(afterName).getParam(GraphUtils.W);
assertDims(1, expectedToKeep, sourceAfter, targetAfter);
@@ -417,7 +414,7 @@ ComputationGraph decreaseDecreaseNout(
final INDArray sourceBias = graph.getLayer(mutationName).getParam(GraphUtils.B);
final INDArray targetBias = mutatedGraph.getLayer(mutationName).getParam(GraphUtils.B);
- assertDims(1, orderToKeepFirst, sourceBias, targetBias);
+ assertDims(0, orderToKeepFirst, sourceBias, targetBias);
final INDArray sourceNext = graph.getLayer(nextMutationName).getParam(GraphUtils.W);
final INDArray targetNext = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.W);
@@ -425,7 +422,7 @@ ComputationGraph decreaseDecreaseNout(
final INDArray sourceNextBias = graph.getLayer(nextMutationName).getParam(GraphUtils.B);
final INDArray targetNextBias = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.B);
- assertDims(1, orderToKeepSecond, sourceNextBias, targetNextBias);
+ assertDims(0, orderToKeepSecond, sourceNextBias, targetNextBias);
final INDArray sourceOutput = graph.getLayer(afterName).getParam(GraphUtils.W);
final INDArray targetOutput = mutatedGraph.getLayer(afterName).getParam(GraphUtils.W);
@@ -451,7 +448,7 @@ ComputationGraph decreaseIncreaseNout(
final long mutationNewNout = 5;
final long nextMutationNewNout = 9;
- final int nextMutationPrevNout = graph.layerSize(nextMutationName);
+ final int nextMutationPrevNout = (int)graph.layerSize(nextMutationName);
final double nextMutationNewVal = 666d; // Is this obtainable somehow?
final ComputationGraph newGraph = new ComputationGraph(new NoutMutation(
@@ -477,7 +474,7 @@ ComputationGraph decreaseIncreaseNout(
final INDArray sourceBias = graph.getLayer(mutationName).getParam(GraphUtils.B);
final INDArray targetBias = mutatedGraph.getLayer(mutationName).getParam(GraphUtils.B);
- assertDims(1, orderToKeepFirst, sourceBias, targetBias);
+ assertDims(0, orderToKeepFirst, sourceBias, targetBias);
final INDArray sourceNext = graph.getLayer(nextMutationName).getParam(GraphUtils.W);
final INDArray targetNext = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.W);
@@ -486,8 +483,8 @@ ComputationGraph decreaseIncreaseNout(
final INDArray sourceNextBias = graph.getLayer(nextMutationName).getParam(GraphUtils.B);
final INDArray targetNextBias = mutatedGraph.getLayer(nextMutationName).getParam(GraphUtils.B);
- assertDims(1, IntStream.range(0, (int) nextMutationNewNout).toArray(), sourceNextBias, targetNextBias);
- assertScalar(1, nextMutationPrevNout, 0, targetNextBias);
+ assertDims(0, IntStream.range(0, (int) nextMutationNewNout).toArray(), sourceNextBias, targetNextBias);
+ assertScalar(0, nextMutationPrevNout, 0, targetNextBias);
final INDArray sourceOutput = graph.getLayer(afterName).getParam(GraphUtils.W);
final INDArray targetOutput = mutatedGraph.getLayer(afterName).getParam(GraphUtils.W);
@@ -511,10 +508,18 @@ private static void assertDims(
final long[] shapeTarget = target.shape();
final long[] shapeSource = source.shape();
final int[] dims = IntStream.range(0, shapeTarget.length).filter(i -> i != dim).toArray();
- for (int elemInd = 0; elemInd < Math.min(shapeTarget[dim], shapeSource[dim]); elemInd++) {
- assertEquals("Incorrect target for element index " + elemInd + "!",
- source.tensorAlongDimension(orderToKeep[elemInd], dims),
- target.tensorAlongDimension(elemInd, dims));
+ if(dims.length == 0) {
+ for (int elemInd = 0; elemInd < Math.min(shapeTarget[dim], shapeSource[dim]); elemInd++) {
+ assertEquals("Incorrect target for element index " + elemInd + "!",
+ source.getDouble(orderToKeep[elemInd]),
+ target.getDouble(elemInd), 1e-10);
+ }
+ } else {
+ for (int elemInd = 0; elemInd < Math.min(shapeTarget[dim], shapeSource[dim]); elemInd++) {
+ assertEquals("Incorrect target for element index " + elemInd + "!",
+ source.tensorAlongDimension(orderToKeep[elemInd], dims),
+ target.tensorAlongDimension(elemInd, dims));
+ }
}
}
@@ -567,9 +572,19 @@ private static void assertDoubleDims(
for (int elemInd0 = 0; elemInd0 < Math.min(shapeTarget[outputDim], shapeSource[outputDim]); elemInd0++) {
for (int elemInd1 = 0; elemInd1 < Math.min(shapeTarget[inputDim], shapeSource[inputDim]); elemInd1++) {
- assertEquals("Incorrect target output for element index " + elemInd0 + ", " + elemInd1 + "!",
- source.tensorAlongDimension(expectedElementOrderDim0[elemInd0], firstTensorDims).tensorAlongDimension(expectedElementOrderDim1[elemInd1], secondTensorDims),
- target.tensorAlongDimension(elemInd0, firstTensorDims).tensorAlongDimension(elemInd1, secondTensorDims));
+ // Somewhat convoluted code to get old behaviour where tensorAlongDimension could return a 0d array from a 1d array
+ final INDArray sourcedim0 = source.tensorAlongDimension(expectedElementOrderDim0[elemInd0], firstTensorDims);
+ final INDArray targetdim0 = target.tensorAlongDimension(elemInd0, firstTensorDims);
+
+ if(shapeSource.length == 2) {
+ assertEquals("Incorrect target output for element index " + elemInd0 + ", " + elemInd1 + "!",
+ sourcedim0.getDouble(expectedElementOrderDim1[elemInd1]),
+ targetdim0.getDouble(elemInd1), 1e-10);
+ } else {
+ assertEquals("Incorrect target output for element index " + elemInd0 + ", " + elemInd1 + "!",
+ sourcedim0.tensorAlongDimension(expectedElementOrderDim1[elemInd1], secondTensorDims),
+ targetdim0.tensorAlongDimension(elemInd1, secondTensorDims));
+ }
}
}
}
diff --git a/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java b/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java
index dc16ccfe..fc7d91f6 100644
--- a/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java
+++ b/src/test/java/ampcontrol/model/training/model/evolve/transfer/SingleTransferTaskTest.java
@@ -55,7 +55,10 @@ public void applySizeDecrease() {
}
}});
- assertEquals("Incorrect output!", expected, target);
+ assertEquals("Lengths does not match!", expected.length(), target.length());
+ for(int i = 0; i < expected.length(); i++) {
+ assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), target.getDouble(i), 1e-6);
+ }
}
/**
@@ -89,7 +92,10 @@ public void applySizeDecreaseEdges() {
{{0}, {3}}},
});
- assertEquals("Incorrect output!", expected, target);
+ assertEquals("Lengths does not match!", expected.length(), target.length());
+ for(int i = 0; i < expected.length(); i++) {
+ assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), target.getDouble(i), 1e-10);
+ }
}
/**
@@ -143,7 +149,11 @@ public int compare(Integer e1, Integer e2) {
{{{3}}, {{2}}, {{0}}},
});
- assertEquals("Incorrect output!", expected, target);
+ assertEquals("Lengths does not match!", expected.length(), target.length());
+ for(int i = 0; i < expected.length(); i++) {
+ assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), target.getDouble(i), 1e-10);
+ }
+
}
/**
@@ -246,8 +256,8 @@ public void applyMaskCoupled() {
for (int elemInd0 = 0; elemInd0 < shapeTarget[0]; elemInd0++) {
for (int elemInd1 = 0; elemInd1 < shapeTarget[1]; elemInd1++) {
assertEquals("Incorrect target for element index " + elemInd0 + "," + elemInd1 + "!",
- source.tensorAlongDimension(orderToKeep[elemInd0], 1).tensorAlongDimension(orderToKeep[elemInd1], 0),
- target.tensorAlongDimension(elemInd0, 1).tensorAlongDimension(elemInd1, 0));
+ source.tensorAlongDimension(orderToKeep[elemInd0], 1).getDouble(orderToKeep[elemInd1]),
+ target.tensorAlongDimension(elemInd0, 1).getDouble(elemInd1), 1e-10);
}
}
diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java
index 9094efb5..1c8fadf6 100644
--- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java
+++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/DummyOutputLayer.java
@@ -5,7 +5,6 @@
import org.deeplearning4j.nn.conf.distribution.BinomialDistribution;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
-import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
@@ -29,8 +28,7 @@ public BlockInfo addLayers(BuilderAdapter builder, BlockInfo info) {
.nOut(info.getPrevNrofOutputs())
.activation(new ActivationIdentity())
.biasInit(0)
- .weightInit(WeightInit.DISTRIBUTION)
- .dist(new BinomialDistribution(1,1)) // 100% probability of 1
+ .weightInit(new BinomialDistribution(1,1)) // 100% probability of 1
.build();
return builder.layer(info,output);
}
diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java
index 86107a18..ba5d1798 100644
--- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java
+++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ForkAggTest.java
@@ -43,7 +43,7 @@ public String name() {
@Test
public void addLayers() {
final String inputName = "input";
- final INDArray input = Nd4j.create(new double[] {-1.23, 2.34}).transpose();
+ final INDArray input = Nd4j.create(new double[] {-1.23, 2.34}).reshape(2, 1);
final double scal0 = 10.11;
final double scal1 = -13.3;
@@ -66,6 +66,10 @@ public void addLayers() {
DummyOutputLayer.setEyeOutput(graph);
final INDArray expected = Nd4j.hstack(input.mul(scal0), input.mul(scal1));
- assertEquals("Incorrect output!", expected, graph.output(input)[0]);
+ final INDArray result = graph.output(input)[0];
+ assertEquals("Lengths does not match!", expected.length(), result.length());
+ for(int i = 0; i < expected.length(); i++) {
+ assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), result.getDouble(i), 1e-6);
+ }
}
}
\ No newline at end of file
diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java
index 56e08a4b..2b6e58ff 100644
--- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java
+++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/MinMaxPoolTest.java
@@ -1,14 +1,13 @@
package ampcontrol.model.training.model.layerblocks.graph;
import ampcontrol.model.training.model.layerblocks.LayerBlockConfig;
+import junit.framework.TestCase;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import static org.junit.Assert.assertEquals;
-
/**
* Test cases for {@link MinMaxPool}
*
@@ -39,6 +38,12 @@ public void addLayers() {
DummyOutputLayer.setEyeOutput(graph);
- assertEquals("Incorrect output", Nd4j.create(expected), graph.output(input)[0]);
+ final INDArray expectedarr = Nd4j.create(expected);
+ final INDArray result = graph.output(input)[0];
+ TestCase.assertEquals("Lengths does not match!", expectedarr.length(), result.length());
+ for(int i = 0; i < expectedarr.length(); i++) {
+ TestCase.assertEquals("Incorrect output for element " + i + "!", expectedarr.getDouble(i), result.getDouble(i), 1e-10);
+ }
+
}
}
\ No newline at end of file
diff --git a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java
index 2ceb0d3b..34977c1b 100644
--- a/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java
+++ b/src/test/java/ampcontrol/model/training/model/layerblocks/graph/ResBlockTest.java
@@ -26,7 +26,7 @@ public class ResBlockTest {
@Test
public void addLayers() {
final String inputName = "input";
- final INDArray input = Nd4j.create(new double[] {77});
+ final INDArray input = Nd4j.create(new double[] {77}).reshape(1,1);
final ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder().graphBuilder()
.addInputs(inputName)
@@ -43,6 +43,11 @@ public void addLayers() {
graphBuilder.setOutputs(output.getInputsNames());
final ComputationGraph graph = new ComputationGraph(graphBuilder.build());
graph.init();
- assertEquals("Incorrect output!", input.mul(2), graph.output(input)[0]);
+ final INDArray expected = input.mul(2);
+ final INDArray result = graph.output(input)[0];
+ assertEquals("Lengths does not match!", expected.length(), result.length());
+ for(int i = 0; i < expected.length(); i++) {
+ assertEquals("Incorrect output for element " + i + "!", expected.getDouble(i), result.getDouble(i), 1e-10);
+ }
}
}
\ No newline at end of file
diff --git a/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java b/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java
index 5074ae7c..d656ecc0 100644
--- a/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java
+++ b/src/test/java/ampcontrol/model/training/model/vertex/ChannelMultVertexImplTest.java
@@ -3,6 +3,7 @@
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
@@ -51,7 +52,7 @@ public void doBackward() {
epsilon.putScalar(4, 0d);
epsilon.putScalar(110, 0d);
epsilon.putScalar(7300, 0d);
- final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0);
+ final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0, DataType.FLOAT16);
toTest.setInputs(new INDArray[]{convInput, gates});
toTest.setEpsilon(epsilon);
final INDArray[] result = toTest.doBackward(false, wsMgr).getSecond();
@@ -68,7 +69,7 @@ private static void testDoForward(int batchSize, int nrofChannels) {
gates.putScalar(oneZeroedChannel, 0);
gates.putScalar(11, 0);
gates.putScalar(13, 0);
- final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0);
+ final GraphVertex toTest = new ChannelMultVertexImpl(null, "test", 0, DataType.FLOAT16);
toTest.setInputs(new INDArray[]{convInput, gates});
final INDArray result = toTest.doForward(false, wsMgr);
assertEquals("Incorrect mean!", gates.mean(1).getDouble(0), result.mean(1).getDouble(0), 1e-10);