From 29c4116da30b8b37a0ac956abf5a273bdbb9f415 Mon Sep 17 00:00:00 2001 From: AlexDBlack Date: Sat, 8 Feb 2020 00:16:11 +1100 Subject: [PATCH 1/6] First steps for SameDiff graph optimizer Signed-off-by: AlexDBlack --- .../org/nd4j/autodiff/samediff/SameDiff.java | 2 + .../samediff/optimize/GraphOptimizer.java | 101 ++++++++++++++++++ .../samediff/optimize/OptimizationConfig.java | 7 ++ .../autodiff/samediff/optimize/Optimizer.java | 26 +++++ .../samediff/optimize/OptimizerSet.java | 13 +++ .../optimizations/BaseOptimizerSet.java | 52 +++++++++ .../ConstantFunctionOptimizations.java | 101 ++++++++++++++++++ .../optimization/TestOptimization.java | 42 ++++++++ 8 files changed, 344 insertions(+) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index de421b2970b9..f4abec1b087e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -118,7 +118,9 @@ public class SameDiff extends SDBaseOps { @Getter private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID + @Getter //TODO shouldn't be in public API private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true); + @Getter //TODO shouldn't be in public API private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true); private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java new file mode 100644 index 000000000000..3bbdd5baa56c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java @@ -0,0 +1,101 @@ +package org.nd4j.autodiff.samediff.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations; + +import java.util.Arrays; +import java.util.List; + +/** + * + * @author Alex Black + */ +@Slf4j +public class GraphOptimizer { + + public static List defaultOptimizations(){ + return Arrays.asList( + new ConstantFunctionOptimizations() + ); + } + + public static SameDiff optimize(SameDiff graph){ + return optimize(graph, defaultOptimizations()); + } + + public static SameDiff optimize(SameDiff graph, List optimizations){ + SameDiff sd = graph.dup(); + + ArrayHolder cArr = sd.getConstantArrays(); + ArrayHolder vArr = sd.getVariablesArrays(); + + OptimizationConfig config = new OptimizationConfig(); //TODO + + for( int i=0; i<3; i++ ) { //Run multiple times - one run isn't enough, as some more optimizations may need to be applied to the output of earlier optimizations + for (OptimizerSet s : optimizations) { + List l = s.getOptimizers(); + for(Optimizer o : l ){ + for(SameDiffOp op : sd.getOps().values()) { + boolean applied = o.checkAndApply(sd, config, op, cArr, vArr); + if(applied) { + log.info("Operation was applied: "); + } + } + } + } + } + + int constBefore = 0; + int constAfter = 0; + int varBefore = 0; + int varAfter = 0; + int arrBefore = 0; + int arrAfter = 0; + + for(SDVariable v : graph.variables()){ + switch(v.getVariableType()){ + case VARIABLE: + varBefore++; + break; + case CONSTANT: + constBefore++; + break; + case ARRAY: + arrBefore++; + break; + case PLACEHOLDER: + break; + } + } + + for(SDVariable v : sd.variables()){ + switch(v.getVariableType()){ + case VARIABLE: + varAfter++; + break; + case CONSTANT: + constAfter++; + break; + case ARRAY: + arrAfter++; + break; + case PLACEHOLDER: + break; + } + } + + + log.info("Total variables: {} before, {} after", graph.getVariables().size(), sd.getVariables().size()); + log.info("Constant variables: {} before, {} after", constBefore, constAfter); + log.info("Array type variables: {} before, {} after", arrBefore, arrAfter); + log.info("Variable type variables: {} before, {} after", varBefore, varAfter); + log.info("Ops: {} before, {} after", graph.getOps().size(), sd.getOps().size()); + + return sd; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java new file mode 100644 index 000000000000..2bc3bf1a087d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java @@ -0,0 +1,7 @@ +package org.nd4j.autodiff.samediff.optimize; + +import java.util.Properties; + +public class OptimizationConfig extends Properties { + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java new file mode 100644 index 000000000000..c84e9708d81c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java @@ -0,0 +1,26 @@ +package org.nd4j.autodiff.samediff.optimize; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; + +import java.util.Properties; + +/** + * + * @author Alex Black + */ +public interface Optimizer { + + /** + * + * @param sd Current SameDiff instance to optimize + * @param optimizationConfig Optimization configuration + * @param op Operation to check for optimization + * @param constantArrays + * @param variablesArrays + * @return True if the optimization was applied + */ + boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java new file mode 100644 index 000000000000..a971c7174967 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java @@ -0,0 +1,13 @@ +package org.nd4j.autodiff.samediff.optimize; + +import java.util.List; + +/** + * + * @author Alex Black + */ +public interface OptimizerSet { + + List getOptimizers(); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java new file mode 100644 index 000000000000..ce1787bdb9b5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java @@ -0,0 +1,52 @@ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizerSet; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.List; + +/** + * + * @author Alex Black + */ +@Slf4j +public abstract class BaseOptimizerSet implements OptimizerSet { + + + @Override + public List getOptimizers() { + Method[] methods = this.getClass().getDeclaredMethods(); + List out = new ArrayList<>(methods.length); + for(Method m : methods){ + int modifiers = m.getModifiers(); + Class retType = m.getReturnType(); + if(retType != null && Modifier.isPublic(modifiers) && Optimizer.class.isAssignableFrom(retType) ){ + try { + Optimizer o = (Optimizer) m.invoke(null); + out.add(o); + } catch (IllegalAccessException | InvocationTargetException e) { + log.warn("Could not create optimizer from method: {}", m, e); + } + } + } + + Class[] declaredClasses = this.getClass().getDeclaredClasses(); + for(Class c : declaredClasses){ + int modifiers = c.getModifiers(); + if(Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) && Optimizer.class.isAssignableFrom(c)){ + try{ + out.add((Optimizer) c.newInstance()); + } catch (IllegalAccessException | InstantiationException e) { + log.warn("Could not create optimizer from inner class: {}", c, e); + } + } + } + + return out; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java new file mode 100644 index 000000000000..db12ad4ac4b0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java @@ -0,0 +1,101 @@ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; +import java.util.Properties; + +/** + * This set of optimizations looks for functions that are applied to constants, and "pre executes" them, so they don't have + * to be calculated (returning the same value) on each run. + * + * @author Alex Black + */ +public class ConstantFunctionOptimizations extends BaseOptimizerSet { + + public static final String CONSTANT_FN_FOLDING_MAX_SIZE = "optimizer.constants.function.max.output.size"; + public static final long CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT = 4 * 1024 * 1024; //4MB + + public static class FoldConstantFunctions implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + //TODO This function needs to check for non-deterministic ops - i.e., random ops - and not apply the optimization to these + + List in = op.getInputsToOp(); + if (in == null || in.isEmpty()) + return false; + for (String s : in) { + if (!sd.getVariable(s).isConstant()) + return false; + } + + long maxSizeToApply = Long.parseLong(optimizationConfig.getProperty(CONSTANT_FN_FOLDING_MAX_SIZE, String.valueOf(CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT))); + //Apply the optimization: + DifferentialFunction df = op.getOp(); + df.clearArrays(); + for (int i = 0; i < in.size(); i++) { + String s = in.get(i); + INDArray arr = sd.getVariable(s).getArr(); + if (df instanceof CustomOp) { + ((CustomOp) df).addInputArgument(arr); + } else { + if (i == 0) + ((Op) df).setX(arr); + else + ((Op) df).setY(arr); + } + } + + INDArray[] outputs; + if (df instanceof CustomOp) { + CustomOp o = (CustomOp) df; + Nd4j.exec(o); + outputs = new INDArray[o.numOutputArguments()]; + for (int j = 0; j < outputs.length; j++) { + outputs[j] = o.getOutputArgument(j); + } + } else { + Op o = (Op) df; + Nd4j.exec(o); + outputs = new INDArray[]{o.z()}; + } + long sizeCount = 0; + for (INDArray i : outputs) { + if (!i.dataType().isNumerical()) + continue; + sizeCount += i.length() * i.dataType().width(); + } + + if (sizeCount > maxSizeToApply) + return false; + + //Convert outputs to constants + List outputNames = op.getOutputsOfOp(); + for(int i=0; i Date: Sat, 8 Feb 2020 18:12:32 +1100 Subject: [PATCH 2/6] Next set of optimizations Signed-off-by: AlexDBlack --- .../ConstantFunctionOptimizations.java | 9 +- .../CuDNNFunctionOptimizations.java | 93 +++++++++++++++++++ .../IdentityFunctionOptimizations.java | 32 +++++++ .../optimizations/OptimizationUtils.java | 46 +++++++++ .../ShapeFunctionOptimizations.java | 77 +++++++++++++++ 5 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java index db12ad4ac4b0..435d8c86ecfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java @@ -89,12 +89,9 @@ public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDif sd.getVariables().get(n).setOutputOfOp(null); } - //Remove the op: TODO Make util method? - sd.getOps().remove(df.getOwnName()); - for(String s : op.getInputsToOp()){ - Variable v = sd.getVariables().get(s); - v.getInputsForOp().remove(op.getName()); - } + //Remove the op + OptimizationUtils.removeOp(sd, df.getOwnName()); + return true; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java new file mode 100644 index 000000000000..b792e2d1c138 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java @@ -0,0 +1,93 @@ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; +import java.util.Properties; + +public class CuDNNFunctionOptimizations extends BaseOptimizerSet { + + protected static final boolean isCudaBackend; + + static { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); +// isCudaBackend = "CUDA".equalsIgnoreCase(backend); + isCudaBackend = true; //For testing only + } + + /** + * https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html#tensor-layout + * For tensor cores: we want NHWC layout: + * Section 7.3.1 + * "Layout choice has an effect on performance, as convolutions implemented for Tensor Cores require NHWC layout and are fastest when input tensors are laid out in NHWC." + * "To maximize performance, we recommend using NHWC tensor layout." + * + * Asn for weights format: cuDNN docs are vague - but TF uses NCHW+OIHW or NHWC+OHWI + */ + public static class CudnnConv2dNCHWtoNHWCConversion implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + if(!(op.getOp() instanceof Conv2D)) + return false; + + Conv2D c2d = (Conv2D)op.getOp(); + boolean weightsCorrect = false; + boolean activationsCorrect = c2d.getConfig().isNHWC(); + + if(activationsCorrect && weightsCorrect) + return false; //Nothing to do here + + //Now, we need to do 2 things + //(a) replace NCHW to NHWC for input + //(b) set weight format to OHWI (OYXI) + + List inputs = op.getInputsToOp(); + String wArgName = inputs.get(1); + + //Step 1 - replace activations + if(!activationsCorrect) { + String inArgName = inputs.get(0); + SDVariable in = sd.getVariable(inArgName); + //Replace [in -> Conv2d(NCHW) -> out] with [in -> permute -> Conv2d(NHWC) -> permute -> out] + String newName = in.name() + "_cudnn_nchw_to_nhwc"; + OptimizationUtils.replaceOpInputsWith(sd, in.name(), newName); + SDVariable nhwc = in.permute(0, 2, 3, 1).rename(newName); //NCHW to NHWC + + SDVariable outNhwc = sd.getVariable(op.getOutputsOfOp().get(0)); + String newName2 = outNhwc.name() + "_cudnn_nhwc_to_nchw"; + SDVariable outNchw = outNhwc.permute(0, 3, 1, 2).rename(newName2); //NHWC to NCHW + + OptimizationUtils.replaceOpInputsWith(sd, outNhwc.name(), outNchw.name()); + + c2d.getConfig().isNHWC(true); + } + + //Step 2 - replace YXIO weights (default) with OYXI weights + //We'll just add a permute here, and let other optimizer steps fix the (variable -> permute -> op ==> permutedVariable -> op) part + if(!weightsCorrect) { + SDVariable w = sd.getVariable(wArgName); + String newWname = w.name() + "_cudnn_yxio_to_oyxi"; + OptimizationUtils.replaceOpInputsWith(sd, w.name(), newWname); + SDVariable wPermuted = w.permute(3, 0, 1, 2).rename(newWname); + + + //TODO once config supports weight layout, set it here + } + + + return true; + } + } + + /* + TODO: Also do pooling2d, batchnorm, etc + */ + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java new file mode 100644 index 000000000000..d32f14d34da6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java @@ -0,0 +1,32 @@ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.Optimizer; + +import java.util.Properties; + +public class IdentityFunctionOptimizations extends BaseOptimizerSet { + + /** + * Remove permute(0,1,2,...,rank-1) as this is a no-op + */ + public static class RemoveIdentityPermute implements Optimizer { + + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } + + /** + * Remove identity(x) + */ + public static class RemoveIdentityOps implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java new file mode 100644 index 000000000000..98fd2303324e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java @@ -0,0 +1,46 @@ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class OptimizationUtils { + + private OptimizationUtils(){ } + + public static void replaceOpInputsWith(SameDiff sd, @NonNull String replaceInput, @NonNull String newInput){ + if(replaceInput.equals(newInput)) + return; + + //Update op input structure: Replace all instances replaceInput->X with newInput->X + Collection ops = sd.getOps().values(); + for(SameDiffOp o : ops){ + List l = o.getInputsToOp(); + while(l != null && l.contains(replaceInput)){ + int idx = l.indexOf(replaceInput); + l.set(idx, newInput); + } + } + + //Update variable structure + Variable v = sd.getVariables().get(replaceInput); + Variable v2 = sd.getVariables().get(newInput); + //NOTE: this only works if we carefully control the order in which replaceOpInputsWith is called! + v2.setInputsForOp(v.getInputsForOp()); + v.setInputsForOp(new ArrayList()); + } + + public static void removeOp(@NonNull SameDiff sd, @NonNull String opToRemove){ + SameDiffOp op = sd.getOps().remove(opToRemove); + for(String s : op.getInputsToOp()){ + Variable v = sd.getVariables().get(s); + v.getInputsForOp().remove(op.getName()); + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java new file mode 100644 index 000000000000..899fa2ad90e7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java @@ -0,0 +1,77 @@ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ops.impl.shape.Permute; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +public class ShapeFunctionOptimizations extends BaseOptimizerSet { + + /** + * Fuse [permute1 -> permute2 -> ... -> permuteN] into a single permute op, + * as long as the intermediate permute outputs aren't needed for another op + */ + public static class FuseChainedPermutes implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + if(!(op.getOp() instanceof Permute)) + return false; + + List inputs = op.getInputsToOp(); + String input = inputs.get(0); + + List toFuse = new ArrayList<>(); + toFuse.add(op.getName()); + String currInput = input; + while(currInput != null){ + Variable v = sd.getVariables().get(currInput); + //In order to fuse permute operations, we require: + // (a) the intermediate variable is ONLY needed by the next permute + // (b) the permute dimensions are constant, + + if(v.getInputsForOp().size() > 1) + break; + } + + if(toFuse.size() > 1){ + //Fuse the permute ops + +// return true; + return false; + } + + + return false; + } + } + + /** + * Fuse [reshape1 -> reshape2 -> ... -> reshapeN] into a single reshape op, + * as long as the intermediate reshape ops aren't needed for another op + */ + public static class FuseChainedReshapes implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } + + /** + * Fuse [concat(concat(concat(x,y,dim=D), z, dim=D), a, dim=D)] into a single concat op, concat(x,y,z,a, dim=D) + * As long as the intermediate outputs aren't needed elsewhere + */ + public static class FuseChainedConcatOps implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } + +} From b991deab5a03c57d2066d5dbe864551a003e5d7d Mon Sep 17 00:00:00 2001 From: AlexDBlack Date: Sat, 8 Feb 2020 20:05:25 +1100 Subject: [PATCH 3/6] Next steps Signed-off-by: AlexDBlack --- .../org/nd4j/autodiff/samediff/SameDiff.java | 4 +- .../array/OptimizedGraphArrayHolder.java | 78 +++++++++++++++++++ .../samediff/optimize/OptimizationHelper.java | 48 ++++++++++++ 3 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index f4abec1b087e..132d38b64f8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -118,9 +118,9 @@ public class SameDiff extends SDBaseOps { @Getter private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID - @Getter //TODO shouldn't be in public API + @Getter @Setter //TODO shouldn't be in public API private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true); - @Getter //TODO shouldn't be in public API + @Getter @Setter //TODO shouldn't be in public API private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true); private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java new file mode 100644 index 000000000000..0b2470676ffc --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java @@ -0,0 +1,78 @@ +package org.nd4j.autodiff.samediff.array; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.function.Supplier; + +import java.util.*; + +public class OptimizedGraphArrayHolder implements ArrayHolder { + + private final ArrayHolder underlyingHolder; + private final Map> functions; + + public OptimizedGraphArrayHolder(ArrayHolder underlyingHolder){ + this.underlyingHolder = underlyingHolder; + this.functions = new HashMap<>(); + } + + public void setFunction(String name, Supplier fn){ + if(underlyingHolder.hasArray(name)) + underlyingHolder.removeArray(name); + functions.put(name, fn); + } + + @Override + public boolean hasArray(String name) { + return functions.containsKey(name) || underlyingHolder.hasArray(name); + } + + @Override + public INDArray getArray(String name) { + if(functions.containsKey(name)) + return functions.get(name).get(); + return underlyingHolder.getArray(name); + } + + @Override + public void setArray(String name, INDArray array) { + Preconditions.checkState(!functions.containsKey(name), "Cannot set array when existing array is only accessible via a function"); + underlyingHolder.setArray(name, array); + } + + @Override + public INDArray removeArray(String name) { + Supplier s = functions.remove(name); + if(s != null) + return s.get(); + return underlyingHolder.removeArray(name); + } + + @Override + public int size() { + return underlyingHolder.size() + functions.size(); + } + + @Override + public void initFrom(ArrayHolder arrayHolder) { + underlyingHolder.initFrom(arrayHolder); + } + + @Override + public Collection arrayNames() { + Set set = new HashSet<>(); + set.addAll(underlyingHolder.arrayNames()); + set.addAll(functions.keySet()); + return set; + } + + @Override + public void rename(String from, String to) { + if(functions.containsKey(from)) { + functions.put(to, functions.remove(from)); + } else { + underlyingHolder.rename(from, to); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java new file mode 100644 index 000000000000..f56f8c4f44fd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java @@ -0,0 +1,48 @@ +package org.nd4j.autodiff.samediff.optimize; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.array.OptimizedGraphArrayHolder; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.function.Supplier; + +public class OptimizationHelper { + + private final SameDiff originalGraph; + private boolean setConstantHolder = false; + private boolean setVariableHolder = false; + + public OptimizationHelper(SameDiff originalGraph){ + this.originalGraph = originalGraph; + } + + public OptimizationHelper arrayRecoveryFunction(String arrayName, Supplier fn){ + SDVariable v = originalGraph.getVariable(arrayName); + Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT, + "Can only set an array recovery function for a variable or a constant"); + + if(v.getVariableType() == VariableType.VARIABLE){ + ArrayHolder h = originalGraph.getVariablesArrays(); + if(!setVariableHolder){ + originalGraph.setVariablesArrays(new OptimizedGraphArrayHolder(h)); + h = originalGraph.getVariablesArrays(); + setVariableHolder = true; + } + ((OptimizedGraphArrayHolder)h).setFunction(arrayName, fn); + } else { + ArrayHolder h = originalGraph.getConstantArrays(); + if(!setConstantHolder){ + originalGraph.setConstantArrays(new OptimizedGraphArrayHolder(h)); + h = originalGraph.getConstantArrays(); + setConstantHolder = true; + } + ((OptimizedGraphArrayHolder)h).setFunction(arrayName, fn); + } + + return this; + } + +} From 61bf5a8c8309b4722bb5a66bff1b13c04d7caadf Mon Sep 17 00:00:00 2001 From: AlexDBlack Date: Sat, 8 Feb 2020 22:21:52 +1100 Subject: [PATCH 4/6] API tweaks Signed-off-by: AlexDBlack --- .../samediff/optimize/GraphOptimizer.java | 40 +++++++++-- .../samediff/optimize/OptimizationConfig.java | 15 ++++ .../samediff/optimize/OptimizationHelper.java | 23 ++++++- .../autodiff/samediff/optimize/Optimizer.java | 2 +- .../optimizations/BaseOptimizerSet.java | 15 ++++ .../ConstantFunctionOptimizations.java | 23 +++++-- .../CuDNNFunctionOptimizations.java | 22 ++++-- .../IdentityFunctionOptimizations.java | 20 +++++- .../optimizations/OptimizationUtils.java | 15 ++++ .../ShapeFunctionOptimizations.java | 24 +++++-- .../UnusedFunctionOptimizations.java | 68 +++++++++++++++++++ .../optimization/TestOptimization.java | 18 +++-- 12 files changed, 257 insertions(+), 28 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java index 3bbdd5baa56c..145c1b46cf90 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize; import lombok.extern.slf4j.Slf4j; @@ -5,9 +20,11 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations; +import org.nd4j.autodiff.samediff.optimize.optimizations.*; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; /** @@ -19,7 +36,12 @@ public class GraphOptimizer { public static List defaultOptimizations(){ return Arrays.asList( - new ConstantFunctionOptimizations() + new UnusedFunctionOptimizations(), + new ConstantFunctionOptimizations(), + new IdentityFunctionOptimizations(), + new ShapeFunctionOptimizations(), + new UnusedFunctionOptimizations(), + new CuDNNFunctionOptimizations() ); } @@ -33,16 +55,22 @@ public static SameDiff optimize(SameDiff graph, List optimizations ArrayHolder cArr = sd.getConstantArrays(); ArrayHolder vArr = sd.getVariablesArrays(); - OptimizationConfig config = new OptimizationConfig(); //TODO + OptimizationHelper h = new OptimizationHelper(graph, new OptimizationConfig()); //TODO defaults for config for( int i=0; i<3; i++ ) { //Run multiple times - one run isn't enough, as some more optimizations may need to be applied to the output of earlier optimizations for (OptimizerSet s : optimizations) { List l = s.getOptimizers(); for(Optimizer o : l ){ - for(SameDiffOp op : sd.getOps().values()) { - boolean applied = o.checkAndApply(sd, config, op, cArr, vArr); + Collection startingOps = new ArrayList<>(sd.getOps().values()); //Create list to avoid concurrent modification exception + for(SameDiffOp op : startingOps) { + //Because ops might disappear from previous optimization steps, we need to check if the previous op + // still exists when iterating... + if(!sd.getOps().containsKey(op.getName())) + continue; + + boolean applied = o.checkAndApply(sd, h, op, cArr, vArr); if(applied) { - log.info("Operation was applied: "); + log.info("Operation was applied: {}", o); } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java index 2bc3bf1a087d..ebc1b036eca8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize; import java.util.Properties; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java index f56f8c4f44fd..a8470b1926a7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java @@ -1,5 +1,21 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize; +import lombok.Getter; import org.nd4j.autodiff.samediff.ArrayHolder; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -9,14 +25,19 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.function.Supplier; +import java.util.Properties; + public class OptimizationHelper { private final SameDiff originalGraph; + @Getter + private final Properties properties; private boolean setConstantHolder = false; private boolean setVariableHolder = false; - public OptimizationHelper(SameDiff originalGraph){ + public OptimizationHelper(SameDiff originalGraph, Properties properties){ this.originalGraph = originalGraph; + this.properties = properties; } public OptimizationHelper arrayRecoveryFunction(String arrayName, Supplier fn){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java index c84e9708d81c..fab056bf92b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java @@ -21,6 +21,6 @@ public interface Optimizer { * @param variablesArrays * @return True if the optimization was applied */ - boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays); + boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java index ce1787bdb9b5..7a60745ed3f4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.optimizations; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java index 435d8c86ecfd..5f3d0a2af09e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java @@ -1,12 +1,26 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.optimizations; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.ArrayHolder; -import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; import org.nd4j.autodiff.samediff.optimize.Optimizer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -14,7 +28,6 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.List; -import java.util.Properties; /** * This set of optimizations looks for functions that are applied to constants, and "pre executes" them, so they don't have @@ -29,7 +42,7 @@ public class ConstantFunctionOptimizations extends BaseOptimizerSet { public static class FoldConstantFunctions implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { //TODO This function needs to check for non-deterministic ops - i.e., random ops - and not apply the optimization to these List in = op.getInputsToOp(); @@ -40,7 +53,7 @@ public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDif return false; } - long maxSizeToApply = Long.parseLong(optimizationConfig.getProperty(CONSTANT_FN_FOLDING_MAX_SIZE, String.valueOf(CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT))); + long maxSizeToApply = Long.parseLong(helper.getProperties().getProperty(CONSTANT_FN_FOLDING_MAX_SIZE, String.valueOf(CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT))); //Apply the optimization: DifferentialFunction df = op.getOp(); df.clearArrays(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java index b792e2d1c138..c1540088c77a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.java @@ -1,16 +1,30 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.optimizations; import org.nd4j.autodiff.samediff.ArrayHolder; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; import org.nd4j.autodiff.samediff.optimize.Optimizer; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D; import org.nd4j.linalg.factory.Nd4j; import java.util.List; -import java.util.Properties; public class CuDNNFunctionOptimizations extends BaseOptimizerSet { @@ -29,11 +43,11 @@ public class CuDNNFunctionOptimizations extends BaseOptimizerSet { * "Layout choice has an effect on performance, as convolutions implemented for Tensor Cores require NHWC layout and are fastest when input tensors are laid out in NHWC." * "To maximize performance, we recommend using NHWC tensor layout." * - * Asn for weights format: cuDNN docs are vague - but TF uses NCHW+OIHW or NHWC+OHWI + * As for weights format: cuDNN docs are vague - but TF uses NCHW+OIHW or NHWC+OHWI */ public static class CudnnConv2dNCHWtoNHWCConversion implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { if(!(op.getOp() instanceof Conv2D)) return false; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java index d32f14d34da6..3801848def3d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java @@ -1,8 +1,24 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.optimizations; import org.nd4j.autodiff.samediff.ArrayHolder; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; import org.nd4j.autodiff.samediff.optimize.Optimizer; import java.util.Properties; @@ -15,7 +31,7 @@ public class IdentityFunctionOptimizations extends BaseOptimizerSet { public static class RemoveIdentityPermute implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { return false; } } @@ -25,7 +41,7 @@ public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDif */ public static class RemoveIdentityOps implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { return false; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java index 98fd2303324e..8e0167820969 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.optimizations; import lombok.NonNull; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java index 899fa2ad90e7..f45b92bce372 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java @@ -1,16 +1,30 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.optimizations; import org.nd4j.autodiff.samediff.ArrayHolder; -import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; import org.nd4j.autodiff.samediff.optimize.Optimizer; import org.nd4j.linalg.api.ops.impl.shape.Permute; import java.util.ArrayList; import java.util.List; -import java.util.Properties; public class ShapeFunctionOptimizations extends BaseOptimizerSet { @@ -20,7 +34,7 @@ public class ShapeFunctionOptimizations extends BaseOptimizerSet { */ public static class FuseChainedPermutes implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { if(!(op.getOp() instanceof Permute)) return false; @@ -58,7 +72,7 @@ public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDif */ public static class FuseChainedReshapes implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { return false; } } @@ -69,7 +83,7 @@ public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDif */ public static class FuseChainedConcatOps implements Optimizer { @Override - public boolean checkAndApply(SameDiff sd, Properties optimizationConfig, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { return false; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java new file mode 100644 index 000000000000..96d3d34de95d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java @@ -0,0 +1,68 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.function.Supplier; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +public class UnusedFunctionOptimizations extends BaseOptimizerSet { + + public static class RemoveUnusedConstants implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + //TODO check this once _per graph_ not per op + List variables = new ArrayList<>(sd.getVariables().values()); + boolean anyRemoved = false; + for(Variable v : variables){ + if(v.getVariable().getVariableType() == VariableType.CONSTANT){ + List inputFor = v.getInputsForOp(); + if(inputFor == null || inputFor.isEmpty()){ + //This constant isn't used... + + //TODO let's put these on disk instead of keeping them in memory... + final INDArray arr = v.getVariable().getArr(); + helper.arrayRecoveryFunction(v.getName(), new Supplier() { + @Override + public INDArray get() { + return arr; + } + }); + + sd.getVariables().remove(v.getName()); + log.info("Removed unused constant: {}", v.getName()); + anyRemoved = true; + } + } + } + return anyRemoved; + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java index 79402a363825..238a5aa0e6ac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java @@ -12,6 +12,7 @@ import java.util.Collections; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; public class TestOptimization extends BaseNd4jTest { @@ -24,19 +25,28 @@ public char ordering() { return 'c'; } + @Override + public long getTimeoutMilliseconds() { + return 1_000_000_000L; + } @Test public void testConstantOpFolding(){ SameDiff sd = SameDiff.create(); SDVariable c = sd.constant("c", Nd4j.scalar(1.0)); - SDVariable v = c.add("add", 1); + SDVariable c2 = c.add("add", 1); + SDVariable v = sd.var("variable", Nd4j.scalar(1.0)); + SDVariable out = v.sub("out", c2); SameDiff optimized = GraphOptimizer.optimize(sd); - assertEquals(2, optimized.getVariables().size()); + assertEquals(3, optimized.getVariables().size()); //"add", "variable", "out" -> "c" should be removed assertEquals(VariableType.CONSTANT, optimized.getVariable("add").getVariableType()); - assertEquals(0, optimized.getOps().size()); + assertEquals(1, optimized.getOps().size()); + assertEquals("subtract", optimized.getOps().values().iterator().next().getName()); + + assertFalse(optimized.hasVariable("c")); - assertEquals(sd.outputSingle(Collections.emptyMap(), "add"), optimized.outputSingle(Collections.emptyMap(), "add")); + assertEquals(sd.outputSingle(Collections.emptyMap(), "out"), optimized.outputSingle(Collections.emptyMap(), "out")); } } From 813c6f9919a5882f2ae5675aab7c6a2e4f54ca04 Mon Sep 17 00:00:00 2001 From: AlexDBlack Date: Sun, 9 Feb 2020 00:08:46 +1100 Subject: [PATCH 5/6] Proper debugging/tracking API for graph optimization Signed-off-by: AlexDBlack --- .../samediff/optimize/GraphOptimizer.java | 13 ++- .../optimize/debug/OptimizationDebugger.java | 13 +++ .../optimization/TestOptimization.java | 51 ++++++++++ .../optimization/util/OptTestConfig.java | 91 +++++++++++++++++ .../util/OptimizationRecordingDebugger.java | 28 ++++++ .../util/OptimizationTestUtil.java | 99 +++++++++++++++++++ 6 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java index 145c1b46cf90..1e350deb5fb6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger; import org.nd4j.autodiff.samediff.optimize.optimizations.*; import java.util.ArrayList; @@ -49,7 +50,11 @@ public static SameDiff optimize(SameDiff graph){ return optimize(graph, defaultOptimizations()); } - public static SameDiff optimize(SameDiff graph, List optimizations){ + public static SameDiff optimize(SameDiff graph, List optimizations) { + return optimize(graph, optimizations, null); + } + + public static SameDiff optimize(SameDiff graph, List optimizations, OptimizationDebugger debugger){ SameDiff sd = graph.dup(); ArrayHolder cArr = sd.getConstantArrays(); @@ -68,10 +73,16 @@ public static SameDiff optimize(SameDiff graph, List optimizations if(!sd.getOps().containsKey(op.getName())) continue; + if(debugger != null) + debugger.beforeOptimizationCheck(sd, op, o); + boolean applied = o.checkAndApply(sd, h, op, cArr, vArr); if(applied) { log.info("Operation was applied: {}", o); } + + if(debugger != null) + debugger.afterOptimizationsCheck(sd, op, o, applied); } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java new file mode 100644 index 000000000000..19f750c17f33 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java @@ -0,0 +1,13 @@ +package org.nd4j.autodiff.samediff.optimize.debug; + +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.Optimizer; + +public interface OptimizationDebugger { + + void beforeOptimizationCheck(SameDiff sd, SameDiffOp op, Optimizer o); + + void afterOptimizationsCheck(SameDiff sd, SameDiffOp op, Optimizer o, boolean wasApplied); + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java index 238a5aa0e6ac..faa2fe9177fa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java @@ -1,14 +1,20 @@ package org.nd4j.autodiff.optimization; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.optimization.util.OptTestConfig; +import org.nd4j.autodiff.optimization.util.OptimizationTestUtil; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; +import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import java.util.Arrays; import java.util.Collections; import static org.junit.Assert.assertEquals; @@ -16,6 +22,9 @@ public class TestOptimization extends BaseNd4jTest { + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + public TestOptimization(Nd4jBackend backend) { super(backend); } @@ -32,6 +41,10 @@ public long getTimeoutMilliseconds() { @Test public void testConstantOpFolding(){ + //We expect 2 things in this test: + //(a) the output of add(constant, constant) is pre-calculated and itself becomes a constant + //(b) the + SameDiff sd = SameDiff.create(); SDVariable c = sd.constant("c", Nd4j.scalar(1.0)); @@ -39,6 +52,8 @@ public void testConstantOpFolding(){ SDVariable v = sd.var("variable", Nd4j.scalar(1.0)); SDVariable out = v.sub("out", c2); + SameDiff copy = sd.dup(); + SameDiff optimized = GraphOptimizer.optimize(sd); assertEquals(3, optimized.getVariables().size()); //"add", "variable", "out" -> "c" should be removed assertEquals(VariableType.CONSTANT, optimized.getVariable("add").getVariableType()); @@ -48,5 +63,41 @@ public void testConstantOpFolding(){ assertFalse(optimized.hasVariable("c")); assertEquals(sd.outputSingle(Collections.emptyMap(), "out"), optimized.outputSingle(Collections.emptyMap(), "out")); + + //Check the + + //Check that the original can be saved and loaded, and still gives the same results + + } + + @Test + public void testConstantOpFolding2(){ + //We expect 2 things in this test: + //(a) the output of add(constant, constant) is pre-calculated and itself becomes a constant + //(b) the + + + SameDiff sd = SameDiff.create(); + SDVariable c = sd.constant("c", Nd4j.scalar(1.0)); + SDVariable c2 = c.add("add", 1); + SDVariable v = sd.var("variable", Nd4j.scalar(1.0)); + SDVariable out = v.sub("out", c2); + + OptTestConfig conf = OptTestConfig.builder() + .original(sd) + .outputs(Collections.singletonList("out")) + .mustApply(sd.getVariables().get("add").getOutputOfOp(), ConstantFunctionOptimizations.FoldConstantFunctions.class) + .build(); + + SameDiff optimized = OptimizationTestUtil.testOptimization(conf); + assertEquals(3, optimized.getVariables().size()); //"add", "variable", "out" -> "c" should be removed + assertEquals(VariableType.CONSTANT, optimized.getVariable("add").getVariableType()); + assertEquals(1, optimized.getOps().size()); + assertEquals("subtract", optimized.getOps().values().iterator().next().getName()); + + assertFalse(optimized.hasVariable("c")); + + assertEquals(sd.outputSingle(Collections.emptyMap(), "out"), optimized.outputSingle(Collections.emptyMap(), "out")); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java new file mode 100644 index 000000000000..57a099a72bea --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java @@ -0,0 +1,91 @@ +package org.nd4j.autodiff.optimization.util; + +import lombok.Builder; +import lombok.Data; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizerSet; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.File; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Data +public class OptTestConfig { + + private SameDiff original; + private Map placeholders; + private List outputs; + private File tempFolder; + private Map> mustApply; + private List optimizerSets; + + public static Builder builder(){ + return new Builder(); + } + + public static class Builder { + + private SameDiff original; + private Map placeholders; + private List outputs; + private File tempFolder; + private Map> mustApply; + private List optimizerSets; + + public Builder original(SameDiff sd){ + original = sd; + return this; + } + + public Builder placeholder(String ph, INDArray arr){ + if(placeholders == null) + placeholders = new HashMap<>(); + placeholders.put(ph, arr); + return this; + } + + public Builder placeholders(Map map){ + placeholders = map; + return this; + } + + public Builder outputs(String... outputs){ + this.outputs = Arrays.asList(outputs); + return this; + } + + public Builder outputs(List outputs){ + this.outputs = outputs; + return this; + } + + public Builder mustApply(String opName, Class optimizerClass){ + if(mustApply == null) + mustApply = new HashMap<>(); + mustApply.put(opName, optimizerClass); + return this; + } + + public Builder optimizerSets(List list){ + this.optimizerSets = list; + return this; + } + + public OptTestConfig build(){ + OptTestConfig c = new OptTestConfig(); + c.original = original; + c.placeholders = placeholders; + c.outputs = outputs; + c.tempFolder = tempFolder; + c.mustApply = mustApply; + c.optimizerSets = optimizerSets; + return c; + } + + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java new file mode 100644 index 000000000000..74c134f4e9a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java @@ -0,0 +1,28 @@ +package org.nd4j.autodiff.optimization.util; + +import lombok.Getter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger; + +import java.util.HashMap; +import java.util.Map; + +public class OptimizationRecordingDebugger implements OptimizationDebugger { + + @Getter + private Map applied = new HashMap<>(); + + @Override + public void beforeOptimizationCheck(SameDiff sd, SameDiffOp op, Optimizer o) { + //No op + } + + @Override + public void afterOptimizationsCheck(SameDiff sd, SameDiffOp op, Optimizer o, boolean wasApplied) { + if(wasApplied){ + applied.put(op.getName(), o); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java new file mode 100644 index 000000000000..4fd1a8c21601 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java @@ -0,0 +1,99 @@ +package org.nd4j.autodiff.optimization.util; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizerSet; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.File; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * TODO: + * - Add ability to track which optimization functions exactly were applied! + */ +public class OptimizationTestUtil { + + private OptimizationTestUtil(){ } + + public static SameDiff testOptimization(OptTestConfig config){ + + List optimizerSets = config.getOptimizerSets(); + if(optimizerSets == null) + optimizerSets = GraphOptimizer.defaultOptimizations(); + OptimizationRecordingDebugger debugger = new OptimizationRecordingDebugger(); + + // + Map ph = config.getPlaceholders(); + List outputs = config.getOutputs(); + SameDiff original = config.getOriginal(); + SameDiff copy = original.dup(); + SameDiff optimized = GraphOptimizer.optimize(original, optimizerSets, debugger); + + //Check that SOMETHING changed in the optimized - number of constants, variables, or ops; or the settings for ops; or the values of some arrays + //TODO + boolean sameNumConst = original.getConstantArrays().size() == optimized.getConstantArrays().size(); + boolean sameNumVars = original.getVariablesArrays().size() == optimized.getVariablesArrays().size(); + boolean sameNumSDVars = original.getVariables().size() == optimized.getVariables().size(); + boolean sameNumOps = original.getOps().size() == optimized.getOps().size(); + + if(sameNumConst && sameNumVars && sameNumSDVars && sameNumOps){ + + + throw new IllegalStateException("Did not detect any changes to the graph structure after optimization (but check is AS YET WIP)"); + } + + //Check that optimizations we expected to be applied were in fact applied: + Map> mustApply = config.getMustApply(); + Map applied = debugger.getApplied(); + for(String s : mustApply.keySet()){ + assertTrue("Expected optimizer of type " + mustApply.get(s).getSimpleName() + " to be applied to op " + s, + applied.containsKey(s)); + } + + + //Second: check that they all produce the same + //TODO this won't work for random ops! + Map origOut = original.output(ph, outputs); + Map copyOut = copy.output(ph, outputs); + Map optimizedOut = optimized.output(ph, outputs); + + assertEquals(copyOut, origOut); + assertEquals(copyOut, optimizedOut); + + File f = new File(config.getTempFolder(), "optimized.sd"); + optimized.save(f, true); + + SameDiff loaded = SameDiff.load(f, true); + Map loadedOut = loaded.output(ph, outputs); + assertEquals(copyOut, loadedOut); + + //TODO add support for training checks! + //This is especially important for updaters... if we permute the weights, we should permute the updater state also + + //Check that nothing has changed (from the user API perspective) for the original graph + //i.e., + for(SDVariable v : copy.variables()){ + SDVariable ov = original.getVariable(v.name()); + + assertEquals(v.dataType(), ov.dataType()); + assertEquals(v.getVariableType(), ov.getVariableType()); + if(v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.VARIABLE){ + INDArray arrCopy = v.getArr(); + INDArray arrOrig = ov.getArr(); + assertEquals(arrCopy, arrOrig); + } + + } + + return optimized; + } + +} From 8e8d4439531bfd4b5a2b395515aed8291e7897d4 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 10 Mar 2020 00:29:31 +1100 Subject: [PATCH 6/6] Add seamless/automatic 'internal' optimization for SameDiff output methods Signed-off-by: Alex Black --- .../org/nd4j/autodiff/samediff/SameDiff.java | 31 ++++ .../samediff/optimize/GraphOptimizer.java | 16 +- .../autodiff/samediff/optimize/Optimizer.java | 27 +++- .../samediff/optimize/OptimizerSet.java | 15 ++ .../optimize/debug/OptimizationDebugger.java | 20 +++ .../IdentityFunctionOptimizations.java | 10 ++ .../optimizations/OptimizationUtils.java | 4 + .../optimization/TestOptimization.java | 37 ++++- .../TestSeamlessOptimization.java | 137 ++++++++++++++++++ .../util/OptimizationTestUtil.java | 4 +- 10 files changed, 285 insertions(+), 16 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 132d38b64f8b..a0e80bc9fbc1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -38,6 +38,8 @@ import org.nd4j.autodiff.samediff.config.OutputConfig; import org.nd4j.autodiff.samediff.internal.*; import org.nd4j.autodiff.samediff.ops.*; +import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizationConfig; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; @@ -109,6 +111,7 @@ @Slf4j public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; + protected static final String OPTIMIZED_FN_KEY = "optimized"; //Fields for graph structure and execution @Getter @@ -240,6 +243,11 @@ public SDBitwise bitwise(){ return bitwise; } + @Setter @Getter + private boolean allowOptimization = true; + private String[] optimizedWRT = null; + + private Map sameDiffFunctionInstances; private Table fieldVariableResolutionMapping; @@ -2556,6 +2564,23 @@ protected Map batchOutputHelper(Map placehol activeListeners.add(l); } + if(allowOptimization){ + if(!sameDiffFunctionInstances.containsKey(OPTIMIZED_FN_KEY) || optimizedWRT == null || !Arrays.equals(optimizedWRT, outputs)){ + //Need to create optimized version + + SameDiff sd = optimize(Arrays.asList(outputs)); + sameDiffFunctionInstances.put(OPTIMIZED_FN_KEY, sd); + + + //TODO clean up old version optimized SameDiff if necessary + } + SameDiff optimized = sameDiffFunctionInstances.get(OPTIMIZED_FN_KEY); + if(optimized.isAllowOptimization()) + optimized.setAllowOptimization(false); //Prevent recursive optimizations + + return optimized.batchOutputHelper(placeholders, activeListeners, operation, outputs); + } + for (Listener l : activeListeners) { l.operationStart(this, operation); } @@ -5865,4 +5890,10 @@ public String generateDistinctCustomVariableName(String base){ return base + "_" + inc; } + + protected SameDiff optimize(List withRespectToOutputs){ + SameDiff sd = GraphOptimizer.optimize(this, withRespectToOutputs); + sd.setAllowOptimization(false); //Prevent recursive optimization attempts when output is called + return sd; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java index 1e350deb5fb6..96d1ad574f0f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java @@ -46,15 +46,21 @@ public static List defaultOptimizations(){ ); } - public static SameDiff optimize(SameDiff graph){ - return optimize(graph, defaultOptimizations()); + public static SameDiff optimize(SameDiff graph, String... requiredOutputs){ + return optimize(graph, Arrays.asList(requiredOutputs)); } - public static SameDiff optimize(SameDiff graph, List optimizations) { - return optimize(graph, optimizations, null); + public static SameDiff optimize(SameDiff graph, List requiredOutputs){ + return optimize(graph, requiredOutputs, defaultOptimizations()); } - public static SameDiff optimize(SameDiff graph, List optimizations, OptimizationDebugger debugger){ + public static SameDiff optimize(SameDiff graph, List requiredOutputs, List optimizations) { + return optimize(graph, requiredOutputs, optimizations, null); + } + + public static SameDiff optimize(SameDiff graph, List requiredOutputs, List optimizations, OptimizationDebugger debugger){ + //TODO Use required outputs - strip unnecessary graph components + SameDiff sd = graph.dup(); ArrayHolder cArr = sd.getConstantArrays(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java index fab056bf92b8..2411562603d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize; import org.nd4j.autodiff.samediff.ArrayHolder; @@ -7,18 +22,16 @@ import java.util.Properties; /** - * * @author Alex Black */ public interface Optimizer { /** - * - * @param sd Current SameDiff instance to optimize - * @param optimizationConfig Optimization configuration - * @param op Operation to check for optimization - * @param constantArrays - * @param variablesArrays + * @param sd Current SameDiff instance to optimize + * @param helper Helper class for optimization + * @param op Operation to check for optimization + * @param constantArrays Array holder for constant arrays + * @param variablesArrays Array holder for variable arrays * @return True if the optimization was applied */ boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java index a971c7174967..6c3bca83df05 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java index 19f750c17f33..db4663b1abc3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java @@ -1,9 +1,29 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.autodiff.samediff.optimize.debug; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.optimize.Optimizer; +/** + * Used as a listener for + * + * @author Alex Black + */ public interface OptimizationDebugger { void beforeOptimizationCheck(SameDiff sd, SameDiffOp op, Optimizer o); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java index 3801848def3d..be1c1dc09d5c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; import java.util.Properties; @@ -42,6 +43,15 @@ public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp public static class RemoveIdentityOps implements Optimizer { @Override public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + if(op.getOp() instanceof Identity){ + String inName = op.getInputsToOp().get(0); + String outputName = op.getOutputsOfOp().get(0); + OptimizationUtils.removeOp(sd, op.getName()); + OptimizationUtils.replaceOpInputsWith(sd, outputName, inName); + OptimizationUtils.removeVariable(sd, outputName); + return true; + } + return false; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java index 8e0167820969..4c6503c6da67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java @@ -58,4 +58,8 @@ public static void removeOp(@NonNull SameDiff sd, @NonNull String opToRemove){ } } + public static void removeVariable(@NonNull SameDiff sd, @NonNull String varToRemove){ + sd.getVariables().remove(varToRemove); + } + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java index faa2fe9177fa..ecdd77afc8f3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java @@ -10,15 +10,16 @@ import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations; +import org.nd4j.autodiff.samediff.optimize.optimizations.IdentityFunctionOptimizations; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.Arrays; import java.util.Collections; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.*; public class TestOptimization extends BaseNd4jTest { @@ -54,7 +55,7 @@ public void testConstantOpFolding(){ SameDiff copy = sd.dup(); - SameDiff optimized = GraphOptimizer.optimize(sd); + SameDiff optimized = GraphOptimizer.optimize(sd, "out"); assertEquals(3, optimized.getVariables().size()); //"add", "variable", "out" -> "c" should be removed assertEquals(VariableType.CONSTANT, optimized.getVariable("add").getVariableType()); assertEquals(1, optimized.getOps().size()); @@ -100,4 +101,34 @@ public void testConstantOpFolding2(){ assertEquals(sd.outputSingle(Collections.emptyMap(), "out"), optimized.outputSingle(Collections.emptyMap(), "out")); } + + @Test + public void testIdentityRemoval(){ + + //Ensure that optimizer is actually used when calling output methods: + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable i1 = sd.identity(in); + SDVariable i2 = sd.identity(w); + SDVariable i3 = sd.identity(b); + SDVariable out = sd.nn.softmax("out", sd.identity(i1.mmul(i2).add(i3))); + + OptTestConfig conf = OptTestConfig.builder() + .original(sd) + .outputs(Collections.singletonList("out")) + .placeholder("in", Nd4j.rand(DataType.FLOAT, 5, 4)) + .mustApply(sd.getVariables().get(i1.name()).getOutputOfOp(), IdentityFunctionOptimizations.RemoveIdentityOps.class) + .mustApply(sd.getVariables().get(i2.name()).getOutputOfOp(), IdentityFunctionOptimizations.RemoveIdentityOps.class) + .mustApply(sd.getVariables().get(i3.name()).getOutputOfOp(), IdentityFunctionOptimizations.RemoveIdentityOps.class) + .build(); + + SameDiff optimized = OptimizationTestUtil.testOptimization(conf); + assertEquals(3, optimized.getOps().size()); + assertFalse(optimized.hasVariable(i1.name())); + assertFalse(optimized.hasVariable(i2.name())); + assertFalse(optimized.hasVariable(i3.name())); + assertTrue(optimized.hasVariable("out")); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java new file mode 100644 index 000000000000..ec1914af1171 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java @@ -0,0 +1,137 @@ +package org.nd4j.autodiff.optimization; + +import lombok.Data; +import org.junit.Test; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import java.util.*; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class TestSeamlessOptimization extends BaseNd4jTest { + + public TestSeamlessOptimization(Nd4jBackend backend) { + super(backend); + } + + + @Test + public void testOutput(){ + + //Ensure that optimizer is actually used when calling output methods: + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + + SDVariable i1 = sd.identity(in); + SDVariable i2 = sd.identity(w); + SDVariable i3 = sd.identity(b); + + SDVariable out = sd.nn.softmax("out", sd.identity(i1.mmul(i2).add(i3))); + + RecordOpsListener l = new RecordOpsListener(); + sd.setListeners(new AssertNoOpsOfTypeListener(Identity.class), l); + + Map ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 10, 4)); + + for( int i=0; i<3; i++ ) { + l.ops.clear(); + + switch (i){ + case 0: + sd.outputSingle(ph, "out"); + break; + case 1: + sd.output(ph, "out"); + break; + case 2: + sd.batchOutput().output("out") + .input("in", ph.get("in")) + .outputSingle(); + break; + } + + + List> expClasses = Arrays.asList(Mmul.class, AddOp.class, SoftMax.class); + assertEquals(3, l.ops.size()); + for (int j = 0; j < 3; j++) { + assertEquals(expClasses.get(j), l.ops.get(j).getOp().getClass()); + } + + } + } + + @Test + public void testDifferentOutputs(){ + //Test when the user requests different outputs instead + } + + @Test + public void testGraphModification(){ + //User modifies the graph -> should reoptimize? + + fail("Not yet implemented"); + } + + public static class AssertNoOpsOfTypeListener extends BaseListener { + private List> list; + + public AssertNoOpsOfTypeListener(Class... c) { + Preconditions.checkState(c != null && c.length > 0, "No classes provided"); + this.list = Arrays.asList(c); + } + + @Override + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + if(list.contains(op.getOp().getClass())){ + throw new IllegalStateException("Encountered unexpected class: " + op.getOp().getClass().getName()); + } + } + } + + @Data + public static class RecordOpsListener extends BaseListener { + + private List ops = new ArrayList<>(); + + @Override + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + ops.add(op); + } + } + + + @Override + public char ordering() { + return 'c'; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java index 4fd1a8c21601..0e0cade8a745 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java @@ -6,6 +6,7 @@ import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; import org.nd4j.autodiff.samediff.optimize.Optimizer; import org.nd4j.autodiff.samediff.optimize.OptimizerSet; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.File; @@ -24,6 +25,7 @@ public class OptimizationTestUtil { private OptimizationTestUtil(){ } public static SameDiff testOptimization(OptTestConfig config){ + Preconditions.checkNotNull(config.getTempFolder(), "Temp folder should be specified before running test"); List optimizerSets = config.getOptimizerSets(); if(optimizerSets == null) @@ -35,7 +37,7 @@ public static SameDiff testOptimization(OptTestConfig config){ List outputs = config.getOutputs(); SameDiff original = config.getOriginal(); SameDiff copy = original.dup(); - SameDiff optimized = GraphOptimizer.optimize(original, optimizerSets, debugger); + SameDiff optimized = GraphOptimizer.optimize(original, outputs, optimizerSets, debugger); //Check that SOMETHING changed in the optimized - number of constants, variables, or ops; or the settings for ops; or the values of some arrays //TODO