diff --git a/.gitignore b/.gitignore index 8a6e890..4fc127e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ result.txt .stack-work/ .envrc .direnv +.vscode/ +dist-newstyle/ diff --git a/hie.yaml b/hie.yaml index bfee1c0..241b30e 100644 --- a/hie.yaml +++ b/hie.yaml @@ -9,6 +9,72 @@ cradle: - path: "./rules/debug/Paths_tensor_right.hs" component: "tensor-right:exe:rules-debug" + - path: "./rules/taso/concat/Main.hs" + component: "tensor-right:exe:rules-taso-concat" + + - path: "./rules/taso/concat/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-concat" + + - path: "./rules/taso/conv/Main.hs" + component: "tensor-right:exe:rules-taso-conv" + + - path: "./rules/taso/conv/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-conv" + + - path: "./rules/taso/enlarge/Main.hs" + component: "tensor-right:exe:rules-taso-enlarge" + + - path: "./rules/taso/enlarge/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-enlarge" + + - path: "./rules/taso/ewadd/Main.hs" + component: "tensor-right:exe:rules-taso-ewadd" + + - path: "./rules/taso/ewadd/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-ewadd" + + - path: "./rules/taso/ewmul/Main.hs" + component: "tensor-right:exe:rules-taso-ewmul" + + - path: "./rules/taso/ewmul/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-ewmul" + + - path: "./rules/taso/matmul2D/Main.hs" + component: "tensor-right:exe:rules-taso-matmul2D" + + - path: "./rules/taso/matmul2D/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-matmul2D" + + - path: "./rules/taso/matmul3D/Main.hs" + component: "tensor-right:exe:rules-taso-matmul3D" + + - path: "./rules/taso/matmul3D/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-matmul3D" + + - path: "./rules/taso/relu/Main.hs" + component: "tensor-right:exe:rules-taso-relu" + + - path: "./rules/taso/relu/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-relu" + + - path: "./rules/taso/smul/Main.hs" + component: "tensor-right:exe:rules-taso-smul" + + - path: "./rules/taso/smul/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-smul" + + - path: "./rules/taso/split/Main.hs" + component: "tensor-right:exe:rules-taso-split" + + - path: "./rules/taso/split/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-split" + + - path: "./rules/taso/transpose/Main.hs" + component: "tensor-right:exe:rules-taso-transpose" + + - path: "./rules/taso/transpose/Paths_tensor_right.hs" + component: "tensor-right:exe:rules-taso-transpose" + - path: "./rules/xla/add/Main.hs" component: "tensor-right:exe:rules-xla-add" diff --git a/package.yaml b/package.yaml index a1d319c..fd426fc 100644 --- a/package.yaml +++ b/package.yaml @@ -189,6 +189,73 @@ executables: dependencies: tensor-right ghc-options: *exe-ghc-options default-extensions: *exe-extensions + # TASO Executables + rules-taso-ewadd: + source-dirs: rules/taso/ewadd + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-ewmul: + source-dirs: rules/taso/ewmul + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-smul: + source-dirs: rules/taso/smul + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-relu: + source-dirs: rules/taso/relu + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-concat: + source-dirs: rules/taso/concat + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-transpose: + source-dirs: rules/taso/transpose + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-enlarge: + source-dirs: rules/taso/enlarge + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-matmul2D: + source-dirs: rules/taso/matmul2D + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-matmul3D: + source-dirs: rules/taso/matmul3D + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-split: + source-dirs: rules/taso/split + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions + rules-taso-conv: + source-dirs: rules/taso/conv + main: Main.hs + dependencies: tensor-right + ghc-options: *exe-ghc-options + default-extensions: *exe-extensions # Other Executables rules-debug: source-dirs: rules/debug diff --git a/plot/timing_plot.py b/plot/timing_plot.py index ca65779..e6004db 100644 --- a/plot/timing_plot.py +++ b/plot/timing_plot.py @@ -28,7 +28,7 @@ def first_num_of_tasks(self) -> int: def overall_time(self) -> float: return sum(x.time for x in self.results) - +# TODO: Handle ANSI color escape codes def parse_file(lines: Sequence[str]) -> list[Rule]: """ The file looks like this: diff --git a/rules/taso/concat/Main.hs b/rules/taso/concat/Main.hs new file mode 100644 index 0000000..eb854e0 --- /dev/null +++ b/rules/taso/concat/Main.hs @@ -0,0 +1,95 @@ +{-# OPTIONS_GHC -Wno-missing-import-lists #-} + +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.TASO (concat, ewadd, ewmul, relu, smul) +import Prelude hiding (concat) + +desugar :: forall a. NumRule a +desugar _ = do + r <- newRClass "r" + [sa, sb] <- newMaps ["sa", "sb"] r + a <- newTensor @a "A" [r --> sa] + b <- newTensor @a "B" [r --> sb] + let d = ByRClass r + lhs <- concat d a b + rhs <- concatTensor a b d + rewrite "concat(d, A, B) ⇒ Concatenate((A, B), d)" lhs rhs + +smulAssociativity :: forall a. NumRule a +smulAssociativity _ = do + let w = ("w" :: a) + r <- newRClass "r" + s <- newMap "s" r + x <- newTensor @a "x" [r --> s] + y <- newTensor @a "y" [r --> s] + let d = ByRClass r + lhs <- concat d (smul x w) (smul y w) + rhs <- smul (concat d x y) w + rewrite "concat(d, smul(x, w), smul(y, w)) ⇒ smul(concat(d, x, y), w)" lhs rhs + +ewaddAssociativity :: forall a. NumRule a +ewaddAssociativity _ = do + r <- newRClass "r" + s <- newMap "s" r + x <- newTensor @a "x" [r --> s] + y <- newTensor @a "y" [r --> s] + z <- newTensor @a "z" [r --> s] + w <- newTensor @a "w" [r --> s] + let d = ByRClass r + lhs <- concat d (ewadd x y) (ewadd z w) + rhs <- ewadd (concat d x z) (concat d y w) + rewrite "concat(d, ewadd(x, y), ewadd(z, w)) ⇒ ewadd(concat(d, x, z), concat(d, y, w))" lhs rhs + +ewmulAssociativity :: forall a. NumRule a +ewmulAssociativity _ = do + r <- newRClass "r" + s <- newMap "s" r + x <- newTensor @a "x" [r --> s] + y <- newTensor @a "y" [r --> s] + z <- newTensor @a "z" [r --> s] + w <- newTensor @a "w" [r --> s] + let d = ByRClass r + lhs <- concat d (ewmul x y) (ewmul z w) + rhs <- ewmul (concat d x z) (concat d y w) + rewrite "concat(d, ewmul(x, y), ewmul(z, w)) ⇒ ewmul(concat(d, x, z), concat(d, y, w))" lhs rhs + +reluAssociativity :: forall a. NumRule a +reluAssociativity _ = do + r <- newRClass "r" + s <- newMap "s" r + x <- newTensor @a "x" [r --> s] + y <- newTensor @a "y" [r --> s] + let d = ByRClass r + lhs <- concat d (relu @a x) (relu @a y) + rhs <- relu @a $ concat d x y + rewrite "" lhs rhs + +geometry :: forall a. NumRule a +geometry _ = do + [d0, d1, d2] <- newRClasses ["d0", "d1", "d2"] + d0S <- newMap "d0S" d0 + d1S <- newMap "d1S" d1 + d2S <- newMap "d2S" d2 + x <- newTensor @a "x" [d0 --> d0S, d1 --> d1S, d2 --> d2S] + y <- newTensor @a "y" [d0 --> d0S, d1 --> d1S, d2 --> d2S] + z <- newTensor @a "z" [d0 --> d0S, d1 --> d1S, d2 --> d2S] + w <- newTensor @a "w" [d0 --> d0S, d1 --> d1S, d2 --> d2S] + lhs <- concat (ByRClass d0) (concat (ByRClass d1) x y) (concat (ByRClass d1) z w) + rhs <- concat (ByRClass d1) (concat (ByRClass d0) x z) (concat (ByRClass d0) y w) + rewrite "concat(d0, concat(d1, x, y), concat(d1, z, w)) ⇒ concat(d1, concat(d0, x, z), concat(0, y, w))" lhs rhs + +main :: IO () +main = do + printTitle "######################## desugarOneRole ########################" + verifyNumDSL desugar + printTitle "######################## smulAssociativity #####################" + verifyNumDSL smulAssociativity + printTitle "######################## ewaddAssociativity ####################" + verifyNumDSL ewaddAssociativity + printTitle "######################## ewmulAssociativity ####################" + verifyNumDSL ewmulAssociativity + printTitle "######################## reluAssociativity ####################" + verifyNumDSL reluAssociativity + printTitle "######################## geometry #############################" + verifyNumDSL geometry diff --git a/rules/taso/conv/Main.hs b/rules/taso/conv/Main.hs new file mode 100644 index 0000000..b74ee46 --- /dev/null +++ b/rules/taso/conv/Main.hs @@ -0,0 +1,696 @@ +module Main (main) where + +import Data.Proxy +import Grisette hiding (dot, (-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (ConvConfig (..), ConvPadding (..), Padding (..), checkSIMap, combineMap, monitorExprOnFailure, newConstMap, newNonNegMap, newSingletonRClass, newSingletonRClasses, pad, precondition, siRelation) +import TensorRight.Internal.DSL.Identifier (MapIdentifier, RClassIdentifier) +import TensorRight.Internal.DSL.TASO (Activation (..), PaddingMode (..), concat, enlarge, ewadd, ewmul, matmul3D, relu, smul, tasoConv) +import Prelude hiding (concat) + +-- Helper to create standard 2D conv config (explicit H and W) +mkConvConfig :: + RClassIdentifier -> -- B + RClassIdentifier -> -- C (input feature) + RClassIdentifier -> -- F (output feature) + RClassIdentifier -> -- H + RClassIdentifier -> -- W + MapIdentifier -> -- strideH + MapIdentifier -> -- strideW + MapIdentifier -> -- siC + MapIdentifier -> -- siH + MapIdentifier -> -- siW + ConvConfig +mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW = + ConvConfig + { batchRClasses = [ByRClass rclassB], + featureRClasses = [ByRClass rclassC], + outputFeatureRClasses = [ByRClass rclassF], + strides = [ByRClass rclassH --> stridesH, ByRClass rclassW --> stridesW], + contractingSIMaps = [ByRClass rclassC --> siC, ByRClass rclassH --> siH, ByRClass rclassW --> siW] + } + +-- | Rule 1: Convolution bilinearity - scalar multiplication swaps between input and kernel +convBilinearScalarSwapSame :: forall a. NumRule a +convBilinearScalarSwapSame _ = do + let w = ("w" :: a) + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + xw <- smul x w + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Same None inputSizes kernelSizes xw y + + yw <- smul y w + rhs <- tasoConv @a config Same None inputSizes kernelSizes x yw + + rewrite "∀s, p, c, x, y, w. conv(s, p, c, smul(x, w), y) = conv(s, p, c, x, smul(y, w))" lhs rhs + +-- | Rule 1: Convolution bilinearity - scalar multiplication swaps between input and kernel +convBilinearScalarSwapValid :: forall a. NumRule a +convBilinearScalarSwapValid _ = do + let w = ("w" :: a) + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + xw <- smul x w + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Valid None inputSizes kernelSizes xw y + + yw <- smul y w + rhs <- tasoConv @a config Valid None inputSizes kernelSizes x yw + + rewrite "∀s, p, c, x, y, w. conv(s, p, c, smul(x, w), y) = conv(s, p, c, x, smul(y, w))" lhs rhs + +-- | Rule 2: Convolution bilinearity - scalar multiplication on output +convBilinearScalarOutputSame :: forall a. NumRule a +convBilinearScalarOutputSame _ = do + let w = ("w" :: a) + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + convXY <- tasoConv @a config Same None inputSizes kernelSizes x y + lhs <- smul convXY w + + xw <- smul x w + rhs <- tasoConv @a config Same None inputSizes kernelSizes xw y + + rewrite "∀s, p, x, y, w. smul(conv(s, p, Anone, x, y), w) = conv(s, p, Anone, smul(x, w), y)" lhs rhs + +-- | Rule 2: Convolution bilinearity - scalar multiplication on output +convBilinearScalarOutputValid :: forall a. NumRule a +convBilinearScalarOutputValid _ = do + let w = ("w" :: a) + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + convXY <- tasoConv @a config Valid None inputSizes kernelSizes x y + lhs <- smul convXY w + + xw <- smul x w + rhs <- tasoConv @a config Valid None inputSizes kernelSizes xw y + + rewrite "∀s, p, x, y, w. smul(conv(s, p, Anone, x, y), w) = conv(s, p, Anone, smul(x, w), y)" lhs rhs + +-- | Rule 3: Convolution bilinearity - addition on kernel +convBilinearKernelAddSame :: forall a. NumRule a +convBilinearKernelAddSame _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + z <- newTensor @a "z" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + yz <- ewadd y z + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Same None inputSizes kernelSizes x yz + + convXY <- tasoConv @a config Same None inputSizes kernelSizes x y + convXZ <- tasoConv @a config Same None inputSizes kernelSizes x z + rhs <- ewadd convXY convXZ + + rewrite "∀s, p, x, y, z. conv(s, p, Anone, x, ewadd(y, z)) = ewadd(conv(s, p, Anone, x, y), conv(s, p, Anone, x, z))" lhs rhs + +-- | Rule 3: Convolution bilinearity - addition on kernel +convBilinearKernelAddValid :: forall a. NumRule a +convBilinearKernelAddValid _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + z <- newTensor @a "z" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + yz <- ewadd y z + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Valid None inputSizes kernelSizes x yz + + convXY <- tasoConv @a config Valid None inputSizes kernelSizes x y + convXZ <- tasoConv @a config Valid None inputSizes kernelSizes x z + rhs <- ewadd convXY convXZ + + rewrite "∀s, p, x, y, z. conv(s, p, Anone, x, ewadd(y, z)) = ewadd(conv(s, p, Anone, x, y), conv(s, p, Anone, x, z))" lhs rhs + +-- | Rule 4: Convolution bilinearity - addition on input +convBilinearInputAddSame :: forall a. NumRule a +convBilinearInputAddSame _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + [sizeB1, sizeB2] <- newMaps ["sizeB1", "sizeB2"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + z <- newTensor @a "z" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + xy <- ewadd x y + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Same None inputSizes kernelSizes xy z + + convXZ <- tasoConv @a config Same None inputSizes kernelSizes x z + convYZ <- tasoConv @a config Same None inputSizes kernelSizes y z + rhs <- ewadd convXZ convYZ + + rewrite "∀s, p, x, y, z. conv(s, p, Anone, ewadd(x, y), z) = ewadd(conv(s, p, Anone, x, z), conv(s, p, Anone, y, z))" lhs rhs + +-- | Rule 4: Convolution bilinearity - addition on input +convBilinearInputAddValid :: forall a. NumRule a +convBilinearInputAddValid _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + [sizeB1, sizeB2] <- newMaps ["sizeB1", "sizeB2"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + z <- newTensor @a "z" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + xy <- ewadd x y + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Valid None inputSizes kernelSizes xy z + + convXZ <- tasoConv @a config Valid None inputSizes kernelSizes x z + convYZ <- tasoConv @a config Valid None inputSizes kernelSizes y z + rhs <- ewadd convXZ convYZ + + rewrite "∀s, p, x, y, z. conv(s, p, Anone, ewadd(x, y), z) = ewadd(conv(s, p, Anone, x, z), conv(s, p, Anone, y, z))" lhs rhs + +-- | Rule 5: Convolution with SAME padding and kernel enlarge (2D) +convSameEnlarge :: forall a. NumRule a +convSameEnlarge _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = + ConvConfig + { batchRClasses = [ByRClass rclassB], + featureRClasses = [ByRClass rclassC], + outputFeatureRClasses = [ByRClass rclassF], + strides = [ByRClass rclassH --> stridesH, ByRClass rclassW --> stridesW], + contractingSIMaps = [ByRClass rclassC --> siC, ByRClass rclassH --> siH, ByRClass rclassW --> siW] + } + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Same None inputSizes kernelSizes x y + + -- Enlarge kernel along H and W using built-in 2D enlarge + hLow <- newNonNegMap "hLow" rclassH + wLow <- newNonNegMap "wLow" rclassW + let kHy = ssym "kHy" :: SymInteger + let kWx = ssym "kWx" :: SymInteger + yEnlarged <- enlarge @a (ByRClass rclassH --> kernelH) (ByRClass rclassW --> kernelW) hLow wLow kHy kWx y + + -- Compute enlarged kernel size maps (max with kHy/kWx) for SAME conv sizes + kHMap <- newConstMap "kHMap" kHy rclassH + kWMap <- newConstMap "kWMap" kWx rclassW + precondition [kHMap] $ \[k] -> k .>= 0 + precondition [kWMap] $ \[k] -> k .>= 0 + kernelH' <- combineMap "kernelH'" (\[s, k] -> symIte (s .>= k) s k) [kernelH, kHMap] + kernelW' <- combineMap "kernelW'" (\[s, k] -> symIte (s .>= k) s k) [kernelW, kWMap] + let kernelSizesRHS = [ByRClass rclassH --> kernelH', ByRClass rclassW --> kernelW'] + rhs <- tasoConv @a config Same None inputSizes kernelSizesRHS x yEnlarged + + rewrite "conv(SAME, x, y) = conv(SAME, x, enlargeKernel2D(y))" lhs rhs + +-- convIdentitySame :: forall a. NumRule a +-- convIdentitySame _ = do +-- -- RClasses: batch, channel/feature (shared!), spatial H/W +-- [rclassB, rclassCF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassCF", "rclassH", "rclassW"] + +-- -- Size maps +-- sizeB <- newMap "sizeB" rclassB +-- sizeCF <- newMap "sizeCF" rclassCF -- ONE map for both input C and output F +-- inputH <- newMap "inputH" rclassH +-- inputW <- newMap "inputW" rclassW +-- kernelH <- newMap "kernelH" rclassH +-- kernelW <- newMap "kernelW" rclassW + +-- -- Input tensor [B, CF, H, W] +-- x <- newTensor @a "x" [rclassB --> sizeB, rclassCF --> sizeCF, rclassH --> inputH, rclassW --> inputW] + +-- -- Strides and SI maps +-- stridesH <- newMap "stridesH" rclassH +-- stridesW <- newMap "stridesW" rclassW +-- siCF <- newMap "siCF" rclassCF -- ONE SI map for channel/feature +-- siH <- newMap "siH" rclassH +-- siW <- newMap "siW" rclassW + +-- let config = +-- ConvConfig +-- { batchRClasses = [ByRClass rclassB], +-- featureRClasses = [ByRClass rclassCF], -- Same RClass for input/output! +-- outputFeatureRClasses = [ByRClass rclassCF], -- Same RClass for input/output! +-- strides = [ByRClass rclassH --> stridesH, ByRClass rclassW --> stridesW], +-- contractingSIMaps = [ByRClass rclassCF --> siCF, ByRClass rclassH --> siH, ByRClass rclassW --> siW] +-- } + +-- -- SAME padding + stride=1 preconditions +-- precondition [stridesH] $ \[s] -> s .== 1 +-- precondition [stridesW] $ \[s] -> s .== 1 + +-- -- Odd kernel sizes via centers: k = 2*c + 1 +-- let cH = ssym "centerH" :: SymInteger +-- let cW = ssym "centerW" :: SymInteger +-- precondition [kernelH] $ \[k] -> k .== (2 * cH + 1) +-- precondition [kernelW] $ \[k] -> k .== (2 * cW + 1) + +-- -- Build explicit identity kernel: 1 when (CF_out==CF_in && H==cH && W==cW), else 0 +-- -- Since kernel has rclassCF twice, we MUST use labels to disambiguate +-- let kShape = [rclassCF --> sizeCF @@ "CFout", rclassCF --> sizeCF @@ "CFin", rclassH --> kernelH, rclassW --> kernelW] +-- iCF_out <- iota kShape (ByLabel "CFout") -- Output feature index +-- iCF_in <- iota kShape (ByLabel "CFin") -- Input feature index +-- iH <- iota kShape (ByRClass rclassH) +-- iW <- iota kShape (ByRClass rclassW) + +-- cHTensor <- constant @TensorInt (nonInf cH) kShape +-- cWTensor <- constant @TensorInt (nonInf cW) kShape + +-- condCF <- compareOp Eqv iCF_out iCF_in -- Diagonal: channel i maps to channel i +-- condH <- compareOp Eqv iH cHTensor +-- condW <- compareOp Eqv iW cWTensor +-- condHW <- boolBinOp And condH condW +-- condAll <- boolBinOp And condCF condHW + +-- one <- constant @a 1 kShape +-- zero <- constant @a 0 kShape +-- idKernel <- select condAll one zero + +-- let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] +-- let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + +-- lhs <- tasoConv @a config Same None inputSizes kernelSizes x idKernel +-- let rhs = x -- No transformation needed! Shapes already match: [B, CF, H, W] +-- rewrite "∀x. conv(SAME, stride=1; identity-kernel) = x" lhs rhs + +-- | Rule 6: Convolution with ReLU activation +convReluActivation :: forall a. NumRule a +convReluActivation _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + lhs <- tasoConv @a config Valid Relu inputSizes kernelSizes x y + + convXY <- tasoConv @a config Valid None inputSizes kernelSizes x y + rhs <- relu @a convXY + + rewrite "∀s, p, x, y. conv(s, p, Arelu, x, y) = relu(conv(s, p, Anone, x, y))" lhs rhs + +-- | Rule 7: Concatenation along batch dimension distributes over conv +convConcatInputBatchSame :: forall a. NumRule a +convConcatInputBatchSame _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + [sizeB1, sizeB2] <- newMaps ["sizeB1", "sizeB2"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + z <- newTensor @a "z" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + convXZ <- tasoConv @a config Same None inputSizes kernelSizes x z + convYZ <- tasoConv @a config Same None inputSizes kernelSizes y z + lhs <- concat (ByRClass rclassB) convXZ convYZ + + xy <- concat (ByRClass rclassB) x y + rhs <- tasoConv @a config Same None inputSizes kernelSizes xy z + + rewrite "∀s, p, c, x, y, z. concat(0, conv(s, p, c, x, z), conv(s, p, c, y, z)) = conv(s, p, c, concat(0, x, y), z)" lhs rhs + +-- | Rule 7: Concatenation along batch dimension distributes over conv +convConcatInputBatchValid :: forall a. NumRule a +convConcatInputBatchValid _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + sizeF <- newMap "sizeF" rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + [sizeB1, sizeB2] <- newMaps ["sizeB1", "sizeB2"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + z <- newTensor @a "z" [rclassF --> sizeF, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + convXZ <- tasoConv @a config Valid None inputSizes kernelSizes x z + convYZ <- tasoConv @a config Valid None inputSizes kernelSizes y z + lhs <- concat (ByRClass rclassB) convXZ convYZ + + xy <- concat (ByRClass rclassB) x y + rhs <- tasoConv @a config Valid None inputSizes kernelSizes xy z + + rewrite "∀s, p, c, x, y, z. concat(0, conv(s, p, c, x, z), conv(s, p, c, y, z)) = conv(s, p, c, concat(0, x, y), z)" lhs rhs + +-- | Rule 8: Concatenation along output feature dimension distributes over conv +convConcatKernelSame :: forall a. NumRule a +convConcatKernelSame _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + [sizeF1, sizeF2] <- newMaps ["sizeF1", "sizeF2"] rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF1, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + z <- newTensor @a "z" [rclassF --> sizeF2, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + convXY <- tasoConv @a config Same None inputSizes kernelSizes x y + convXZ <- tasoConv @a config Same None inputSizes kernelSizes x z + lhs <- concat (ByRClass rclassF) convXY convXZ + + yz <- concat (ByRClass rclassF) y z + rhs <- tasoConv @a config Same None inputSizes kernelSizes x yz + + rewrite "∀s, p, c, x, y, z. concat(1, conv(s, p, c, x, y), conv(s, p, c, x, z)) = conv(s, p, c, x, concat(0, y, z))" lhs rhs + +-- | Rule 8: Concatenation along output feature dimension distributes over conv +convConcatKernelValid :: forall a. NumRule a +convConcatKernelValid _ = do + [rclassB, rclassC, rclassF, rclassH, rclassW] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassH", "rclassW"] + sizeC <- newMap "sizeC" rclassC + [sizeF1, sizeF2] <- newMaps ["sizeF1", "sizeF2"] rclassF + inputH <- newMap "inputH" rclassH + inputW <- newMap "inputW" rclassW + kernelH <- newMap "kernelH" rclassH + kernelW <- newMap "kernelW" rclassW + sizeB <- newMap "sizeB" rclassB + + x <- newTensor @a "x" [rclassB --> sizeB, rclassC --> sizeC, rclassH --> inputH, rclassW --> inputW] + y <- newTensor @a "y" [rclassF --> sizeF1, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + z <- newTensor @a "z" [rclassF --> sizeF2, rclassC --> sizeC, rclassH --> kernelH, rclassW --> kernelW] + + stridesH <- newMap "stridesH" rclassH + stridesW <- newMap "stridesW" rclassW + siC <- newMap "siC" rclassC + siH <- newMap "siH" rclassH + siW <- newMap "siW" rclassW + + let config = mkConvConfig rclassB rclassC rclassF rclassH rclassW stridesH stridesW siC siH siW + + let inputSizes = [ByRClass rclassH --> inputH, ByRClass rclassW --> inputW] + let kernelSizes = [ByRClass rclassH --> kernelH, ByRClass rclassW --> kernelW] + convXY <- tasoConv @a config Valid None inputSizes kernelSizes x y + convXZ <- tasoConv @a config Valid None inputSizes kernelSizes x z + lhs <- concat (ByRClass rclassF) convXY convXZ + + yz <- concat (ByRClass rclassF) y z + rhs <- tasoConv @a config Valid None inputSizes kernelSizes x yz + + rewrite "∀s, p, c, x, y, z. concat(1, conv(s, p, c, x, y), conv(s, p, c, x, z)) = conv(s, p, c, x, concat(0, y, z))" lhs rhs + +-- | Rule 9: Concatenation on input channels with matching concatenation on kernel features +-- convConcatMixed :: forall a. NumRule a +-- convConcatMixed _ = do +-- [rclassB, rclassC, rclassF, rclassSpatial] <- newSingletonRClasses ["rclassB", "rclassC", "rclassF", "rclassSpatial"] +-- [sizeC1, sizeC2] <- newMaps ["sizeC1", "sizeC2"] rclassC +-- sizeF <- newMap "sizeF" rclassF +-- inputSpatial <- newMap "inputSpatial" rclassSpatial +-- kernelSpatial <- newMap "kernelSpatial" rclassSpatial +-- [sizeB1, sizeB2] <- newMaps ["sizeB1", "sizeB2"] rclassB + +-- x <- newTensor @a "x" [rclassB --> sizeB1, rclassC --> sizeC1, rclassSpatial --> inputSpatial] +-- z <- newTensor @a "z" [rclassB --> sizeB2, rclassC --> sizeC2, rclassSpatial --> inputSpatial] +-- y <- newTensor @a "y" [rclassF --> sizeF, rclassC --> sizeC1, rclassSpatial --> kernelSpatial] +-- w <- newTensor @a "w" [rclassF --> sizeF, rclassC --> sizeC2, rclassSpatial --> kernelSpatial] + +-- strides <- newMap "strides" rclassSpatial +-- siSpatial <- newMap "siSpatial" rclassSpatial + +-- let inputSizes = [ByRClass rclassSpatial --> inputSpatial] +-- let kernelSizes = [ByRClass rclassSpatial --> kernelSpatial] + +-- -- LHS: concat on channels then convolve with one SI map +-- xz <- concat (ByRClass rclassC) x z +-- yw <- concat (ByRClass rclassC) y w +-- siC_LHS <- newMap "siC_LHS" rclassC +-- let configLHS = mkConvConfig rclassB rclassC rclassF rclassSpatial strides siC_LHS siSpatial +-- lhs <- tasoConv @a configLHS Valid None inputSizes kernelSizes xz yw + +-- -- RHS: separate convs with separate SI maps, then add +-- siC_RHS1 <- newMap "siC_RHS1" rclassC +-- siC_RHS2 <- newMap "siC_RHS2" rclassC +-- let configRHS1 = mkConvConfig rclassB rclassC rclassF rclassSpatial strides siC_RHS1 siSpatial +-- let configRHS2 = mkConvConfig rclassB rclassC rclassF rclassSpatial strides siC_RHS2 siSpatial +-- convXY <- tasoConv @a configRHS1 Valid None inputSizes kernelSizes x y +-- convZW <- tasoConv @a configRHS2 Valid None inputSizes kernelSizes z w +-- rhs <- ewadd convXY convZW + +-- -- SI relation: LHS channel SI maps to the same value in both RHS convs +-- siRelation [siC_LHS, siC_RHS1] $ \[l, r] -> l .== r +-- siRelation [siC_LHS, siC_RHS2] $ \[l, r] -> l .== r +-- checkSIMap [siC_LHS] [siC_RHS1, siC_RHS2] + +-- rewrite "concat on channels splits convolution" lhs rhs + +main :: IO () +main = do + printTitle "#################### convBilinearScalarSwapSame ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearScalarSwapSame + + printTitle "#################### convBilinearScalarSwapValid ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearScalarSwapValid + + printTitle "#################### convBilinearScalarOutputSame ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearScalarOutputSame + + printTitle "#################### convBilinearScalarOutputValid ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearScalarOutputValid + + printTitle "#################### convBilinearKernelAddSame ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearKernelAddSame + + printTitle "#################### convBilinearKernelAddValid ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearKernelAddValid + + printTitle "#################### convBilinearInputAddSame ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearInputAddSame + + printTitle "#################### convBilinearInputAddValid ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convBilinearInputAddValid + + printTitle "#################### convSameEnlarge ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convSameEnlarge + + -- printTitle "#################### convIdentitySame ####################" + -- verifyNumDSLWith (withTimeout 15000000 z3) convIdentitySame + + printTitle "#################### convReluActivation ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convReluActivation + + printTitle "#################### convConcatInputBatchSame ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convConcatInputBatchSame + + printTitle "#################### convConcatInputBatchValid ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convConcatInputBatchValid + + printTitle "#################### convConcatKernelSame ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convConcatKernelSame + + printTitle "#################### convConcatKernelValid ####################" + verifyNumDSLWith (withTimeout 15000000 z3) convConcatKernelValid + +-- printTitle "#################### convConcatMixed ####################" +-- verifyNumDSL convConcatMixed diff --git a/rules/taso/enlarge/Main.hs b/rules/taso/enlarge/Main.hs new file mode 100644 index 0000000..1ea5604 --- /dev/null +++ b/rules/taso/enlarge/Main.hs @@ -0,0 +1,60 @@ +module Main (main) where + +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (newSingletonRClass) +import TensorRight.Internal.DSL.TASO + +desugarEnlarge :: forall a. NumRule a +desugarEnlarge _ = do + rclass <- newSingletonRClass "rclass" + [n, c, h, w] <- newMaps ["n", "c", "h", "w"] rclass + tA <- newTensor @a "A" [rclass --> n @@ "N", rclass --> c @@ "C", rclass --> h @@ "H", rclass --> w @@ "W"] + let kx = ssym "kx" :: SymInteger + let ky = ssym "ky" :: SymInteger + + -- Prepare shared low padding maps to be used by both LHS and RHS + hLow <- newNonNegMap "hLow" rclass + wLow <- newNonNegMap "wLow" rclass + + lhs <- enlarge @a (ByLabel "H" --> h) (ByLabel "W" --> w) hLow wLow ky kx tA + + -- Building the rhs + kH <- newConstMap "kH" ky rclass + kW <- newConstMap "kW" kx rclass + + precondition [kH] $ \[k] -> k .>= 0 + precondition [kW] $ \[k] -> k .>= 0 + + sH' <- combineMap "sH'" (\[a', k'] -> symIte (a' .>= k') a' k') [h, kH] + sW' <- combineMap "sW'" (\[a', k'] -> symIte (a' .>= k') a' k') [w, kW] + + dH <- combineMap "dH" (\[m, a'] -> m - a') [sH', h] + dW <- combineMap "dW" (\[m, a'] -> m - a') [sW', w] + + -- Deterministic split on RHS: low = floor(d/2), high = d - low + precondition [hLow, dH] $ \[l, d] -> (l + l) .<= d .&& d .<= (l + l + 1) + hHigh <- combineMap "hHigh" (\[d, l] -> d - l) [dH, hLow] + + precondition [wLow, dW] $ \[l, d] -> (l + l) .<= d .&& d .<= (l + l + 1) + wHigh <- combineMap "wHigh" (\[d, l] -> d - l) [dW, wLow] + + z <- newConstMap "zero" 0 rclass + + let hRef = ByLabel "H" + let wRef = ByLabel "W" + + rhs <- + pad tA (0 :: a) $ + Padding + { low = [hRef --> hLow, wRef --> wLow], + interior = [hRef --> z, wRef --> z], + high = [hRef --> hHigh, wRef --> wHigh] + } + + rewrite "enlarge(ky, kx, A) ⇒ Pad(A, 0, symmetric(ky on H, kx on W)))" lhs rhs + +main :: IO () +main = do + printTitle "######################## desugarEnlarge ########################" + verifyNumDSL desugarEnlarge diff --git a/rules/taso/ewadd/Main.hs b/rules/taso/ewadd/Main.hs new file mode 100644 index 0000000..a680ce1 --- /dev/null +++ b/rules/taso/ewadd/Main.hs @@ -0,0 +1,45 @@ +module Main (main) where + +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.TASO (ewadd) + +desugar :: forall a. NumRule a -- Verify desugaring +desugar _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + tA <- newTensor @a "A" [rclass --> map] + tB <- newTensor @a "B" [rclass --> map] + lhs <- ewadd tA tB + rhs <- numBinOp Add tA tB + rewrite "ewadd(A, B) ⇒ Add(A, B)" lhs rhs + +associativity :: forall a. NumRule a -- Associativity +associativity _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "y" [rclass --> map] + z <- newTensor @a "z" [rclass --> map] + lhs <- ewadd x $ ewadd y z + rhs <- ewadd (ewadd x y) z + rewrite "ewadd(x, ewadd(y, z)) ⇒ Add(ewadd(x, y), z)" lhs rhs + +commutativity :: forall a. NumRule a -- Verify commutative +commutativity _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "y" [rclass --> map] + lhs <- ewadd x y + rhs <- ewadd y x + rewrite "ewadd(x, y) ⇒ ewadd(y, x)" lhs rhs + +main :: IO () +main = do + printTitle "############################## desugar ##############################" + verifyNumDSL desugar + printTitle "############################## associativity ##############################" + verifyNumDSL associativity + printTitle "############################## commutativity ##############################" + verifyNumDSL commutativity diff --git a/rules/taso/ewmul/Main.hs b/rules/taso/ewmul/Main.hs new file mode 100644 index 0000000..7ca7744 --- /dev/null +++ b/rules/taso/ewmul/Main.hs @@ -0,0 +1,74 @@ +module Main (main) where + +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (newRClass) +import TensorRight.Internal.DSL.TASO (ewadd, ewmul) + +desugar :: forall a. NumRule a -- Verify desugaring +desugar _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + tA <- newTensor @a "A" [rclass --> map] + tB <- newTensor @a "B" [rclass --> map] + lhs <- ewmul tA tB + rhs <- numBinOp Mul tA tB + rewrite "ewmul(A, B) ⇒ Mul(A, B)" lhs rhs + +associativity :: forall a. NumRule a -- Associativity +associativity _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "t" [rclass --> map] + z <- newTensor @a "z" [rclass --> map] + lhs <- ewmul x $ ewmul y z + rhs <- ewmul (ewmul x y) z + rewrite "ewmul(x, ewmul(y, z)) ⇒ mul(ewmul(x, y), z)" lhs rhs + +commutativity :: forall a. NumRule a -- Verify commutative +commutativity _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "y" [rclass --> map] + lhs <- ewmul x y + rhs <- ewmul y x + rewrite "ewmul(x, y) ⇒ ewmul(y, x)" lhs rhs + +distributivity :: forall a. NumRule a -- Verify distributivity +distributivity _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "y" [rclass --> map] + z <- newTensor @a "z" [rclass --> map] + lhs <- ewmul (ewadd x y) z + rhs <- ewadd (ewmul x z) (ewmul y z) + rewrite "ewmul(ewadd(x, y), z) ⇒ ewadd(ewmul(x, z), ewmul(y, z))" lhs rhs + +identity :: forall a. NumRule a -- Verify identity +identity _ = do + rclassN <- newRClass "rclassN" + sizeN <- newMap "sizeN" rclassN + + x <- newTensor @a "x" [rclassN --> sizeN] + ones <- constant @a 1 [rclassN --> sizeN] + + lhs <- ewmul x ones + let rhs = x + + rewrite "ewmul(x, I) ⇒ x" lhs rhs + +main :: IO () +main = do + printTitle "############################## desugar ##############################" + verifyNumDSL desugar + printTitle "############################## associativity ##############################" + verifyNumDSL associativity + printTitle "############################## commutativity ##############################" + verifyNumDSL commutativity + printTitle "############################## distributivity ##############################" + verifyNumDSL distributivity + printTitle "############################## identity ##############################" + verifyNumDSL identity diff --git a/rules/taso/matmul2D/Main.hs b/rules/taso/matmul2D/Main.hs new file mode 100644 index 0000000..1817ee8 --- /dev/null +++ b/rules/taso/matmul2D/Main.hs @@ -0,0 +1,204 @@ +module Main (main) where + +import Grisette hiding (dot, (-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (checkSIMap, monitorExprOnFailure, newSingletonRClass, newSingletonRClasses, siRelation) +import TensorRight.Internal.DSL.TASO (concat, ewadd, ewmul, matmul2D, relu, smul, transpose) +import Prelude hiding (concat) + +-- | Rule 1: Matrix multiplication associativity +-- ∀x, y, z. matmul(x, matmul(y, z)) = matmul(matmul(x, y), z) +matmulAssociativity :: forall a. NumRule a +matmulAssociativity _ = do + [rclassM, rclassK, rclassN, rclassP] <- newSingletonRClasses ["rclassM", "rclassK", "rclassN", "rclassP"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + sizeP <- newMap "sizeP" rclassP + + x <- newTensor @a "x" [rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassK --> sizeK, rclassN --> sizeN] -- Shared K and N + z <- newTensor @a "z" [rclassN --> sizeN, rclassP --> sizeP] -- Shared N + nL <- newMap "contractSI" rclassN + nR <- newMap "contractSI" rclassN + + kL <- newMap "contractSI" rclassK + kR <- newMap "contractSI" rclassK + + siRelation [kL, kR] $ \[l, r] -> l .== r + siRelation [nL, nR] $ \[l, r] -> l .== r + checkSIMap [kL, nL] [kR, nR] + + lhs <- matmul2D x (matmul2D y z [rclassN --> nL]) [rclassK --> kL] + rhs <- matmul2D (matmul2D x y [rclassK --> kR]) z [rclassN --> nR] + + rewrite "matmul(x, matmul(y, z)) ⇒ matmul(matmul(x, y), z)" lhs rhs + +-- | Rule 2: Matrix multiplication is linear (scalar multiplication) +-- ∀x, y, w. smul(matmul(x, y), w) = matmul(x, smul(y, w)) +matmulScalarLinear :: forall a. NumRule a +matmulScalarLinear _ = do + let w = ("w" :: a) + [rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + + x <- newTensor @a "x" [rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassK --> sizeK, rclassN --> sizeN] + + kL <- newMap "contractSI" rclassK + kR <- newMap "contractSI" rclassK + xy <- matmul2D x y [rclassK --> kL] + lhs <- smul xy w + yw <- smul y w + rhs <- matmul2D x yw [rclassK --> kR] + + siRelation [kL, kR] $ \[l, r] -> l .== r + checkSIMap [kL] [kR] + + rewrite "smul(matmul(x, y), w) ⇒ matmul(x, smul(y, w))" lhs rhs + +-- | Rule 3: Matrix multiplication distributes over addition +-- ∀x, y, z. matmul(x, ewadd(y, z)) = ewadd(matmul(x, y), matmul(x, z)) +matmulDistributive :: forall a. NumRule a +matmulDistributive _ = do + [rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + + x <- newTensor @a "x" [rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassK --> sizeK, rclassN --> sizeN] + z <- newTensor @a "z" [rclassK --> sizeK, rclassN --> sizeN] + + yz <- ewadd y z + kL <- newMap "contractSI" rclassK + kR1 <- newMap "contractSI" rclassK + kR2 <- newMap "contractSI" rclassK + lhs <- matmul2D x yz [rclassK --> kL] + xy <- matmul2D x y [rclassK --> kR1] + xz <- matmul2D x z [rclassK --> kR2] + rhs <- ewadd xy xz + + siRelation [kL, kR1] $ \[l, r] -> l .== r + siRelation [kL, kR2] $ \[l, r] -> l .== r + checkSIMap [kL] [kR1, kR2] + + rewrite "matmul(x, ewadd(y, z)) ⇒ ewadd(matmul(x, y), matmul(x, z))" lhs rhs + +-- | Rule 4: Matrix multiplication and transpose interaction +-- ∀x, y. transpose(matmul(x, y)) = matmul(transpose(y), transpose(x)) +matmulTranspose :: forall a. NumRule a +matmulTranspose _ = do + [rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + + -- Labelled inputs + x <- newTensor @a "x" [rclassM --> sizeM @@ "L", rclassK --> sizeK @@ "K"] + y <- newTensor @a "y" [rclassK --> sizeK @@ "K", rclassN --> sizeN @@ "R"] + + -- LHS: transpose(matmul(x, y)) + kL <- newMap "contractSI" rclassK + -- xy should be [rclassM @@ "L" rclassN @@ "R"] + xy <- matmul2D x y [ByLabel "K" --> kL] + -- lhs should be [rclassM @@ "R" rclassN @@ "L"] + lhs <- transpose xy + + -- RHS: matmul(transpose(y), transpose(x)) + yt <- transpose y -- [rclassK @@ "R", rclassN @@ "K"] + xt0 <- transpose x -- [rclassM @@ "K", rclassK @@ "L"] + xt <- relabel xt0 [ByLabel "L" --> ByLabel "R", ByLabel "K" --> ByLabel "K'"] -- [rclassM @@ "K'", rclassK @@ "R"] + kR <- newMap "contractSI" rclassK + rhs0 <- matmul2D yt xt [ByLabel "R" --> kR] -- [rclassN @@ "K", rclassM @@ "K'"] + -- rhs should now also be [rclassM @@ "R" rclassN @@ "L"] + rhs <- relabel rhs0 [ByLabel "K'" --> ByLabel "R", ByLabel "K" --> ByLabel "L"] + + siRelation [kL, kR] $ \[i, j] -> i .== j + checkSIMap [kL] [kR] + + rewrite "transpose(matmul(x, y)) ⇒ matmul(transpose(y), transpose(x))" lhs rhs + +-- -- | Rule 6: Concatenation and matrix multiplication (right distributive) +-- -- ∀x, y, z. concat(1, matmul(x, y), matmul(x, z)) = matmul(x, concat(1, y, z)) +matmulConcatRight :: forall a. NumRule a +matmulConcatRight _ = do + [rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + [sizeN1, sizeN2] <- newMaps ["sizeN1", "sizeN2"] rclassN + + x <- newTensor @a "x" [rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassK --> sizeK, rclassN --> sizeN1] + z <- newTensor @a "z" [rclassK --> sizeK, rclassN --> sizeN2] + + kL1 <- newMap "contractSI" rclassK + kL2 <- newMap "contractSI" rclassK + xy <- matmul2D x y [rclassK --> kL1] + xz <- matmul2D x z [rclassK --> kL2] + lhs <- concat (ByRClass rclassN) xy xz + yz <- concat (ByRClass rclassN) y z + kR <- newMap "contractSI" rclassK + rhs <- matmul2D x yz [rclassK --> kR] + + siRelation [kL1, kR] $ \[l, r] -> l .== r + siRelation [kL2, kR] $ \[l, r] -> l .== r + checkSIMap [kL1, kL2] [kR] + + rewrite "concat(1, matmul(x, y), matmul(x, z)) ⇒ matmul(x, concat(1, y, z))" lhs rhs + +-- | Rule 7: Concatenation and matrix multiplication (mixed) +-- ∀x, y, z, w. matmul(concat(1, x, z), concat(0, y, w)) = ewadd(matmul(x, y), matmul(z, w)) +-- matmulConcatMixed :: forall a. NumRule a +-- matmulConcatMixed _ = do +-- [rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassM", "rclassK", "rclassN"] +-- [sizeM1, sizeM2] <- newMaps ["sizeM1", "sizeM2"] rclassM +-- [sizeK1, sizeK2] <- newMaps ["sizeK1", "sizeK2"] rclassK +-- sizeN <- newMap "sizeN" rclassN + +-- x <- newTensor @a "x" [rclassM --> sizeM1, rclassK --> sizeK1] +-- y <- newTensor @a "y" [rclassK --> sizeK1, rclassN --> sizeN] +-- z <- newTensor @a "z" [rclassM --> sizeM2, rclassK --> sizeK2] +-- w <- newTensor @a "w" [rclassK --> sizeK2, rclassN --> sizeN] + +-- -- Dimensions must match appropriately for concatenation and matmul +-- precondition [sizeM1, sizeM2] $ \[m1, m2] -> m1 .== m2 +-- precondition [sizeK1, sizeK2] $ \[k1, k2] -> k1 .== k2 + +-- xz <- concat (ByRClass rclassM) x z +-- yw <- concat (ByRClass rclassK) y w +-- kL <- newMap "contractSI" rclassK +-- kR1 <- newMap "contractSI" rclassK +-- kR2 <- newMap "contractSI" rclassK +-- lhs <- matmul2D xz yw [rclassK --> kL] +-- xy <- matmul2D x y [rclassK --> kR1] +-- zw <- matmul2D z w [rclassK --> kR2] +-- rhs <- ewadd xy zw + +-- siRelation [kL, kR1] $ \[l, r] -> l .== r +-- siRelation [kL, kR2] $ \[l, r] -> l .== r +-- checkSIMap [kL] [kR1, kR2] + +-- rewrite "matmul(concat(1, x, z), concat(0, y, w)) ⇒ ewadd(matmul(x, y), matmul(z, w))" lhs rhs + +main :: IO () +main = do + printTitle "#################### matmulAssociativity ####################" + verifyNumDSL matmulAssociativity + + printTitle "#################### matmulScalarLinear #####################" + verifyNumDSL matmulScalarLinear + + printTitle "#################### matmulDistributive #####################" + verifyNumDSL matmulDistributive + + printTitle "###################### matmulTranspose ######################" + verifyNumDSL matmulTranspose + + printTitle "#################### matmulConcatRight ######################" + verifyNumDSL matmulConcatRight + +-- printTitle "##################### matmulConcatMixed #####################" +-- verifyNumDSL matmulConcatMixed diff --git a/rules/taso/matmul3D/Main.hs b/rules/taso/matmul3D/Main.hs new file mode 100644 index 0000000..681dd9c --- /dev/null +++ b/rules/taso/matmul3D/Main.hs @@ -0,0 +1,187 @@ +module Main (main) where + +import Grisette hiding (dot, (-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (checkSIMap, monitorExprOnFailure, newSingletonRClass, newSingletonRClasses, siRelation) +import TensorRight.Internal.DSL.TASO (concat, ewadd, ewmul, matmul3D, relu, smul, transpose) +import Prelude hiding (concat) + +-- | Rule 1: Matrix multiplication associativity (batched) +matmulAssociativity :: forall a. NumRule a +matmulAssociativity _ = do + [rclassB, rclassM, rclassK, rclassN, rclassP] <- newSingletonRClasses ["rclassB", "rclassM", "rclassK", "rclassN", "rclassP"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + sizeP <- newMap "sizeP" rclassP + [sizeB1, sizeB2, sizeB3] <- newMaps ["sizeB1", "sizeB2", "sizeB3"] rclassB + + -- x:[B,M,K], y:[B,K,N], z:[B,N,P] + x <- newTensor @a "x" [rclassB --> sizeB1, rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassK --> sizeK, rclassN --> sizeN] + z <- newTensor @a "z" [rclassB --> sizeB3, rclassN --> sizeN, rclassP --> sizeP] + + -- independent contraction maps for K and N + kL <- newMap "kL" rclassK + kR <- newMap "kR" rclassK + nL <- newMap "nL" rclassN + nR <- newMap "nR" rclassN + + siRelation [kL, kR] $ \[l, r] -> l .== r + siRelation [nL, nR] $ \[l, r] -> l .== r + checkSIMap [kL, nL] [kR, nR] + + -- LHS: x · (y · z), inner contracts N + yz <- matmul3D y z [rclassN --> nL] [ByRClass rclassB] + lhs <- matmul3D x yz [rclassK --> kL] [ByRClass rclassB] + + -- RHS: (x · y) · z, inner contracts K + xy <- matmul3D x y [rclassK --> kR] [ByRClass rclassB] + rhs <- matmul3D xy z [rclassN --> nR] [ByRClass rclassB] + + rewrite "∀x, y, z. matmul3D(x, matmul3D(y, z)) = matmul3D(matmul3D(x, y), z)" lhs rhs + +-- | Rule 2: Scalar linearity (batched) +matmulScalarLinear :: forall a. NumRule a +matmulScalarLinear _ = do + let w = ("w" :: a) + [rclassB, rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassB", "rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + [sizeB1, sizeB2] <- newMaps ["sizeB1", "sizeB2"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassK --> sizeK, rclassN --> sizeN] + + kL <- newMap "kL" rclassK + kR <- newMap "kR" rclassK + + xy <- matmul3D x y [rclassK --> kL] [ByRClass rclassB] + lhs <- smul xy w + yw <- smul y w + rhs <- matmul3D x yw [rclassK --> kR] [ByRClass rclassB] + + siRelation [kL, kR] $ \[l, r] -> l .== r + checkSIMap [kL] [kR] + + rewrite "∀x, y, w. smul(matmul3D(x, y), w) = matmul3D(x, smul(y, w))" lhs rhs + +-- | Rule 3: Distributivity over addition (batched) +matmulDistributive :: forall a. NumRule a +matmulDistributive _ = do + [rclassB, rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassB", "rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + sizeN <- newMap "sizeN" rclassN + [sizeB1, sizeB2, sizeB3] <- newMaps ["sizeB1", "sizeB2", "sizeB3"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassK --> sizeK, rclassN --> sizeN] + z <- newTensor @a "z" [rclassB --> sizeB3, rclassK --> sizeK, rclassN --> sizeN] + + yz <- ewadd y z + kL <- newMap "kL" rclassK + kR1 <- newMap "kR1" rclassK + kR2 <- newMap "kR2" rclassK + + lhs <- matmul3D x yz [rclassK --> kL] [ByRClass rclassB] + xy <- matmul3D x y [rclassK --> kR1] [ByRClass rclassB] + xz <- matmul3D x z [rclassK --> kR2] [ByRClass rclassB] + rhs <- ewadd xy xz + + siRelation [kL, kR1] $ \[l, r] -> l .== r + siRelation [kL, kR2] $ \[l, r] -> l .== r + checkSIMap [kL] [kR1, kR2] + + rewrite "∀x, y, z. matmul(x, ewadd(y, z)) = ewadd(matmul(x, y), matmul(x, z))" lhs rhs + +-- | Rule 4: Right concatenation (batched) +matmulConcatRight :: forall a. NumRule a +matmulConcatRight _ = do + [rclassB, rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassB", "rclassM", "rclassK", "rclassN"] + sizeM <- newMap "sizeM" rclassM + sizeK <- newMap "sizeK" rclassK + [sizeN1, sizeN2] <- newMaps ["sizeN1", "sizeN2"] rclassN + [sizeB1, sizeB2, sizeB3] <- newMaps ["sizeB1", "sizeB2", "sizeB3"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassM --> sizeM, rclassK --> sizeK] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassK --> sizeK, rclassN --> sizeN1] + z <- newTensor @a "z" [rclassB --> sizeB3, rclassK --> sizeK, rclassN --> sizeN2] + + kL1 <- newMap "kL1" rclassK + kL2 <- newMap "kL2" rclassK + xy <- matmul3D x y [rclassK --> kL1] [ByRClass rclassB] + xz <- matmul3D x z [rclassK --> kL2] [ByRClass rclassB] + lhs <- concat (ByRClass rclassN) xy xz + + yz <- concat (ByRClass rclassN) y z + kR <- newMap "kR" rclassK + rhs <- matmul3D x yz [rclassK --> kR] [ByRClass rclassB] + + siRelation [kL1, kR] $ \[l, r] -> l .== r + siRelation [kL2, kR] $ \[l, r] -> l .== r + checkSIMap [kL1, kL2] [kR] + + rewrite "concat along N moves through matmul3D(x, ·)" lhs rhs + +-- | Rule 5: Concatenation and matrix multiplication (mixed) +-- ∀x, y, z, w. matmul(concat(1, x, z), concat(0, y, w)) = ewadd(matmul(x, y), matmul(z, w)) + +-- | Rule 5: Mixed concatenation (batched) +matmulConcatMixed :: forall a. NumRule a +matmulConcatMixed _ = do + [rclassB, rclassM, rclassK, rclassN] <- newSingletonRClasses ["rclassB", "rclassM", "rclassK", "rclassN"] + [sizeM1, sizeM2] <- newMaps ["sizeM1", "sizeM2"] rclassM + [sizeK1, sizeK2] <- newMaps ["sizeK1", "sizeK2"] rclassK + sizeN <- newMap "sizeN" rclassN + [sizeB1, sizeB2, sizeB3, sizeB4] <- newMaps ["sizeB1", "sizeB2", "sizeB3", "sizeB4"] rclassB + + x <- newTensor @a "x" [rclassB --> sizeB1, rclassM --> sizeM1, rclassK --> sizeK1] + y <- newTensor @a "y" [rclassB --> sizeB2, rclassK --> sizeK1, rclassN --> sizeN] + z <- newTensor @a "z" [rclassB --> sizeB3, rclassM --> sizeM2, rclassK --> sizeK2] + w <- newTensor @a "w" [rclassB --> sizeB4, rclassK --> sizeK2, rclassN --> sizeN] + + -- All batch sizes must match so that B is a proper batch axis + precondition [sizeB1, sizeB2] $ \[b1, b2] -> b1 .== b2 + precondition [sizeB1, sizeB3] $ \[b1, b3] -> b1 .== b3 + precondition [sizeB1, sizeB4] $ \[b1, b4] -> b1 .== b4 + -- Mixed: equal splits on M and K so SI equality is valid on both branches + precondition [sizeM1, sizeM2] $ \[m1, m2] -> m1 .== m2 + precondition [sizeK1, sizeK2] $ \[k1, k2] -> k1 .== k2 + + xm <- concat (ByRClass rclassM) x z + yk <- concat (ByRClass rclassK) y w + + -- Contract along K with equal SI maps across sides + kL <- newMap "kL" rclassK + kR1 <- newMap "kR1" rclassK + kR2 <- newMap "kR2" rclassK + + lhs <- matmul3D xm yk [rclassK --> kL] [ByRClass rclassB] + xy <- matmul3D x y [rclassK --> kR1] [ByRClass rclassB] + zw <- matmul3D z w [rclassK --> kR2] [ByRClass rclassB] + rhs <- ewadd xy zw + + siRelation [kL, kR1] $ \[l, r] -> l .== r + siRelation [kL, kR2] $ \[l, r] -> l .== r + checkSIMap [kL] [kR1, kR2] + + rewrite "matmul3D(concat_M x z, concat_K y w) ⇒ ewadd(matmul3D(x,y), matmul3D(z,w))" lhs rhs + +main :: IO () +main = do + printTitle "#################### matmulAssociativity ####################" + verifyNumDSL matmulAssociativity + + printTitle "#################### matmulScalarLinear #####################" + verifyNumDSL matmulScalarLinear + + printTitle "#################### matmulDistributive #####################" + verifyNumDSL matmulDistributive + + printTitle "#################### matmulConcatRight ######################" + verifyNumDSL matmulConcatRight + +-- printTitle "##################### matmulConcatMixed #####################" +-- verifyNumDSL matmulConcatMixed diff --git a/rules/taso/relu/Main.hs b/rules/taso/relu/Main.hs new file mode 100644 index 0000000..9f107d0 --- /dev/null +++ b/rules/taso/relu/Main.hs @@ -0,0 +1,17 @@ +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.TASO (relu) + +desugar :: forall a. NumRule a -- Verify desugaring +desugar _ = do + rclass <- newRClass "rclass" + map <- newMap "map" rclass + tA <- newTensor @a "A" [rclass --> map] + lhs <- relu @a tA + rhs <- clampScalar @a 0 tA posInf + rewrite "relu(A) ⇒ Clamp(0, A, inf)" lhs rhs + +main :: IO () +main = do + printTitle "############################## desugar ##############################" + verifyNumDSL desugar diff --git a/rules/taso/smul/Main.hs b/rules/taso/smul/Main.hs new file mode 100644 index 0000000..54a8f4d --- /dev/null +++ b/rules/taso/smul/Main.hs @@ -0,0 +1,59 @@ +module Main (main) where + +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.TASO (ewadd, ewmul, smul) + +desugar :: forall a. NumRule a -- Verify desugaring +desugar _ = do + let s = ("s" :: a) + rclass <- newRClass "rclass" + map <- newMap "map" rclass + tA <- newTensor @a "A" [rclass --> map] + lhs <- smul tA s + rhs <- numBinScalarOp Mul tA s + rewrite "smul(A, s) ⇒ Mul(A, s)" lhs rhs + +associativity :: forall a. NumRule a -- Verify associativity +associativity _ = do + let w = ("w" :: a) + let y = ("y" :: a) + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + lhs <- smul (smul x y) w + rhs <- smul x (y * w) -- Multiply the scalars first since smul (y, w) doesn't make sense + rewrite "smul(smul(x, y), w) ⇒ smul(x, smul(y, w))" lhs rhs + +distributivity :: forall a. NumRule a -- Distributivity +distributivity _ = do + let w = ("w" :: a) + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "y" [rclass --> map] + lhs <- smul (ewadd x y) w + rhs <- ewadd (smul x w) (smul y w) + rewrite "smul(ewadd(x, y), w) ⇒ ewadd(smul(x, w), smul(y, w))" lhs rhs + +commutativity :: forall a. NumRule a -- Operator commutativity +commutativity _ = do + let w = ("w" :: a) + rclass <- newRClass "rclass" + map <- newMap "map" rclass + x <- newTensor @a "x" [rclass --> map] + y <- newTensor @a "y" [rclass --> map] + lhs <- smul (ewmul x y) w + rhs <- ewmul x (smul y w) + rewrite "smul(ewmul(x, y), w) ⇒ ewmul(x, smul(y, w))" lhs rhs + +main :: IO () +main = do + printTitle "############################## desugar ##############################" + verifyNumDSL desugar + printTitle "############################## associativity ##############################" + verifyNumDSL associativity + printTitle "############################## distributivity ##############################" + verifyNumDSL distributivity + printTitle "############################## commutativity ##############################" + verifyNumDSL commutativity diff --git a/rules/taso/split/Main.hs b/rules/taso/split/Main.hs new file mode 100644 index 0000000..2f83166 --- /dev/null +++ b/rules/taso/split/Main.hs @@ -0,0 +1,48 @@ +module Main (main) where + +import Grisette hiding (dot, (-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (checkSIMap, monitorExprOnFailure, newRClasses, siRelation) +import TensorRight.Internal.DSL.TASO (concat, split0, split1) +import Prelude hiding (concat) + +split0_desugar :: forall a. NumRule a -- Verify desugaring +split0_desugar _ = do + [rclassM, rclassN] <- newRClasses ["rclassM", "rclassN"] + sizeM1 <- newMap "sizeM1" rclassM + sizeM2 <- newMap "sizeM2" rclassM + + sizeN <- newMap "sizeN" rclassN + + x <- newTensor @a "x" [rclassM --> sizeM1, rclassN --> sizeN] + y <- newTensor @a "y" [rclassM --> sizeM2, rclassN --> sizeN] + + let concatAxis = ByRClass rclassN + lhs <- split0 concatAxis $ concat concatAxis x y + let rhs = x + + rewrite "split_0(a, concat(a, x, y)) ⇒ x" lhs rhs + +split1_desugar :: forall a. NumRule a -- Verify desugaring +split1_desugar _ = do + [rclassM, rclassN] <- newRClasses ["rclassM", "rclassN"] + sizeM1 <- newMap "sizeM1" rclassM + sizeM2 <- newMap "sizeM2" rclassM + + sizeN <- newMap "sizeN" rclassN + + x <- newTensor @a "x" [rclassM --> sizeM1, rclassN --> sizeN] + y <- newTensor @a "y" [rclassM --> sizeM2, rclassN --> sizeN] + + let concatAxis = ByRClass rclassN + lhs <- split1 concatAxis $ concat concatAxis x y + let rhs = y + + rewrite "split_1(a, concat(a, x, y)) ⇒ y" lhs rhs + +main :: IO () +main = do + printTitle "#################### split0 desguar ####################" + verifyNumDSL split0_desugar + printTitle "#################### split1 desugar ####################" + verifyNumDSL split1_desugar diff --git a/rules/taso/transpose/Main.hs b/rules/taso/transpose/Main.hs new file mode 100644 index 0000000..5646b57 --- /dev/null +++ b/rules/taso/transpose/Main.hs @@ -0,0 +1,110 @@ +module Main (main) where + +import Grisette hiding ((-->)) +import TensorRight +import TensorRight.Internal.DSL.DSL (newSingletonRClass) +import TensorRight.Internal.DSL.TASO (concat, ewadd, ewmul, relu, smul, transpose) +import Prelude hiding (concat) + +-- ############################# (Rewrite rules not enforcing singleton) ############################ +-- Desugaring for general TASO transpose +desugarTranspose :: forall a. AnyDTypeRule a +desugarTranspose _ = do + rclass <- newSingletonRClass "rclass" + s1 <- newMap "s1" rclass + s2 <- newMap "s2" rclass + tA <- newTensor @a "A" [rclass --> s1 @@ "L", rclass --> s2 @@ "R"] + lhs <- transpose tA + rhs <- relabel tA [ByLabel "L" --> ByLabel "R", ByLabel "R" --> ByLabel "L"] + rewrite "transpose(A) ⇒ relabel(A, swap)" lhs rhs + +inverse :: forall a. AnyDTypeRule a +inverse _ = do + rclass <- newSingletonRClass "rclass" + s1 <- newMap "s1" rclass + s2 <- newMap "s2" rclass + tA <- newTensor @a "A" [rclass --> s1 @@ "L", rclass --> s2 @@ "R"] + lhs <- transpose $ transpose tA + rewrite "transpose(transpose(A)) ⇒ A" lhs tA + +-- transpose(ewadd(x, y)) = ewadd(transpose(x), transpose(y)) +transposeEwadd :: forall a. NumRule a +transposeEwadd _ = do + r <- newSingletonRClass "r" + sL <- newMap "sL" r + sR <- newMap "sR" r + x <- newTensor @a "x" [r --> sL @@ "L", r --> sR @@ "R"] + y <- newTensor @a "y" [r --> sL @@ "L", r --> sR @@ "R"] + lhs <- transpose (ewadd x y) + rhs <- ewadd (transpose x) (transpose y) + rewrite "transpose(ewadd(x,y)) ⇒ ewadd(transpose(x), transpose(y))" lhs rhs + +-- transpose(ewmul(x, y)) = ewmul(transpose(x), transpose(y)) +transposeEwmul :: forall a. NumRule a +transposeEwmul _ = do + r <- newSingletonRClass "r" + sL <- newMap "sL" r + sR <- newMap "sR" r + x <- newTensor @a "x" [r --> sL @@ "L", r --> sR @@ "R"] + y <- newTensor @a "y" [r --> sL @@ "L", r --> sR @@ "R"] + lhs <- transpose (ewmul x y) + rhs <- ewmul (transpose x) (transpose y) + rewrite "transpose(ewmul(x,y)) ⇒ ewmul(transpose(x), transpose(y))" lhs rhs + +-- transpose(smul(x, w)) = smul(transpose(x), w) +transposeSmul :: forall a. NumRule a +transposeSmul _ = do + r <- newSingletonRClass "r" + sL <- newMap "sL" r + sR <- newMap "sR" r + x <- newTensor @a "x" [r --> sL @@ "L", r --> sR @@ "R"] + let w = ("w" :: a) + lhs <- transpose (smul x w) + rhs <- smul (transpose x) w + rewrite "transpose(smul(x,w)) ⇒ smul(transpose(x), w)" lhs rhs + +-- transpose(relu(x)) = relu(transpose(x)) +transposeRelu :: forall a. NumRule a +transposeRelu _ = do + r <- newSingletonRClass "r" + sL <- newMap "sL" r + sR <- newMap "sR" r + x <- newTensor @a "x" [r --> sL @@ "L", r --> sR @@ "R"] + lhs <- transpose (relu @a x) + rhs <- relu @a (transpose x) + rewrite "transpose(relu(x)) ⇒ relu(transpose(x))" lhs rhs + +-- concat(1, transpose(x), transpose(y)) = transpose(concat(0, x, y)) +transposeConcat :: forall a. AnyDTypeRule a +transposeConcat _ = do + -- Use same rclass with two labels for 2D + r <- newSingletonRClass "r" + sLx <- newMap "sLx" r + sRx <- newMap "sRx" r + sLy <- newMap "sLy" r + sRy <- newMap "sRy" r + x <- newTensor @a "x" [r --> sLx @@ "L", r --> sRx @@ "R"] + y <- newTensor @a "y" [r --> sLy @@ "L", r --> sRy @@ "R"] + -- concat along label "R" (axis 1), then transpose should move concat to axis 0 (label "L") + let axis1 = ByLabel "R" + let axis0 = ByLabel "L" + lhs <- concat axis1 (transpose x) (transpose y) + rhs <- transpose (concat axis0 x y) + rewrite "concat(1, transpose(x), transpose(y)) ⇒ transpose(concat(0, x, y))" lhs rhs + +main :: IO () +main = do + printTitle "######################## desugarTranspose ########################" + verifyNumDSL desugarTranspose + printTitle "######################## inverse #################################" + verifyNumDSL inverse + printTitle "######################## transposeEwadd ##########################" + verifyNumDSL transposeEwadd + printTitle "######################## transposeEwmul ##########################" + verifyNumDSL transposeEwmul + printTitle "######################## transposeSmul ###########################" + verifyNumDSL transposeSmul + printTitle "######################## transposeRelu ###########################" + verifyNumDSL transposeRelu + printTitle "######################## transposeConcat #########################" + verifyNumDSL transposeConcat diff --git a/rules/xla/add/Main.hs b/rules/xla/add/Main.hs index 722bbc7..45673bc 100644 --- a/rules/xla/add/Main.hs +++ b/rules/xla/add/Main.hs @@ -46,11 +46,11 @@ rule04 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyNumDSL rule04 diff --git a/rules/xla/broadcast/Main.hs b/rules/xla/broadcast/Main.hs index 4ab0a5c..444f4a9 100644 --- a/rules/xla/broadcast/Main.hs +++ b/rules/xla/broadcast/Main.hs @@ -103,19 +103,19 @@ rule08 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyAnyDTypeDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyAnyDTypeDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyDSL rule07 - print "############################## rule08 ##############################" + printTitle "############################## rule08 ##############################" verifyAnyDTypeDSL rule08 diff --git a/rules/xla/clamp/Main.hs b/rules/xla/clamp/Main.hs index 7d3722f..0b31d79 100644 --- a/rules/xla/clamp/Main.hs +++ b/rules/xla/clamp/Main.hs @@ -52,9 +52,9 @@ rule03 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 diff --git a/rules/xla/compare/Main.hs b/rules/xla/compare/Main.hs index 2c0cad3..f00defc 100644 --- a/rules/xla/compare/Main.hs +++ b/rules/xla/compare/Main.hs @@ -33,15 +33,15 @@ rule06 = constructRule "Eqv(A, A) ⇒ True" Eqv True main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyNumDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyNumDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyNumDSL rule06 diff --git a/rules/xla/concat/Main.hs b/rules/xla/concat/Main.hs index 1b1e4aa..b2be1e9 100644 --- a/rules/xla/concat/Main.hs +++ b/rules/xla/concat/Main.hs @@ -143,17 +143,17 @@ rule07 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyAnyDTypeDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyAnyDTypeDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyAnyDTypeDSL rule07 diff --git a/rules/xla/conv/Main.hs b/rules/xla/conv/Main.hs index c73002e..4876bb1 100644 --- a/rules/xla/conv/Main.hs +++ b/rules/xla/conv/Main.hs @@ -459,9 +459,9 @@ rule03 _ = do main :: IO () main = do - print "############################## rule00 ##############################" + printTitle "############################## rule00 ##############################" verifyNumDSL rule00 - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 diff --git a/rules/xla/divmod/Main.hs b/rules/xla/divmod/Main.hs index 8fcb094..c4f5a44 100644 --- a/rules/xla/divmod/Main.hs +++ b/rules/xla/divmod/Main.hs @@ -104,19 +104,19 @@ rule08 = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyDSLWith (withTimeout 10000000 z3) rule07 - print "############################## rule08 ##############################" + printTitle "############################## rule08 ##############################" verifyDSLWith (withTimeout 10000000 z3) rule08 diff --git a/rules/xla/dot/Main.hs b/rules/xla/dot/Main.hs index 0a12434..3e2e00c 100644 --- a/rules/xla/dot/Main.hs +++ b/rules/xla/dot/Main.hs @@ -1,6 +1,5 @@ module Main (main) where -import Debug.Trace (traceShow) import Grisette hiding (dot, (-->)) import TensorRight @@ -207,15 +206,15 @@ rule06 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSLWith cvc5 rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyNumDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyNumDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyNumDSL rule06 diff --git a/rules/xla/dyslice/Main.hs b/rules/xla/dyslice/Main.hs index a1f9f91..3f32482 100644 --- a/rules/xla/dyslice/Main.hs +++ b/rules/xla/dyslice/Main.hs @@ -144,13 +144,13 @@ rule06 = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyAnyDTypeDSL rule05 diff --git a/rules/xla/dyupslice/Main.hs b/rules/xla/dyupslice/Main.hs index 3ef8446..f65cb21 100644 --- a/rules/xla/dyupslice/Main.hs +++ b/rules/xla/dyupslice/Main.hs @@ -98,11 +98,11 @@ rule04 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSL rule04 diff --git a/rules/xla/iota/Main.hs b/rules/xla/iota/Main.hs index 4c9bc9f..a5f19c5 100644 --- a/rules/xla/iota/Main.hs +++ b/rules/xla/iota/Main.hs @@ -17,5 +17,5 @@ rule01 = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyDSL rule01 diff --git a/rules/xla/logical/Main.hs b/rules/xla/logical/Main.hs index 6a5c39a..1ac04c3 100644 --- a/rules/xla/logical/Main.hs +++ b/rules/xla/logical/Main.hs @@ -108,25 +108,25 @@ rule11 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyDSL rule07 - print "############################## rule08 ##############################" + printTitle "############################## rule08 ##############################" verifyDSL rule08 - print "############################## rule09 ##############################" + printTitle "############################## rule09 ##############################" verifyNumDSL rule09 - print "############################## rule10 ##############################" + printTitle "############################## rule10 ##############################" verifyNumDSL rule10 - print "############################## rule11 ##############################" + printTitle "############################## rule11 ##############################" verifyNumDSL rule11 diff --git a/rules/xla/max/Main.hs b/rules/xla/max/Main.hs index eea3080..7a9a353 100644 --- a/rules/xla/max/Main.hs +++ b/rules/xla/max/Main.hs @@ -24,7 +24,7 @@ rule02 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 diff --git a/rules/xla/mul/Main.hs b/rules/xla/mul/Main.hs index 5519bc0..2af2a82 100644 --- a/rules/xla/mul/Main.hs +++ b/rules/xla/mul/Main.hs @@ -109,23 +109,23 @@ rule10 = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyNumDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyNumDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyNumDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyNumDSL rule07 - print "############################## rule08 ##############################" + printTitle "############################## rule08 ##############################" verifyNumDSL rule08 - print "############################## rule09 ##############################" + printTitle "############################## rule09 ##############################" verifyNumDSL rule09 - print "############################## rule10 ##############################" + printTitle "############################## rule10 ##############################" verifyDSLWith cvc5 rule10 diff --git a/rules/xla/not/Main.hs b/rules/xla/not/Main.hs index 35fcdc6..7c23bbb 100644 --- a/rules/xla/not/Main.hs +++ b/rules/xla/not/Main.hs @@ -23,7 +23,7 @@ rule02 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 diff --git a/rules/xla/pad/Main.hs b/rules/xla/pad/Main.hs index c1f681d..fbd80f7 100644 --- a/rules/xla/pad/Main.hs +++ b/rules/xla/pad/Main.hs @@ -174,11 +174,11 @@ rule04 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSL rule04 diff --git a/rules/xla/reduce/Main.hs b/rules/xla/reduce/Main.hs index fd57526..7490b18 100644 --- a/rules/xla/reduce/Main.hs +++ b/rules/xla/reduce/Main.hs @@ -222,19 +222,19 @@ rule08 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyNumDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyNumDSL rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyNumDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyNumDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyNumDSL rule07 - print "############################## rule08 ##############################" + printTitle "############################## rule08 ##############################" verifyNumDSL rule08 diff --git a/rules/xla/relabel/Main.hs b/rules/xla/relabel/Main.hs index 86defbe..0b2ea8c 100644 --- a/rules/xla/relabel/Main.hs +++ b/rules/xla/relabel/Main.hs @@ -25,7 +25,7 @@ rule02 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 diff --git a/rules/xla/reverse/Main.hs b/rules/xla/reverse/Main.hs index c3a7074..aa19231 100644 --- a/rules/xla/reverse/Main.hs +++ b/rules/xla/reverse/Main.hs @@ -63,11 +63,11 @@ rule04 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyNumDSL rule04 diff --git a/rules/xla/select/Main.hs b/rules/xla/select/Main.hs index 498b055..3ecdc0d 100644 --- a/rules/xla/select/Main.hs +++ b/rules/xla/select/Main.hs @@ -48,11 +48,11 @@ rule04 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSL rule04 diff --git a/rules/xla/slice/Main.hs b/rules/xla/slice/Main.hs index 9e30457..7968f53 100644 --- a/rules/xla/slice/Main.hs +++ b/rules/xla/slice/Main.hs @@ -408,29 +408,29 @@ rule13 _ = do main :: IO () main = do - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyAnyDTypeDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyAnyDTypeDSL rule02 - print "############################## rule03 ##############################" + printTitle "############################## rule03 ##############################" verifyAnyDTypeDSL rule03 - print "############################## rule04 ##############################" + printTitle "############################## rule04 ##############################" verifyAnyDTypeDSLWith cvc5 rule04 - print "############################## rule05 ##############################" + printTitle "############################## rule05 ##############################" verifyAnyDTypeDSL rule05 - print "############################## rule06 ##############################" + printTitle "############################## rule06 ##############################" verifyAnyDTypeDSL rule06 - print "############################## rule07 ##############################" + printTitle "############################## rule07 ##############################" verifyAnyDTypeDSL rule07 - print "############################## rule08 ##############################" + printTitle "############################## rule08 ##############################" verifyAnyDTypeDSL rule08 - print "############################## rule09 ##############################" + printTitle "############################## rule09 ##############################" verifyAnyDTypeDSL rule09 - print "############################## rule10 ##############################" + printTitle "############################## rule10 ##############################" verifyAnyDTypeDSL rule10 - print "############################## rule11 ##############################" + printTitle "############################## rule11 ##############################" verifyAnyDTypeDSL rule11 - print "############################## rule12 ##############################" + printTitle "############################## rule12 ##############################" verifyAnyDTypeDSL rule12 - print "############################## rule13 ##############################" + printTitle "############################## rule13 ##############################" verifyAnyDTypeDSLWith cvc5 rule13 diff --git a/rules/xla/sub/Main.hs b/rules/xla/sub/Main.hs index 3f47fb8..77fb35e 100644 --- a/rules/xla/sub/Main.hs +++ b/rules/xla/sub/Main.hs @@ -35,9 +35,9 @@ rule02 _ = do main :: IO () main = do - print "############################## rule00 ##############################" + printTitle "############################## rule00 ##############################" verifyNumDSL rule00 - print "############################## rule01 ##############################" + printTitle "############################## rule01 ##############################" verifyNumDSL rule01 - print "############################## rule02 ##############################" + printTitle "############################## rule02 ##############################" verifyNumDSL rule02 diff --git a/runall.sh b/runall.sh index dd92a72..5b447e3 100755 --- a/runall.sh +++ b/runall.sh @@ -10,30 +10,31 @@ run_and_capture() { LOCAL_SUCCESS=0 while IFS= read -r line; do echo "$line" 1>&2 - if [[ $line =~ ^\[SUCCESS\].* ]]; then + clean_line=$(sed -r 's/\x1B\[[0-9;]*[mK]//g' <<<"$line") + if [[ $clean_line =~ ^\[SUCCESS\].* ]]; then export LOCAL_SUCCESS=$((LOCAL_SUCCESS + 1)) - elif [[ $line =~ ^\[SUCCESS-Overall\].* ]]; then + elif [[ $clean_line =~ ^\[SUCCESS-Overall\].* ]]; then export LOCAL_SUCCESS=$((LOCAL_SUCCESS + 1)) - elif [[ $line =~ ^\[SUCCESS-.*\].* ]]; then + elif [[ $clean_line =~ ^\[SUCCESS-.*\].* ]]; then true - elif [[ $line =~ ^\[FAIL\].* ]]; then + elif [[ $clean_line =~ ^\[FAIL\].* ]]; then export LOCAL_FAILED=$((LOCAL_FAILED + 1)) - elif [[ $line =~ ^\[FAIL-Overall\].* ]]; then + elif [[ $clean_line =~ ^\[FAIL-Overall\].* ]]; then export LOCAL_FAILED=$((LOCAL_FAILED + 1)) - elif [[ $line =~ ^\[FAIL-.*\].* ]]; then + elif [[ $clean_line =~ ^\[FAIL-.*\].* ]]; then true - elif [[ $line =~ ^\[WARNING\].* ]]; then + elif [[ $clean_line =~ ^\[WARNING\].* ]]; then true - elif [[ $line =~ ^\[INFO-.*\].* ]]; then + elif [[ $clean_line =~ ^\[INFO-.*\].* ]]; then true - elif [[ $line =~ ^\[INFO\].* ]]; then + elif [[ $clean_line =~ ^\[INFO\].* ]]; then true - elif [[ $line =~ ^====\>.* ]]; then + elif [[ $clean_line =~ ^====\>.* ]]; then true - elif [[ $line =~ ^\>\>\>.* ]]; then + elif [[ $clean_line =~ ^\>\>\>.* ]]; then true else - echo "Unknown line: $line" + echo "Unknown line: $clean_line" exit 1 fi done diff --git a/src/TensorRight.hs b/src/TensorRight.hs index b5a2de3..572f895 100644 --- a/src/TensorRight.hs +++ b/src/TensorRight.hs @@ -524,6 +524,9 @@ module TensorRight nonInf, posInf, negInf, + + -- * Utils + printTitle, ) where @@ -534,3 +537,4 @@ import TensorRight.Internal.DSL.DSL import TensorRight.Internal.DSL.Expr import TensorRight.Internal.DSL.Syntax import TensorRight.Internal.DSL.Verify +import TensorRight.Internal.Util.Pretty diff --git a/src/TensorRight/Internal/Core/Tensor.hs b/src/TensorRight/Internal/Core/Tensor.hs index c0f3fb4..321fda3 100644 --- a/src/TensorRight/Internal/Core/Tensor.hs +++ b/src/TensorRight/Internal/Core/Tensor.hs @@ -164,6 +164,9 @@ instance ToElem TensorReal where instance ToElem SymBool where toElem = BoolElem . Typed.TensorElemVal +instance ToElem Elem where + toElem = id + tensorAccess :: (TensorOperand t) => t -> Indices -> ErrorEnv Elem tensorAccess to i = do t <- tensor to diff --git a/src/TensorRight/Internal/Core/Tensor/Typed.hs b/src/TensorRight/Internal/Core/Tensor/Typed.hs index 0db33b9..3da0089 100644 --- a/src/TensorRight/Internal/Core/Tensor/Typed.hs +++ b/src/TensorRight/Internal/Core/Tensor/Typed.hs @@ -631,9 +631,9 @@ sliceStartEndStrides to SliceArgs {..} = do restrictAxes diffDims (fromHashMap valMap) return $ unionAxisMap indices emptyIndices let defaultMap val = HM.fromList . map (,val) . HS.toList - filledStart <- checkAndFillInAxes "start" start $ (defaultMap 0 axes) + filledStart <- checkAndFillInAxes "start" start $ defaultMap 0 axes filledEnd <- checkAndFillInAxes "end" end $ asHashMap (tensorShape t) - filledStrides <- checkAndFillInAxes "strides" strides $ (defaultMap 1 axes) + filledStrides <- checkAndFillInAxes "strides" strides $ defaultMap 1 axes assert "start must be non-negative" $ symAll (.>= 0) $ asHashMap filledStart -- The original Rosette implementation may be buggy here. diff --git a/src/TensorRight/Internal/Core/Verify.hs b/src/TensorRight/Internal/Core/Verify.hs index 48d6145..ba99f32 100644 --- a/src/TensorRight/Internal/Core/Verify.hs +++ b/src/TensorRight/Internal/Core/Verify.hs @@ -66,6 +66,7 @@ import TensorRight.Internal.Core.Tensor ) import qualified TensorRight.Internal.Core.Tensor.Typed as Typed import TensorRight.Internal.Util.Error (Error, ErrorEnv, splitWithError) +import TensorRight.Internal.Util.Pretty (printWarning) getTensorWithValidityCondition :: forall a. @@ -304,11 +305,11 @@ verifyRule bil2r <- case soll2r of Left Unsat -> return True Left err -> do - putStrLn $ "[WARNING]: Verification for forall right si there do not exist multiple left si fails due to unexpected solver failure" <> show err + printWarning $ "Verification for forall right si there do not exist multiple left si fails due to unexpected solver failure" <> show err return False Right m -> do pprint m - putStrLn "[WARNING]: SI-relation is not bijective. (There exist multiple left SI for a right SI.)" + printWarning "SI-relation is not bijective. (There exist multiple left SI for a right SI.)" return False condr2l <- @@ -326,11 +327,11 @@ verifyRule bir2l <- case solr2l of Left Unsat -> return True Left err -> do - putStrLn $ "[WARNING]: Verification for forall left si there do not exist multiple right si fails due to unexpected solver failure" <> show err + printWarning $ "Verification for forall left si there do not exist multiple right si fails due to unexpected solver failure" <> show err return False Right m -> do pprint m - putStrLn "[WARNING]: SI-relation is not bijective. (There exist multiple right SI for a left SI.)" + printWarning "SI-relation is not bijective. (There exist multiple right SI for a left SI.)" return False if bil2r && bir2l @@ -351,13 +352,11 @@ verifyRule allokl <- case r of Left Unsat -> return True Left err -> do - putStrLn $ - "[WARNING]: Verification that all left si can be accessed fails due to unexpected solver failure" - <> show err + printWarning $ "Verification that all left si can be accessed fails due to unexpected solver failure" <> show err return False Right m -> do pprint m - putStrLn "[WARNING]: Some left si cannot be accessed." + printWarning "Some left si cannot be accessed." return False condr <- evaluate $ @@ -375,17 +374,16 @@ verifyRule allokr <- case r of Left Unsat -> return True Left err -> do - putStrLn $ - "[WARNING]: Verification that all right si can be accessed fails due to unexpected solver failure" - <> show err + printWarning $ "Verification that all right si can be accessed fails due to unexpected solver failure" <> show err return False Right m -> do pprint m - putStrLn "[WARNING]: Some right si cannot be accessed." + printWarning "Some right si cannot be accessed." return False unless (allokl && allokr) $ - putStrLn "[WARNING]: Some SI cannot be accessed." - else putStrLn "[WARNING]: SI-relation is not bijective." + printWarning + "Some SI cannot be accessed." + else printWarning "SI-relation is not bijective." cond1 <- evaluate $ diff --git a/src/TensorRight/Internal/DSL/BoundInference.hs b/src/TensorRight/Internal/DSL/BoundInference.hs index d30e9b7..63ba9a2 100644 --- a/src/TensorRight/Internal/DSL/BoundInference.hs +++ b/src/TensorRight/Internal/DSL/BoundInference.hs @@ -230,14 +230,14 @@ inferBound :: GrisetteSMTConfig -> VerifyTask -> HS.HashSet RClassIdentifier -> - HS.HashSet RClassIdentifier -> + HM.HashMap RClassIdentifier Int -> AbstractShape -> - IO (HM.HashMap RClassIdentifier Int) + IO (HM.HashMap RClassIdentifier (Int, Int)) inferBound solverConfig (VerifyTask _ lhs rhs pre siRelation _ _ _ _ _ _ _ _ _) - nonSingletonRClasses - singletonRClasses + nonFixedRClasses + rankConditions sp = do let preCond = pre when (preCond == con False) $ @@ -291,12 +291,13 @@ inferBound max 1 $ kFromAllAccesses rclass + numHasRClassInGroup rclass (HS.toList filteredConditions) - return $ - HM.fromList $ - ( (\rclass -> (rclass, kForRClass rclass)) - <$> HS.toList (nonSingletonRClasses `HS.difference` singletonRClasses) - ) - <> ((,1) <$> HS.toList singletonRClasses) + + let fixedRClasses = HM.keysSet rankConditions + let inferredBounds = + HM.fromList $ + (\rclass -> (rclass, (1, kForRClass rclass))) + <$> HS.toList (nonFixedRClasses `HS.difference` fixedRClasses) + return $ HM.union inferredBounds (HM.map (\k -> (k, k)) rankConditions) abstractShapeAccess :: AbstractShape -> Indices abstractShapeAccess AbstractShape {..} = do diff --git a/src/TensorRight/Internal/DSL/DSL.hs b/src/TensorRight/Internal/DSL/DSL.hs index 1d78ffb..1ce3665 100644 --- a/src/TensorRight/Internal/DSL/DSL.hs +++ b/src/TensorRight/Internal/DSL/DSL.hs @@ -36,6 +36,7 @@ module TensorRight.Internal.DSL.DSL newTensor, numBinOp, boolBinOp, + rankPrecondition, reduce, siRelation, precondition, @@ -60,6 +61,8 @@ module TensorRight.Internal.DSL.DSL boolUnaryOp, convBase, conv, + newSingletonRClass, + newSingletonRClasses, monitorExprOnFailure, monitorMapOnFailure, clamp, @@ -70,6 +73,8 @@ module TensorRight.Internal.DSL.DSL newConstMap, newConstMaps, combineMap, + twoRefsOf, + threeRefsOf, Padding (..), ConvConfig (..), ConvPadding (..), @@ -82,6 +87,8 @@ module TensorRight.Internal.DSL.DSL checkSIMap, reshapeDegenerate, numTensorAssumption, + ExprInContext, + liftInContext, ) where @@ -124,7 +131,7 @@ import TensorRight.Internal.DSL.Expr ConvPaddingArgsExpr (ConvPaddingArgsExpr, high, ldilation, low, rdilation), DSLContext, DySliceArgsExpr (DySliceArgsExpr, sizes, start), - Env (Env, lhsSIMaps, numTensorAssumptions), + Env (..), Expr, NumTensorAssumption (NumTensorAssumption), PaddingArgsExpr (PaddingArgsExpr, high, interior, low), @@ -178,7 +185,6 @@ import TensorRight.Internal.DSL.Expr rhsSIMaps, runDSLContext, siRelations, - singletonRClasses, tensorDTypes, tensorShapes, validTensorShape, @@ -204,7 +210,7 @@ import TensorRight.Internal.DSL.Shape restrictAbstractShape, toAbstractShape, ) -import TensorRight.Internal.Util.Error (assert) +import TensorRight.Internal.Util.Error (assert, tshow) -- | Create an integer element from a tensor int. intElem :: TensorInt -> Elem @@ -391,6 +397,34 @@ precondition :: DSLContext () precondition maps = precondition' maps . zipCondition +-- | Declare an exact rank for an RClass. Sets rankConditions[rclass] = k. +rankPrecondition :: + RClassIdentifier -> + Int -> + DSLContext () +rankPrecondition rclass k = do + assert "rankPrecondition: k must be >= 1" (k >= 1) + env <- get + case HM.lookup rclass (rankConditions env) of + Just k' -> assert "Conflicting rankPrecondition for the same RClass" (k' == k) + Nothing -> return () + put $ env {rankConditions = HM.insert rclass k (rankConditions env)} + +-- | Mark an existing RClass as singleton (exact rank 1). +markSingleton :: RClassIdentifier -> DSLContext () +markSingleton r = rankPrecondition r 1 + +-- | Create a new singleton RClass +newSingletonRClass :: T.Text -> DSLContext RClassIdentifier +newSingletonRClass label = do + rclass <- newRClass label + markSingleton rclass + return rclass + +-- | Create singleton RClasses +newSingletonRClasses :: [T.Text] -> DSLContext [RClassIdentifier] +newSingletonRClasses = traverse newSingletonRClass + -- | Add an SI relation to rewriting rule. -- It is similar to 'precondition', but it is used to specify the SI relations. siRelation' :: @@ -487,7 +521,7 @@ numBinScalarOp op lhs' rhs = do typeLhs <- typeOf lhs assert "lhs must be int or real" $ typeLhs `elem` [IntType, RealType] assert "lhs and rhs must have the same dtype" $ toDType rhs == typeLhs - return (shapeLhs, IntType) + return (shapeLhs, typeLhs) -- | Boolean binary operation. The lhs and rhs must have the same shape, and -- the dtype of lhs and rhs must be 'BoolType'. @@ -672,8 +706,7 @@ iota shapeDesc d = do validTensorShape shape let abstractShape = toAbstractShape shape rclass <- getRClassByRClassRef abstractShape d - env <- get - put $ env {singletonRClasses = HS.insert rclass (singletonRClasses env)} + markSingleton rclass return (abstractShape, IntType) -- | The named arguments to the 'slice' operation. @@ -1096,8 +1129,7 @@ concatTensor lhs' rhs' d = do assert "lhs and rhs must have the same rclasses" $ shapeLhs == shapeRhs assert "lhs and rhs must have the same type" $ tyLhs == tyRhs rclass <- getRClassByRClassRef shapeLhs d - env <- get - put $ env {singletonRClasses = HS.insert rclass (singletonRClasses env)} + markSingleton rclass return (shapeLhs, tyLhs) -- | Concatenate a list of tensors. @@ -1117,8 +1149,7 @@ concatTensorList exprs' d = do assert "All tensors in concatList must have the same RClasses" $ all (== head shapes) shapes assert "All tensors in concatList must have the same type" $ all (== head tys) tys rclass <- getRClassByRClassRef (head shapes) d - env <- get - put $ env {singletonRClasses = HS.insert rclass (singletonRClasses env)} + markSingleton rclass return (head shapes, head tys) -- | Relabel operation. @@ -1127,7 +1158,7 @@ relabel :: -- | The tensor to relabel. e -> -- | The relabel map. Should be @[rclass --> 'ByLabel' label]@ or - -- @['ByLabel' label -> 'ByLabel' label, ...]@. + -- @['ByLabel' label --> 'ByLabel' label, ...]@. [RelabelMapDesc] -> DSLContext Expr relabel expr' relabelMapDescs = do @@ -1187,8 +1218,9 @@ dot lhs rhs contractingSIMapsDesc batchRClasses = do let dotAllRefs = HM.keysSet contractingSIMaps <> HS.fromList batchRClasses let lhsAllRefs = abstractShapeAllRefs shapeLhs let rhsAllRefs = abstractShapeAllRefs shapeRhs + assert - ( "Contracion + batch rclasses must be exactly the interaction of lhs and " + ( "Contraction + batch rclasses must be exactly the interaction of lhs and " <> "rhs rclasses" ) $ dotAllRefs == HS.intersection lhsAllRefs rhsAllRefs @@ -1483,3 +1515,22 @@ checkSIMap lhs rhs = do { lhsSIMaps = HS.union lhsSet $ lhsSIMaps env, rhsSIMaps = HS.union rhsSet $ rhsSIMaps env } + +-- | Gets the two aggregated axes from a 2D tensor. Useful for 2D Transpose +twoRefsOf :: Expr -> DSLContext (RClassRef, RClassRef) +twoRefsOf e = do + shape <- shapeOf e + let refs = HS.toList $ abstractShapeAllRefs shape + assert ("Expected exactly 2 refs, got " <> tshow (length refs) <> ": " <> tshow refs) $ + length refs == 2 + let [a, b] = refs in return (a, b) + +-- | Helper function to get three aggregated axes from a 3D tensor. +-- Useful for 3D batched matrix multiplication. +threeRefsOf :: Expr -> DSLContext (RClassRef, RClassRef, RClassRef) +threeRefsOf e = do + shape <- shapeOf e + let refs = HS.toList $ abstractShapeAllRefs shape + assert ("Expected exactly 3 refs, got " <> tshow (length refs) <> ": " <> tshow refs) $ + length refs == 3 + let [a, b, c] = refs in return (a, b, c) diff --git a/src/TensorRight/Internal/DSL/Eval.hs b/src/TensorRight/Internal/DSL/Eval.hs index 1ec2fc2..d4a0951 100644 --- a/src/TensorRight/Internal/DSL/Eval.hs +++ b/src/TensorRight/Internal/DSL/Eval.hs @@ -176,7 +176,7 @@ freshMapBase name mapBase = HM.map ( \(SymInteger (SymTerm (TypedSymbol s :: TypedSymbol knd a))) -> SymInteger $ - SymTerm $ + SymTerm ( TypedSymbol $ mapIdentifier (mapMetadata (\m -> List [Atom "fresh", Atom $ T.pack name, m])) diff --git a/src/TensorRight/Internal/DSL/Expr.hs b/src/TensorRight/Internal/DSL/Expr.hs index 02d91dc..8dae974 100644 --- a/src/TensorRight/Internal/DSL/Expr.hs +++ b/src/TensorRight/Internal/DSL/Expr.hs @@ -617,10 +617,10 @@ data Env = Env exprAbstractShapes :: HM.HashMap Int AbstractShape, exprDTypes :: HM.HashMap Int DType, preConditions :: [Condition], + rankConditions :: HM.HashMap RClassIdentifier Int, numTensorAssumptions :: [NumTensorAssumption], siMaps :: HS.HashSet MapIdentifier, siRelations :: [Condition], - singletonRClasses :: HS.HashSet RClassIdentifier, monitoringExprs :: [(T.Text, Expr)], monitoringMaps :: [(T.Text, RClassRef, MapIdentifier)], lhsSIMaps :: HS.HashSet MapIdentifier, @@ -643,10 +643,10 @@ emptyEnv = HM.empty HM.empty [] + HM.empty [] HS.empty [] - HS.empty [] [] HS.empty diff --git a/src/TensorRight/Internal/DSL/Shape.hs b/src/TensorRight/Internal/DSL/Shape.hs index eaed752..f2ff3cb 100644 --- a/src/TensorRight/Internal/DSL/Shape.hs +++ b/src/TensorRight/Internal/DSL/Shape.hs @@ -4,6 +4,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} @@ -154,6 +155,7 @@ class TensorShapeLike a where toTensorShape :: (MonadError Error m) => a -> m TensorShape instance TensorShapeLike TensorShape where + toTensorShape :: (MonadError Error m) => TensorShape -> m TensorShape toTensorShape = return instance TensorShapeLike [TensorShapeDesc] where diff --git a/src/TensorRight/Internal/DSL/TASO.hs b/src/TensorRight/Internal/DSL/TASO.hs new file mode 100644 index 0000000..b57d6ec --- /dev/null +++ b/src/TensorRight/Internal/DSL/TASO.hs @@ -0,0 +1,436 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# OPTIONS_GHC -Wno-missing-import-lists #-} + +module TensorRight.Internal.DSL.TASO + ( ewadd, + ewmul, + smul, + relu, + concat, + split0, + split1, + transpose, + enlarge, + tasoConv, + matmul2D, + matmul3D, + PaddingMode (..), + Activation (..), + ) +where + +import Control.Monad.Except (MonadError (throwError)) +import Grisette (SymInteger, symIte, (.&&), (.<), (.<=), (.==), (.>=)) +import TensorRight (NumBinOp (Add, Mul), ToElem, posInf) +import TensorRight.Internal.Core.Tensor (ToDType) +import TensorRight.Internal.DSL.DSL + ( ConvConfig (..), + ConvPadding (..), + DSLContext, + Expr, + ExprInContext, + Padding (..), + RClassRef (..), + ValidNum, + clampScalar, + combineMap, + conv, + liftInContext, + newConstMap, + newNonNegMap, + numBinOp, + numBinScalarOp, + pad, + precondition, + twoRefsOf, + threeRefsOf, + relabel, + dot, + concatTensor, + ) +import TensorRight.Internal.DSL.Expr (checkMapHasRClass, getRClassByMap) +import qualified TensorRight.Internal.DSL.Expr as E +import TensorRight.Internal.DSL.Identifier (MapIdentifier) +import TensorRight.Internal.DSL.Parameters (ParamDesc (..)) +import TensorRight.Internal.DSL.Syntax (ArrowSyntax ((-->))) +import TensorRight.Internal.Util.Error (assert, tshow) +import Prelude hiding (concat) + +data Activation = Relu | None + +data PaddingMode = Same | Valid + +-- Helper function to get MapIdentifier from stride ParamDesc +getStrideMap :: ParamDesc -> MapIdentifier +getStrideMap (ParamDesc _ map) = map + +-- | TASO's ewadd operator. The lhs and rhs must have the same shape and the type must be either 'IntType' or 'RealType'. +ewadd :: + (ExprInContext lhs, ExprInContext rhs) => + -- | Lhs expression. + lhs -> + -- | Rhs expression. + rhs -> + DSLContext Expr +ewadd = numBinOp Add + +-- | TASO's ewmul operator. The lhs and rhs must have the same shape and the type must be either 'IntType' or 'RealType'. +ewmul :: + (ExprInContext lhs, ExprInContext rhs) => + -- | Lhs expression. + lhs -> + -- | Rhs expression. + rhs -> + DSLContext Expr +ewmul = numBinOp Mul + +-- | TASO's smul operator. The dtype of lhs must be 'IntType' or 'RealType'. +smul :: + (ExprInContext lhs, ToElem a, ToDType a) => + -- | Lhs expression. + lhs -> + -- | Rhs scalar. + a -> + DSLContext Expr +smul = numBinScalarOp Mul + +-- | TASO's relu operator +relu :: + forall a lhs. + (ExprInContext lhs, ValidNum a) => + -- | The tensor to clamp + lhs -> + DSLContext Expr +relu e = clampScalar @a 0 e posInf + +-- | TASO's concat operator +concat :: + (ExprInContext lhs, ExprInContext rhs) => + -- | The aggregated-axis to concat on. + RClassRef -> + -- | The left-hand side tensor. + lhs -> + -- | The right-hand side tensor. + rhs -> + DSLContext Expr +concat axis lhs' rhs' = concatTensor lhs' rhs' axis + +-- | TASO's transpose operator +transpose :: + (ExprInContext e) => + -- | The tensor to transpose + e -> + DSLContext Expr +transpose e' = do + e <- liftInContext e' + (a, b) <- twoRefsOf e + relabel e [a --> b, b --> a] + +-- TODO: Semantics of enlarge should be implemented in TensorRight/Internal/Core +-- TASO's enlarge operator! +-- Split policy: low = floor(d/2), high = d - low, where d = max(s, k) - s per axis. +enlarge :: + forall a e. + (ExprInContext e, ValidNum a) => + -- | Size descriptor for H axis: provides the axis ref and the existing size map of A along H. + ParamDesc -> + -- | Size descriptor for W axis: provides the axis ref and the existing size map of A along W. + ParamDesc -> + -- | Pre-allocated low padding map for H. + MapIdentifier -> + -- | Pre-allocated low padding map for W. + MapIdentifier -> + -- | Target size ky for H (abstract scalar). + SymInteger -> + -- | Target size kx for W (abstract scalar). + SymInteger -> + -- | The tensor to enlarge. + e -> + DSLContext Expr +enlarge (ParamDesc hRef sH) (ParamDesc wRef sW) hLow wLow ky kx e = do + rH <- getRClassByMap sH + rW <- getRClassByMap sW + -- Ensure provided low maps match rclasses + checkMapHasRClass rH hLow + checkMapHasRClass rW wLow + + -- Promote scalars + kH <- newConstMap "kH" ky rH + kW <- newConstMap "kW" kx rW + precondition [kH] $ \[k] -> k .>= 0 + precondition [kW] $ \[k] -> k .>= 0 + + -- Target sizes via max + sH' <- combineMap "sH'" (\[a, k] -> symIte (a .>= k) a k) [sH, kH] + sW' <- combineMap "sW'" (\[a, k] -> symIte (a .>= k) a k) [sW, kW] + + -- Differences + dH <- combineMap "dH" (\[m, a] -> m - a) [sH', sH] + dW <- combineMap "dW" (\[m, a] -> m - a) [sW', sW] + + -- Determine splits and construct high end based on them + precondition [hLow, dH] $ \[l, d] -> (l + l) .<= d .&& d .<= (l + l + 1) + precondition [wLow, dW] $ \[l, d] -> (l + l) .<= d .&& d .<= (l + l + 1) + hHigh <- combineMap "hHigh" (\[d, l] -> d - l) [dH, hLow] + wHigh <- combineMap "wHigh" (\[d, l] -> d - l) [dW, wLow] + + -- Zero interior paddings + zH <- newConstMap "zeroH" 0 rH + zW <- newConstMap "zeroW" 0 rW + + pad e (0 :: a) $ + Padding + { low = [hRef --> hLow, wRef --> wLow], + interior = [hRef --> zH, wRef --> zW], + high = [hRef --> hHigh, wRef --> wHigh] + } + +-- | TASO's 2D matrix multiplication operator +matmul2D :: + (ExprInContext lhs, ExprInContext rhs) => + -- | The left-hand side tensor (shape [M, K]) + lhs -> + -- | The right-hand side tensor (shape [K, N]) + rhs -> + -- | The contracting SI maps. + [ParamDesc] -> + DSLContext Expr +matmul2D lhs' rhs' contract = do + lhs <- liftInContext lhs' + rhs <- liftInContext rhs' + twoRefsOf lhs + twoRefsOf rhs + assert ("matmul2D: expected 1 contraction rclass, got " <> tshow (length contract)) $ + length contract == 1 + dot lhs rhs contract [] + +-- | TASO's 2D matrix multiplication operator +matmul3D :: + (ExprInContext lhs, ExprInContext rhs) => + -- | The left-hand side tensor (shape [M, K]) + lhs -> + -- | The right-hand side tensor (shape [K, N]) + rhs -> + -- | The contracting SI maps. + [ParamDesc] -> + -- | Batch RClasses + [RClassRef] -> + DSLContext Expr +matmul3D lhs' rhs' contract batch = do + lhs <- liftInContext lhs' + rhs <- liftInContext rhs' + -- Get the three axes from each tensor + threeRefsOf lhs -- B, M, K + threeRefsOf rhs -- B, K, N + assert ("matmul3D: expected 1 contraction rclass, got " <> tshow (length contract)) $ + length contract == 1 + assert ("matmul3D: expected 1 batch rclass, got " <> tshow (length batch)) $ + length batch == 1 + dot lhs rhs contract batch + +-- | TASO's 2D matrix multiplication operator +tasoConv :: + forall a input weights. + (ExprInContext input, ExprInContext weights, ValidNum a) => + -- | Convolution config + ConvConfig -> + -- | Padding config + PaddingMode -> + -- | Choice of activation function + Activation -> + -- | Input spatial size maps (per spatial RClass) + [ParamDesc] -> + -- | Kernel spatial size maps (per spatial RClass) + [ParamDesc] -> + -- | Input tensor + input -> + -- | The weights (kernel) tensor. + weights -> + DSLContext Expr +tasoConv config padConfig act inputSizePDs kernelSizePDs input weights = do + -- Determine spatial refs from the stride descriptors in the config + let strideRefs = + case config of + ConvConfig {strides = ss} -> [ref | ParamDesc ref _ <- ss] + + let toRClassId ref = case ref of + ByRClass r -> return r + ByLabel _ -> throwError "tasoConv requires strides specified with ByRClass refs" + + -- Build padding parameters per mode + (lowPDs, ldilPDs, highPDs, rdilPDs) <- case padConfig of + Valid -> do + -- VALID: low=0, high=0, ldilation=1, rdilation=1 + lowPDs <- + traverse + ( \ref -> do + r <- toRClassId ref + z <- newConstMap "low0" 0 r + return (ref --> z) + ) + strideRefs + highPDs <- + traverse + ( \ref -> do + r <- toRClassId ref + z <- newConstMap "high0" 0 r + return (ref --> z) + ) + strideRefs + ldilPDs <- + traverse + ( \ref -> do + r <- toRClassId ref + o <- newConstMap "ldilation1" 1 r + return (ref --> o) + ) + strideRefs + rdilPDs <- + traverse + ( \ref -> do + r <- toRClassId ref + o <- newConstMap "rdilation1" 1 r + return (ref --> o) + ) + strideRefs + return (lowPDs, ldilPDs, highPDs, rdilPDs) + Same -> do + -- SAME: compute padding using the formula: p_total = max(0, (ceil(n_in/s) - 1) * s + k - n_in) + -- For each spatial dimension, we need input size, kernel size, and stride + let strideMaps = [getStrideMap pd | pd <- strides config] + + -- Use provided input/kernel size maps (deterministic SAME) + let lookupSize :: RClassRef -> [ParamDesc] -> MapIdentifier + lookupSize ref pds = + case [m | ParamDesc r m <- pds, r == ref] of + (m : _) -> m + [] -> error "tasoConv(Same): missing spatial size map" + + let inputSizePairs = [(ref, lookupSize ref inputSizePDs) | ref <- strideRefs] + let kernelSizePairs = [(ref, lookupSize ref kernelSizePDs) | ref <- strideRefs] + + -- Create symbolic output size maps for SAME padding + -- For SAME padding, output_size = ceil(input_size / stride) + outputSizePairs <- + traverse + ( \((ref, inputSize), strideMap) -> do + r <- toRClassId ref + outputSize <- newNonNegMap "outputSize" r + -- Constrain: outputSize * stride >= inputSize (ceiling property) + precondition [outputSize, inputSize, strideMap] $ \[out, inp, str] -> out * str .>= inp + -- Constrain: (outputSize - 1) * stride < inputSize (minimal ceiling) + precondition [outputSize, inputSize, strideMap] $ \[out, inp, str] -> (out - 1) * str .< inp + return (ref, outputSize) + ) + (zip inputSizePairs strideMaps) + + -- Compute total padding using the SAME formula + -- p_total = max(0, (outputSize - 1) * stride + kernelSize - inputSize) + totalPaddingPairs <- + traverse + ( \((ref, outputSize), (_, kernelSize), ((_, inputSize), strideMap)) -> do + -- Compute total padding: (outputSize - 1) * stride + kernelSize - inputSize + totalPadding <- combineMap "totalPadding" (\[out, s, k, n] -> (out - 1) * s + k - n) [outputSize, strideMap, kernelSize, inputSize] + -- Constrain total padding to be non-negative + precondition [totalPadding] $ \[p] -> p .>= 0 + return (ref, totalPadding) + ) + (zip3 outputSizePairs kernelSizePairs (zip inputSizePairs strideMaps)) + + -- Split total padding into low and high: low = floor(p_total / 2), high = p_total - low + lowPairs <- + traverse + ( \(ref, totalPadding) -> do + r <- toRClassId ref + low <- newNonNegMap "sameLow" r + -- Constrain: low + low <= totalPadding <= low + low + 1 + precondition [low, totalPadding] $ \[l, p] -> (l + l) .<= p .&& p .<= (l + l + 1) + return (ref, low) + ) + totalPaddingPairs + + highPairs <- + traverse + ( \((ref, totalPadding), (_, low)) -> do + r <- toRClassId ref + high <- newNonNegMap "sameHigh" r + -- Constrain: low + high = totalPadding + precondition [low, high, totalPadding] $ \[l, h, p] -> l + h .== p + return (ref, high) + ) + (zip totalPaddingPairs lowPairs) + + let lowPDs = [ref --> l | (ref, l) <- lowPairs] + let highPDs = [ref --> h | (ref, h) <- highPairs] + + -- Unit dilations + ldilPDs <- + traverse + ( \ref -> do + r <- toRClassId ref + o <- newConstMap "ldilation1" 1 r + return (ref --> o) + ) + strideRefs + rdilPDs <- + traverse + ( \ref -> do + r <- toRClassId ref + o <- newConstMap "rdilation1" 1 r + return (ref --> o) + ) + strideRefs + return (lowPDs, ldilPDs, highPDs, rdilPDs) + + outExpr <- + conv + input + weights + config + ConvPadding + { low = lowPDs, + ldilation = ldilPDs, + high = highPDs, + rdilation = rdilPDs + } + + case act of + Relu -> relu @a outExpr + None -> return outExpr + +-- | TASO's split0 operator +split0 :: + (ExprInContext e) => + RClassRef -> + e -> + DSLContext Expr +split0 axis e' = do + e <- liftInContext e' + case e of + E.Concat _ l _ d | d == axis -> return l + E.Concat {} -> throwError "split0: expected Concat on the given axis" + _ -> throwError "split0: input is not a Concat" + +-- | TASO's split1 operator +split1 :: + (ExprInContext e) => + RClassRef -> + e -> + DSLContext Expr +split1 axis e' = do + e <- liftInContext e' + case e of + E.Concat _ _ r d | d == axis -> return r + E.Concat {} -> throwError "split1: expected Concat on the given axis" + _ -> throwError "split1: input is not a Concat" diff --git a/src/TensorRight/Internal/DSL/Verify.hs b/src/TensorRight/Internal/DSL/Verify.hs index bd2bece..3b6ff41 100644 --- a/src/TensorRight/Internal/DSL/Verify.hs +++ b/src/TensorRight/Internal/DSL/Verify.hs @@ -57,8 +57,8 @@ import TensorRight.Internal.DSL.DSL lhsSIMaps, numTensorAssumptions, preConditions, + rankConditions, rhsSIMaps, - singletonRClasses, tensorShapes ), ValidElem, @@ -97,6 +97,7 @@ import TensorRight.Internal.DSL.Identifier (RClassIdentifier) import TensorRight.Internal.DSL.Shape ( AbstractShape, ) +import TensorRight.Internal.Util.Pretty (printFailure, printSuccess, printTitle) verifyDSLWithNDim :: GrisetteSMTConfig -> @@ -210,6 +211,9 @@ verifyDSLWithNDim solverConfig rewrite Env {..} ndim = do else mempty ) maps + + let fixedRClasses = HM.keysSet rankConditions + let nonFixedRClasses = declaredRClasses `HS.difference` fixedRClasses return ( VerifyTask solverConfig @@ -226,8 +230,8 @@ verifyDSLWithNDim solverConfig rewrite Env {..} ndim = do otherSISymbols monitoringTensors monitoringSizes, - declaredRClasses `HS.difference` singletonRClasses, - singletonRClasses, + nonFixedRClasses, + fixedRClasses, exprAbstractShapes HM.! exprId (lhs rewrite) ) @@ -244,7 +248,7 @@ printRewriteNameLine :: DSLContext Rewrite -> IO () printRewriteNameLine rewrite = do case getRewriteName rewrite of Left err -> fail $ T.unpack err - Right name -> putStrLn $ "====> " <> T.unpack name + Right name -> printTitle $ "====> " <> T.unpack name data Result = Result { elapsedTime :: Double, @@ -257,21 +261,14 @@ instance Semigroup Result where printResult :: Maybe String -> Result -> IO () printResult subTheory Result {..} = - putStrLn $ - "[" - <> ( if isRight result - then "SUCCESS" - else "FAIL" - ) - <> maybe "" ("-" <>) subTheory - <> "]: [" - <> show elapsedTime - <> "s] Verification " - <> (if isRight result then "succeeded" else "failed") - <> ( case result of - Right () -> "." - Left e -> " with error: " <> show e - ) + if isRight result + then printSuccess theory $ time <> " Verification succeeded." + else printFailure theory $ time <> " Verification failed with error: " <> showError result + where + showError (Left e) = show e + showError (Right _) = "" + time = "[" <> show elapsedTime <> "s]" + theory = maybe "" ("-" <>) subTheory bracketFailure :: DSLContext Rewrite -> IO () -> IO Result @@ -302,38 +299,33 @@ verifyDSLWithImpl solverConfig theoryInfo rewrite = do Right (rewrite, env) -> do putStrLn $ "Verifying rule " <> T.unpack (name rewrite) let bound0 = baseRClassBound0 rewrite env - (task, nonSingletonRClasses, singletonRClasses, shape) <- + (task, nonSingletonRClasses, _singletonRClasses, shape) <- verifyDSLWithNDim solverConfig rewrite env bound0 - inferredBound <- - inferBound - solverConfig - task - nonSingletonRClasses - singletonRClasses - shape - putStrLn $ "Inferred bounds: " <> show inferredBound + inferredBounds <- + inferBound solverConfig task nonSingletonRClasses (rankConditions env) shape + putStrLn $ "Inferred bounds: " <> show inferredBounds putStrLn $ "[INFO" <> maybe "" ("-" <>) theoryInfo <> "]: Inferred bounds: " - <> show inferredBound + <> show inferredBounds putStrLn $ "[INFO" <> maybe "" ("-" <>) theoryInfo <> "]: Number of bounded verification tasks: " - <> show (product inferredBound) - let ndims = allNdims $ HM.toList inferredBound + <> show (product $ fmap (\(l, u) -> u - l + 1) inferredBounds) + let ndims = allNdims $ HM.toList inferredBounds let fst4 (a, _, _, _) = a traverse_ (verifyDSLWithNDim solverConfig rewrite env >=> verifyRule . fst4) ndims where - allNdims :: [(RClassIdentifier, Int)] -> [HM.HashMap RClassIdentifier Int] + allNdims :: [(RClassIdentifier, (Int, Int))] -> [HM.HashMap RClassIdentifier Int] allNdims inferredBoundList = HM.fromList <$> traverse - ( \(rclassIdent, bound) -> - [(rclassIdent, i) | i <- [1 .. bound]] + ( \(rclassIdent, (lower, upper)) -> + [(rclassIdent, i) | i <- [lower .. upper]] ) inferredBoundList diff --git a/src/TensorRight/Internal/Util/Error.hs b/src/TensorRight/Internal/Util/Error.hs index 5021d39..69e94ce 100644 --- a/src/TensorRight/Internal/Util/Error.hs +++ b/src/TensorRight/Internal/Util/Error.hs @@ -8,6 +8,7 @@ module TensorRight.Internal.Util.Error ErrorEnv, assert, splitWithError, + tshow, ) where @@ -32,6 +33,9 @@ assert :: (UnifiedBranching mode m, MonadError Error m) => Error -> GetBool mode -> m () assert err cond = mrgIf cond (return ()) $ throwError err +tshow :: Show a => a -> T.Text +tshow = T.pack . show + -- May introduce this into Grisette library in the future splitWithError :: forall a. (Mergeable a) => ExceptT Error Union a -> Maybe (SymBool, Union a) diff --git a/src/TensorRight/Internal/Util/Pretty.hs b/src/TensorRight/Internal/Util/Pretty.hs index 69548f1..3648f12 100644 --- a/src/TensorRight/Internal/Util/Pretty.hs +++ b/src/TensorRight/Internal/Util/Pretty.hs @@ -6,6 +6,10 @@ module TensorRight.Internal.Util.Pretty condEnclose, prettyWithConstructor, gprettyParen, + printTitle, + printSuccess, + printFailure, + printWarning, ) where @@ -33,3 +37,15 @@ gprettyParen b = condEnclose b "(" ")" prettyWithConstructor :: Int -> Doc ann -> [Doc ann] -> Doc ann prettyWithConstructor n c l = group $ condEnclose (n > 10) "(" ")" $ align $ nest 2 $ vsep (c : l) + +printTitle :: String -> IO () +printTitle s = putStrLn $ "\ESC[34m" <> s <> "\ESC[0m" + +printSuccess :: String -> String -> IO () +printSuccess theory s = putStrLn $ "\ESC[32m[SUCCESS" <> theory <> "]: " <> s <> "\ESC[0m" + +printFailure :: String -> String -> IO () +printFailure theory s = putStrLn $ "\ESC[31m[FAIL" <> theory <> "]: " <> s <> "\ESC[0m" + +printWarning :: String -> IO () +printWarning s = putStrLn $ "\ESC[33m[WARNING]: " <> s <> "\ESC[0m" diff --git a/tensor-right.cabal b/tensor-right.cabal index 351eae7..b2ab461 100644 --- a/tensor-right.cabal +++ b/tensor-right.cabal @@ -4,803 +4,1134 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -name: tensor-right -version: 0.1.0.0 -synopsis: Automated Verification of Tensor Graph Rewrites -description: TensorRight is an automatic tool that can be used to verify - Tensor Graph Rewrites. -license: Apache-2.0 -license-file: LICENSE -build-type: Simple +name: tensor-right +version: 0.1.0.0 +synopsis: Automated Verification of Tensor Graph Rewrites +description: TensorRight is an automatic tool that can be used to verify + Tensor Graph Rewrites. +license: Apache-2.0 +license-file: LICENSE +build-type: Simple library exposed-modules: - TensorRight - TensorRight.Internal.Core.Axis - TensorRight.Internal.Core.Linearization - TensorRight.Internal.Core.Tensor - TensorRight.Internal.Core.Tensor.TensorInt - TensorRight.Internal.Core.Tensor.Typed - TensorRight.Internal.Core.Verify - TensorRight.Internal.DSL.BoundInference - TensorRight.Internal.DSL.Condition - TensorRight.Internal.DSL.DSL - TensorRight.Internal.DSL.Eval - TensorRight.Internal.DSL.Expr - TensorRight.Internal.DSL.Identifier - TensorRight.Internal.DSL.Parameters - TensorRight.Internal.DSL.RelabelMap - TensorRight.Internal.DSL.Shape - TensorRight.Internal.DSL.Syntax - TensorRight.Internal.DSL.Verify - TensorRight.Internal.Util.Error - TensorRight.Internal.Util.Pretty - other-modules: - Paths_tensor_right - hs-source-dirs: - src + TensorRight + TensorRight.Internal.Core.Axis + TensorRight.Internal.Core.Linearization + TensorRight.Internal.Core.Tensor + TensorRight.Internal.Core.Tensor.TensorInt + TensorRight.Internal.Core.Tensor.Typed + TensorRight.Internal.Core.Verify + TensorRight.Internal.DSL.BoundInference + TensorRight.Internal.DSL.Condition + TensorRight.Internal.DSL.DSL + TensorRight.Internal.DSL.Eval + TensorRight.Internal.DSL.Expr + TensorRight.Internal.DSL.Identifier + TensorRight.Internal.DSL.Parameters + TensorRight.Internal.DSL.RelabelMap + TensorRight.Internal.DSL.Shape + TensorRight.Internal.DSL.Syntax + TensorRight.Internal.DSL.TASO + TensorRight.Internal.DSL.Verify + TensorRight.Internal.Util.Error + TensorRight.Internal.Util.Pretty + other-modules: + Paths_tensor_right + hs-source-dirs: + src ghc-options: -Wextra -Wcompat -Widentities -Wincomplete-record-updates -Wmissing-export-lists -Wmissing-home-modules -Wmissing-import-lists -Wpartial-fields -Wunused-type-patterns -Wno-x-partial -Wno-unrecognised-warning-flags build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + text, + unordered-containers default-language: Haskell2010 executable rules-debug main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right + hs-source-dirs: + rules/debug + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-concat + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/concat + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-conv + main-is: Main.hs + other-modules: + Paths_tensor_right hs-source-dirs: - rules/debug + rules/taso/conv default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-enlarge + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/enlarge + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-ewadd + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/ewadd + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-ewmul + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/ewmul + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-matmul2D + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/matmul2D + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-matmul3D + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/matmul3D + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-relu + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/relu + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-smul + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/smul + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-split + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/split + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers + default-language: Haskell2010 + +executable rules-taso-transpose + main-is: Main.hs + other-modules: + Paths_tensor_right + hs-source-dirs: + rules/taso/transpose + default-extensions: + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-add main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/add + rules/xla/add default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-broadcast main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/broadcast + rules/xla/broadcast default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-clamp main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/clamp + rules/xla/clamp default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-compare main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/compare + rules/xla/compare default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-concat main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/concat + rules/xla/concat default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-conv main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/conv + rules/xla/conv default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-divmod main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/divmod + rules/xla/divmod default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-dot main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/dot + rules/xla/dot default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-dyslice main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/dyslice + rules/xla/dyslice default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-dyupslice main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/dyupslice + rules/xla/dyupslice default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-generalize main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/generalize + rules/xla/generalize default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-iota main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/iota + rules/xla/iota default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-logical main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/logical + rules/xla/logical default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-max main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/max + rules/xla/max default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-mul main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/mul + rules/xla/mul default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-not main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/not + rules/xla/not default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-pad main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/pad + rules/xla/pad default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-reduce main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/reduce + rules/xla/reduce default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-relabel main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/relabel + rules/xla/relabel default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-reverse main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/reverse + rules/xla/reverse default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-select main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/select + rules/xla/select default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-slice main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/slice + rules/xla/slice default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 executable rules-xla-sub main-is: Main.hs other-modules: - Paths_tensor_right + Paths_tensor_right hs-source-dirs: - rules/xla/sub + rules/xla/sub default-extensions: - DuplicateRecordFields - OverloadedStrings - TypeApplications - AllowAmbiguousTypes - ScopedTypeVariables - FlexibleContexts - RankNTypes + DuplicateRecordFields + OverloadedStrings + TypeApplications + AllowAmbiguousTypes + ScopedTypeVariables + FlexibleContexts + RankNTypes ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: - base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , text - , unordered-containers + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + text, + unordered-containers default-language: Haskell2010 test-suite spec type: exitcode-stdio-1.0 main-is: Main.hs other-modules: - Core.LinearizationTest - Core.TensorTest - TestUtil - Paths_tensor_right - hs-source-dirs: - test - ghc-options: -threaded -rtsopts -with-rtsopts=-N - build-depends: - HUnit >=1.6 - , QuickCheck - , base >=4.14 && <5 - , deepseq - , grisette ==0.11.* - , hashable - , mtl - , ordered-containers - , prettyprinter - , sbv - , template-haskell - , tensor-right - , test-framework >=0.8.2 && <0.9 - , test-framework-hunit >=0.3.0.2 && <0.4 - , test-framework-quickcheck2 >=0.3.0.5 && <0.4 - , text - , unordered-containers + Core.LinearizationTest + Core.TensorTest + TestUtil + Paths_tensor_right + hs-source-dirs: + test + ghc-options: -threaded -rtsopts -with-rtsopts=-N + build-depends: + HUnit >=1.6, + QuickCheck, + base >=4.14 && <5, + deepseq, + grisette ==0.11.*, + hashable, + mtl, + ordered-containers, + prettyprinter, + sbv, + template-haskell, + tensor-right, + test-framework >=0.8.2 && <0.9, + test-framework-hunit >=0.3.0.2 && <0.4, + test-framework-quickcheck2 >=0.3.0.5 && <0.4, + text, + unordered-containers default-language: Haskell2010 diff --git a/test/Core/TensorTest.hs b/test/Core/TensorTest.hs index 3e3d75b..50f00f6 100644 --- a/test/Core/TensorTest.hs +++ b/test/Core/TensorTest.hs @@ -17,7 +17,7 @@ import Grisette ITEOp (symIte), LogicalOp ((.&&)), SimpleMergeable, - Solvable (con, isym, ssym), + Solvable (con, ssym), SymBool, SymEq ((.==)), SymInteger,