From 7ae770e052b8ee54b2e3538871f834765e891aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Ma=C4=8Dkay?= Date: Tue, 24 Oct 2017 12:02:51 +0200 Subject: [PATCH 1/4] Add MultithreadedSubnetworkTest --- .../MultithreadedSubnetworkTest.java | 267 ++++++++++++++++++ 1 file changed, 267 insertions(+) create mode 100644 drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java new file mode 100644 index 000000000000..f262423120ec --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests; + + +import org.drools.compiler.CommonTestMethodBase; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.kie.api.KieBase; +import org.kie.api.io.ResourceType; +import org.kie.api.runtime.KieSession; +import org.kie.api.runtime.rule.FactHandle; +import org.kie.internal.utils.KieHelper; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +@RunWith(Parameterized.class) +public class MultithreadedSubnetworkTest extends CommonTestMethodBase { + + protected final String drl; + + @Parameters(name = "DRL={0}") + public static List getTestParameter() { + return Arrays.asList( + new String[] {"sharedSubnetworkRule", sharedSubnetworkRule}, + new String[] {"noSharingSubnetworkRule", noSharingSubnetworkRule}, + new String[] {"notSubnetworkRule", notSubnetworkRule}, + new String[] {"existsSubnetworkRule", existsSubnetworkRule}); + } + + public MultithreadedSubnetworkTest(String drlName, final String drl) { + this.drl = drl; + } + + final static String sharedSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1y when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + + " Long()\n" + + "then\n" + + " System.out.println(\"R1y\");" + + "end\n" + + "\n" + + "rule R1x when\n" + + " AtomicInteger( get() == 1 ) \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )\n" + + "then\n" + + " System.out.println(\"R1x\");" + + "end\n" + + "" + + "rule R2 when\n" + + " $i : AtomicInteger( get() < 3 )\n" + + "then\n" + + " System.out.println(\"R2\");" + + " $i.incrementAndGet();" + + " update($i);" + + "end\n"; + + final static String noSharingSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1y when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + + " Long()\n" + + "then\n" + + " System.out.println(\"R1y\");" + + "end\n" + + "\n" + + "rule R1x when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( $i : AtomicInteger( get() == 1) and String( this == \"test_2\" ) ; count($i) )\n" + + "then\n" + + " System.out.println(\"R1x\");" + + "end\n" + + "" + + "rule R2 when\n" + + " $i : AtomicInteger( get() < 3 )\n" + + "then\n" + + " System.out.println(\"R2\");" + + " $i.incrementAndGet();" + + " update($i);" + + "end\n"; + + final static String notSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " AtomicInteger() \n" + + " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R1\");" + + "end\n" + + "\n" + + "rule R2 when\n" + + " AtomicInteger() \n" + + " String( this != \"test_2\" ) \n" + + " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R2\");" + + "end\n"; + + final static String existsSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " AtomicInteger() \n" + + " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R1\");" + + "end\n" + + "\n" + + "rule R2 when\n" + + " AtomicInteger() \n" + + " String( this != \"test_2\" ) \n" + + " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R2\");" + + "end\n"; + + @Test(timeout = 10000) + public void testConcurrentInsertionsFewObjectsManyThreads() throws InterruptedException { + testConcurrentInsertions(drl, 1, 1000, false, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsManyObjectsFewThreads() throws InterruptedException { + testConcurrentInsertions(drl, 500, 4, false, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsManyObjectsSingleThread() throws InterruptedException { + testConcurrentInsertions(drl, 1000, 1, false, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsNewSessionEachThread() throws InterruptedException { + testConcurrentInsertions(drl, 10, 1000, true, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsNewSessionEachThreadUpdate() throws InterruptedException { + testConcurrentInsertions(drl, 10, 1000, true, true); + } + + private void testConcurrentInsertions(final String drl, final int objectCount, final int threadCount, + final boolean newSessionForEachThread, final boolean updateFacts) throws InterruptedException { + + final KieBase kieBase = new KieHelper().addContent(drl, ResourceType.DRL).build(); + + ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setDaemon(true); + return t; + } + }); + + KieSession ksession = null; + try { + Callable[] tasks = new Callable[threadCount]; + if (newSessionForEachThread) { + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, kieBase, updateFacts); + } + } else { + ksession = kieBase.newKieSession(); + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, ksession, false, updateFacts); + } + } + + CompletionService ecs = new ExecutorCompletionService(executor); + for (Callable task : tasks) { + ecs.submit(task); + } + + int successCounter = 0; + for (int i = 0; i < threadCount; i++) { + try { + if (ecs.take().get()) { + successCounter++; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + assertEquals(threadCount, successCounter); + if (ksession != null) { + ksession.dispose(); + } + } finally { + executor.shutdown(); + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } + } + + private Callable getTask( final int objectCount, final KieBase kieBase, final boolean updateFacts) { + return getTask(objectCount, kieBase.newKieSession(), true, updateFacts); + } + + private Callable getTask( + final int objectCount, + final KieSession ksession, + final boolean disposeSession, + final boolean updateFacts) { + return new Callable() { + public Boolean call() throws Exception { + try { + for (int j = 0; j < 10; j++) { + FactHandle[] facts = new FactHandle[objectCount]; + FactHandle[] stringFacts = new FactHandle[objectCount]; + for (int i = 0; i < objectCount; i++) { + facts[i] = ksession.insert(new AtomicInteger(i)); + stringFacts[i] = ksession.insert("test_" + i); + } + if (updateFacts) { + for (int i = 0; i < objectCount; i++) { + ksession.update(facts[i], new AtomicInteger(-i)); + ksession.update(stringFacts[i], "updated_test_" + i); + } + } + for (int i = 0; i < objectCount; i++) { + ksession.delete(facts[i]); + ksession.delete(stringFacts[i]); + } + ksession.fireAllRules(); + } + return true; + } catch (Exception e) { + e.printStackTrace(); + return false; + } finally { + if (disposeSession) { + ksession.dispose(); + } + } + } + }; + } +} From 198beb3f3c991b0efe6da3a71939621dfe961302 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Ma=C4=8Dkay?= Date: Wed, 8 Nov 2017 10:21:04 +0100 Subject: [PATCH 2/4] Improve test coverage for concurrent KieSessions and KieBases --- .../integrationtests/MultithreadTest.java | 70 +- .../MultithreadedSubnetworkTest.java | 398 +++++------ .../AbstractConcurrentSessionsTest.java | 33 +- .../session/AbstractParallelTest.java | 224 ++++++ .../session/ConcurrentBasesParallelTest.java | 669 ++++++++++++++++++ ...aTypeEvaluationConcurrentSessionsTest.java | 5 +- ...peEvaluationSharedSessionParallelTest.java | 173 +++++ .../session/JoinsConcurrentSessionsTest.java | 5 +- .../session/SharedSessionParallelTest.java | 488 +++++++++++++ .../SubnetworkConcurrentSessionsTest.java | 5 +- 10 files changed, 1827 insertions(+), 243 deletions(-) create mode 100644 drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java create mode 100644 drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java create mode 100644 drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java create mode 100644 drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java index 6e12f882c4b9..26a0d26f0fa4 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java @@ -28,6 +28,7 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; @@ -67,7 +68,7 @@ public class MultithreadTest extends CommonTestMethodBase { private static final Logger LOG = LoggerFactory.getLogger(MultithreadTest.class); @Test(timeout = 10000) - public void testConcurrentInsertionsFewObjectsManyThreads() { + public void testConcurrentInsertionsFewObjectsManyThreads() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + "\n" + "rule \"R\"\n" + @@ -79,7 +80,7 @@ public void testConcurrentInsertionsFewObjectsManyThreads() { } @Test(timeout = 10000) - public void testConcurrentInsertionsManyObjectsFewThreads() { + public void testConcurrentInsertionsManyObjectsFewThreads() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + "\n" + "rule \"R\"\n" + @@ -91,7 +92,7 @@ public void testConcurrentInsertionsManyObjectsFewThreads() { } @Test(timeout = 10000) - public void testConcurrentInsertionsNewSessionEachThreadObjectTypeNode() { + public void testConcurrentInsertionsNewSessionEachThreadObjectTypeNode() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + " query existsBeanSeed5More() \n" + " Bean( seed > 5 ) \n" + @@ -112,7 +113,7 @@ public void testConcurrentInsertionsNewSessionEachThreadObjectTypeNode() { } @Test(timeout = 10000) - public void testConcurrentInsertionsNewSessionEachThread() { + public void testConcurrentInsertionsNewSessionEachThread() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + " query existsBeanSeed5More() \n" + " Bean( seed > 5 ) \n" + @@ -144,11 +145,11 @@ public void testConcurrentInsertionsNewSessionEachThread() { } private void testConcurrentInsertions(final String drl, final int objectCount, final int threadCount, - final boolean newSessionForEachThread, final boolean updateFacts) { + final boolean newSessionForEachThread, final boolean updateFacts) throws InterruptedException { final KieBase kieBase = new KieHelper().addContent(drl, ResourceType.DRL).build(); - Executor executor = Executors.newCachedThreadPool(new ThreadFactory() { + ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); @@ -157,37 +158,44 @@ public Thread newThread(Runnable r) { }); KieSession ksession = null; - Callable[] tasks = new Callable[threadCount]; - if (newSessionForEachThread) { - for (int i = 0; i < threadCount; i++) { - tasks[i] = getTask( objectCount, kieBase, updateFacts ); - } - } else { - ksession = kieBase.newKieSession(); - for (int i = 0; i < threadCount; i++) { - tasks[i] = getTask( objectCount, ksession, false , updateFacts ); + try { + Callable[] tasks = new Callable[threadCount]; + if (newSessionForEachThread) { + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, kieBase, updateFacts); + } + } else { + ksession = kieBase.newKieSession(); + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, ksession, false, updateFacts); + } } - } - CompletionService ecs = new ExecutorCompletionService(executor); - for (Callable task : tasks) { - ecs.submit( task ); - } + CompletionService ecs = new ExecutorCompletionService(executor); + for (Callable task : tasks) { + ecs.submit(task); + } - int successCounter = 0; - for (int i = 0; i < threadCount; i++) { - try { - if ( ecs.take().get() ) { - successCounter++; + int successCounter = 0; + for (int i = 0; i < threadCount; i++) { + try { + if (ecs.take().get()) { + successCounter++; + } + } catch (Exception e) { + throw new RuntimeException(e); } - } catch (Exception e) { - throw new RuntimeException(e); } - } - assertEquals(threadCount, successCounter); - if (ksession != null) { - ksession.dispose(); + assertEquals(threadCount, successCounter); + if (ksession != null) { + ksession.dispose(); + } + } finally { + executor.shutdown(); + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } } } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java index f262423120ec..024fe82a1318 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java @@ -42,226 +42,226 @@ @RunWith(Parameterized.class) public class MultithreadedSubnetworkTest extends CommonTestMethodBase { - protected final String drl; + protected final String drl; - @Parameters(name = "DRL={0}") - public static List getTestParameter() { - return Arrays.asList( - new String[] {"sharedSubnetworkRule", sharedSubnetworkRule}, - new String[] {"noSharingSubnetworkRule", noSharingSubnetworkRule}, - new String[] {"notSubnetworkRule", notSubnetworkRule}, - new String[] {"existsSubnetworkRule", existsSubnetworkRule}); - } + @Parameters(name = "DRL={0}") + public static List getTestParameter() { + return Arrays.asList( + new String[] {"sharedSubnetworkRule", sharedSubnetworkRule}, + new String[] {"noSharingSubnetworkRule", noSharingSubnetworkRule}, + new String[] {"notSubnetworkRule", notSubnetworkRule}, + new String[] {"existsSubnetworkRule", existsSubnetworkRule}); + } - public MultithreadedSubnetworkTest(String drlName, final String drl) { - this.drl = drl; - } + public MultithreadedSubnetworkTest(String drlName, final String drl) { + this.drl = drl; + } - final static String sharedSubnetworkRule = - "import " + AtomicInteger.class.getCanonicalName() + ";\n" + - "rule R1y when\n" + - " AtomicInteger() \n" + - " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + - " Long()\n" + - "then\n" + - " System.out.println(\"R1y\");" + - "end\n" + - "\n" + - "rule R1x when\n" + - " AtomicInteger( get() == 1 ) \n" + - " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )\n" + - "then\n" + - " System.out.println(\"R1x\");" + - "end\n" + - "" + - "rule R2 when\n" + - " $i : AtomicInteger( get() < 3 )\n" + - "then\n" + - " System.out.println(\"R2\");" + - " $i.incrementAndGet();" + - " update($i);" + - "end\n"; + final static String sharedSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1y when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + + " Long()\n" + + "then\n" + + " System.out.println(\"R1y\");" + + "end\n" + + "\n" + + "rule R1x when\n" + + " AtomicInteger( get() == 1 ) \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )\n" + + "then\n" + + " System.out.println(\"R1x\");" + + "end\n" + + "" + + "rule R2 when\n" + + " $i : AtomicInteger( get() < 3 )\n" + + "then\n" + + " System.out.println(\"R2\");" + + " $i.incrementAndGet();" + + " update($i);" + + "end\n"; - final static String noSharingSubnetworkRule = - "import " + AtomicInteger.class.getCanonicalName() + ";\n" + - "rule R1y when\n" + - " AtomicInteger() \n" + - " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + - " Long()\n" + - "then\n" + - " System.out.println(\"R1y\");" + - "end\n" + - "\n" + - "rule R1x when\n" + - " AtomicInteger() \n" + - " Number() from accumulate ( $i : AtomicInteger( get() == 1) and String( this == \"test_2\" ) ; count($i) )\n" + - "then\n" + - " System.out.println(\"R1x\");" + - "end\n" + - "" + - "rule R2 when\n" + - " $i : AtomicInteger( get() < 3 )\n" + - "then\n" + - " System.out.println(\"R2\");" + - " $i.incrementAndGet();" + - " update($i);" + - "end\n"; + final static String noSharingSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1y when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + + " Long()\n" + + "then\n" + + " System.out.println(\"R1y\");" + + "end\n" + + "\n" + + "rule R1x when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( $i : AtomicInteger( get() == 1) and String( this == \"test_2\" ) ; count($i) )\n" + + "then\n" + + " System.out.println(\"R1x\");" + + "end\n" + + "" + + "rule R2 when\n" + + " $i : AtomicInteger( get() < 3 )\n" + + "then\n" + + " System.out.println(\"R2\");" + + " $i.incrementAndGet();" + + " update($i);" + + "end\n"; - final static String notSubnetworkRule = - "import " + AtomicInteger.class.getCanonicalName() + ";\n" + - "rule R1 when\n" + - " AtomicInteger() \n" + - " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + - "then\n" + - " System.out.println(\"R1\");" + - "end\n" + - "\n" + - "rule R2 when\n" + - " AtomicInteger() \n" + - " String( this != \"test_2\" ) \n" + - " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + - "then\n" + - " System.out.println(\"R2\");" + - "end\n"; + final static String notSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " AtomicInteger() \n" + + " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R1\");" + + "end\n" + + "\n" + + "rule R2 when\n" + + " AtomicInteger() \n" + + " String( this != \"test_2\" ) \n" + + " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R2\");" + + "end\n"; - final static String existsSubnetworkRule = - "import " + AtomicInteger.class.getCanonicalName() + ";\n" + - "rule R1 when\n" + - " AtomicInteger() \n" + - " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + - "then\n" + - " System.out.println(\"R1\");" + - "end\n" + - "\n" + - "rule R2 when\n" + - " AtomicInteger() \n" + - " String( this != \"test_2\" ) \n" + - " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + - "then\n" + - " System.out.println(\"R2\");" + - "end\n"; + final static String existsSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " AtomicInteger() \n" + + " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R1\");" + + "end\n" + + "\n" + + "rule R2 when\n" + + " AtomicInteger() \n" + + " String( this != \"test_2\" ) \n" + + " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R2\");" + + "end\n"; - @Test(timeout = 10000) - public void testConcurrentInsertionsFewObjectsManyThreads() throws InterruptedException { - testConcurrentInsertions(drl, 1, 1000, false, false); - } + @Test(timeout = 10000) + public void testConcurrentInsertionsFewObjectsManyThreads() throws InterruptedException { + testConcurrentInsertions(drl, 1, 1000, false, false); + } - @Test(timeout = 10000) - public void testConcurrentInsertionsManyObjectsFewThreads() throws InterruptedException { - testConcurrentInsertions(drl, 500, 4, false, false); - } + @Test(timeout = 10000) + public void testConcurrentInsertionsManyObjectsFewThreads() throws InterruptedException { + testConcurrentInsertions(drl, 500, 4, false, false); + } - @Test(timeout = 10000) - public void testConcurrentInsertionsManyObjectsSingleThread() throws InterruptedException { - testConcurrentInsertions(drl, 1000, 1, false, false); - } + @Test(timeout = 10000) + public void testConcurrentInsertionsManyObjectsSingleThread() throws InterruptedException { + testConcurrentInsertions(drl, 1000, 1, false, false); + } - @Test(timeout = 10000) - public void testConcurrentInsertionsNewSessionEachThread() throws InterruptedException { - testConcurrentInsertions(drl, 10, 1000, true, false); - } + @Test(timeout = 10000) + public void testConcurrentInsertionsNewSessionEachThread() throws InterruptedException { + testConcurrentInsertions(drl, 10, 1000, true, false); + } - @Test(timeout = 10000) - public void testConcurrentInsertionsNewSessionEachThreadUpdate() throws InterruptedException { - testConcurrentInsertions(drl, 10, 1000, true, true); - } + @Test(timeout = 10000) + public void testConcurrentInsertionsNewSessionEachThreadUpdate() throws InterruptedException { + testConcurrentInsertions(drl, 10, 1000, true, true); + } - private void testConcurrentInsertions(final String drl, final int objectCount, final int threadCount, - final boolean newSessionForEachThread, final boolean updateFacts) throws InterruptedException { + private void testConcurrentInsertions(final String drl, final int objectCount, final int threadCount, + final boolean newSessionForEachThread, final boolean updateFacts) throws InterruptedException { - final KieBase kieBase = new KieHelper().addContent(drl, ResourceType.DRL).build(); + final KieBase kieBase = new KieHelper().addContent(drl, ResourceType.DRL).build(); - ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { - public Thread newThread(Runnable r) { - Thread t = new Thread(r); - t.setDaemon(true); - return t; - } - }); - - KieSession ksession = null; - try { - Callable[] tasks = new Callable[threadCount]; - if (newSessionForEachThread) { - for (int i = 0; i < threadCount; i++) { - tasks[i] = getTask(objectCount, kieBase, updateFacts); + ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setDaemon(true); + return t; } - } else { - ksession = kieBase.newKieSession(); - for (int i = 0; i < threadCount; i++) { - tasks[i] = getTask(objectCount, ksession, false, updateFacts); + }); + + KieSession ksession = null; + try { + Callable[] tasks = new Callable[threadCount]; + if (newSessionForEachThread) { + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, kieBase, updateFacts); + } + } else { + ksession = kieBase.newKieSession(); + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, ksession, false, updateFacts); + } } - } - CompletionService ecs = new ExecutorCompletionService(executor); - for (Callable task : tasks) { - ecs.submit(task); - } + CompletionService ecs = new ExecutorCompletionService(executor); + for (Callable task : tasks) { + ecs.submit(task); + } - int successCounter = 0; - for (int i = 0; i < threadCount; i++) { - try { - if (ecs.take().get()) { - successCounter++; - } - } catch (Exception e) { - throw new RuntimeException(e); + int successCounter = 0; + for (int i = 0; i < threadCount; i++) { + try { + if (ecs.take().get()) { + successCounter++; + } + } catch (Exception e) { + throw new RuntimeException(e); + } } - } - assertEquals(threadCount, successCounter); - if (ksession != null) { - ksession.dispose(); - } - } finally { - executor.shutdown(); - if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { - executor.shutdownNow(); - } - } - } + assertEquals(threadCount, successCounter); + if (ksession != null) { + ksession.dispose(); + } + } finally { + executor.shutdown(); + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } + } - private Callable getTask( final int objectCount, final KieBase kieBase, final boolean updateFacts) { - return getTask(objectCount, kieBase.newKieSession(), true, updateFacts); - } + private Callable getTask(final int objectCount, final KieBase kieBase, final boolean updateFacts) { + return getTask(objectCount, kieBase.newKieSession(), true, updateFacts); + } - private Callable getTask( - final int objectCount, - final KieSession ksession, - final boolean disposeSession, - final boolean updateFacts) { - return new Callable() { - public Boolean call() throws Exception { - try { - for (int j = 0; j < 10; j++) { - FactHandle[] facts = new FactHandle[objectCount]; - FactHandle[] stringFacts = new FactHandle[objectCount]; - for (int i = 0; i < objectCount; i++) { - facts[i] = ksession.insert(new AtomicInteger(i)); - stringFacts[i] = ksession.insert("test_" + i); - } - if (updateFacts) { - for (int i = 0; i < objectCount; i++) { - ksession.update(facts[i], new AtomicInteger(-i)); - ksession.update(stringFacts[i], "updated_test_" + i); - } - } - for (int i = 0; i < objectCount; i++) { - ksession.delete(facts[i]); - ksession.delete(stringFacts[i]); - } - ksession.fireAllRules(); - } - return true; - } catch (Exception e) { - e.printStackTrace(); - return false; - } finally { - if (disposeSession) { - ksession.dispose(); - } + private Callable getTask( + final int objectCount, + final KieSession ksession, + final boolean disposeSession, + final boolean updateFacts) { + return new Callable() { + public Boolean call() throws Exception { + try { + for (int j = 0; j < 10; j++) { + FactHandle[] facts = new FactHandle[objectCount]; + FactHandle[] stringFacts = new FactHandle[objectCount]; + for (int i = 0; i < objectCount; i++) { + facts[i] = ksession.insert(new AtomicInteger(i)); + stringFacts[i] = ksession.insert("test_" + i); + } + if (updateFacts) { + for (int i = 0; i < objectCount; i++) { + ksession.update(facts[i], new AtomicInteger(-i)); + ksession.update(stringFacts[i], "updated_test_" + i); + } + } + for (int i = 0; i < objectCount; i++) { + ksession.delete(facts[i]); + ksession.delete(stringFacts[i]); + } + ksession.fireAllRules(); + } + return true; + } catch (Exception e) { + e.printStackTrace(); + return false; + } finally { + if (disposeSession) { + ksession.dispose(); + } + } } - } - }; - } + }; + } } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java index 26f53a261fd8..287c3e6b0a4a 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java @@ -47,19 +47,26 @@ public abstract class AbstractConcurrentSessionsTest { protected final boolean enforcedJitting; protected final boolean serializeKieBase; + protected final boolean sharedKieBase; - @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}") + @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}, Share KieBase={2}") public static List getTestParameters() { return Arrays.asList( - new Boolean[]{false, false}, - new Boolean[]{false, true}, - new Boolean[]{true, false}, - new Boolean[]{true, true}); + new Boolean[]{false, false, false}, + new Boolean[]{false, true, false}, + new Boolean[]{true, false, false}, + new Boolean[]{true, true, false}, + new Boolean[]{false, false, true}, + new Boolean[]{false, true, true}, + new Boolean[]{true, false, true}, + new Boolean[]{true, true, true}); } - public AbstractConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { + public AbstractConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { this.enforcedJitting = enforcedJitting; this.serializeKieBase = serializeKieBase; + this.sharedKieBase = sharedKieBase; } interface KieSessionExecutor { @@ -103,10 +110,22 @@ public Thread newThread( final Runnable r ) { for (int i = 0; i < threadCount; i++) { final int counter = i; + + KieBase kieBaseLocal; + if (sharedKieBase) { + kieBaseLocal = kieBase; + } else { + if (serializeKieBase) { + kieBaseLocal = serializeAndDeserializeKieBase(kieHelper.build(kieBaseOptions)); + } else { + kieBaseLocal = kieHelper.build(kieBaseOptions); + } + } + tasks[i] = new Callable() { @Override public Boolean call() throws Exception { - return kieSessionExecutor.execute(kieBase.newKieSession(), counter); + return kieSessionExecutor.execute(kieBaseLocal.newKieSession(), counter); } }; } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java new file mode 100644 index 000000000000..2967cdda6553 --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java @@ -0,0 +1,224 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.kie.api.KieBase; +import org.kie.api.conf.KieBaseOption; +import org.kie.api.io.ResourceType; +import org.kie.api.runtime.KieSession; +import org.kie.internal.conf.ConstraintJittingThresholdOption; +import org.kie.internal.utils.KieHelper; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public abstract class AbstractParallelTest { + + protected final boolean enforcedJitting; + protected final boolean serializeKieBase; + + @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}") + public static List getTestParameters() { + return Arrays.asList( + new Boolean[]{false, false}, + new Boolean[]{false, true}, + new Boolean[]{true, false}, + new Boolean[]{true, true}); + } + + public AbstractParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + this.enforcedJitting = enforcedJitting; + this.serializeKieBase = serializeKieBase; + } + + public void parallelTest(int numberOfThreads, ParallelTestExecutor executor) throws InterruptedException { + + Callable[] tasks; + + final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads, new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + final Thread t = new Thread(r); + t.setDaemon(true); + return t; + } + }); + + tasks = getTasks(numberOfThreads, executor); + final CompletionService completionService = new ExecutorCompletionService(executorService); + + try { + for (Callable task : tasks) { + completionService.submit(task); + } + + int successCounter = 0; + for (int i = 0; i < numberOfThreads; i++) { + try { + if (completionService.take().get()) { + successCounter ++; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + assertEquals(numberOfThreads, successCounter); + + } finally { + executorService.shutdown(); + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } + + } + + public interface ParallelTestExecutor { + public boolean execute(int counter) throws InterruptedException; + } + + private Callable[] getTasks(int numberOfThreads, ParallelTestExecutor executor) { + Callable[] tasks = new Callable[numberOfThreads]; + for (int i = 0; i < numberOfThreads; i++) { + final int counter = i; + tasks[counter] = new Callable() { + @Override + public Boolean call() throws Exception { + return executor.execute(counter); + } + }; + } + return tasks; + } + + private static synchronized KieHelper getKieHelper(String... drls) { + KieHelper kieHelper = new KieHelper(); + for (String drl : drls) { + kieHelper.addContent(drl, ResourceType.DRL); + } + return kieHelper; + } + + protected synchronized KieBase getKieBase(String... drls) { + KieBaseOption[] kieBaseOptions = (enforcedJitting) ? new KieBaseOption[]{ConstraintJittingThresholdOption.get(0)} + : new KieBaseOption[]{}; + KieBase kieBase = getKieHelper(drls).build(kieBaseOptions); + if (serializeKieBase) { + kieBase = serializeAndDeserializeKieBase(kieBase); + } + return kieBase; + } + + private KieBase serializeAndDeserializeKieBase(final KieBase kieBase) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream( baos ); + try { + out.writeObject( kieBase ); + out.flush(); + } finally { + out.close(); + } + + ObjectInputStream in = new ObjectInputStream( new ByteArrayInputStream( baos.toByteArray() ) ); + try { + return (KieBase) in.readObject(); + } finally { + in.close(); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + protected static void disposeSession(KieSession session) { + if (session != null) { + session.dispose(); + } + } + + protected void checkList(int end, List list) { + checkList(0, end, list); + } + + protected void checkList(int start, int end, List list) { + int expectedSize = end-start; + checkList(start, end, list, expectedSize); + } + + protected void checkList(int start, int end, List list, int expectedSize) { + assertEquals(expectedSize, list.size()); + for (int i = start; i < end; i++) { + assertTrue(list.contains(""+i)); + } + } + + public static class BeanA { + int seed; + public BeanA() { + this.seed = 1; + } + public BeanA(int seed) { + this.seed = seed; + } + public int getSeed() { + return this.seed; + } + public void setSeed(int seed) { + this.seed = seed; + } + } + + public static class BeanB { + int seed; + public BeanB() { + this.seed = 1; + } + public BeanB(int seed) { + this.seed = seed; + } + public int getSeed() { + return this.seed; + } + public void setSeed(int seed) { + this.seed = seed; + } + } + +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java new file mode 100644 index 000000000000..2b2d27239ace --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java @@ -0,0 +1,669 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.junit.Test; +import org.kie.api.KieBase; +import org.kie.api.runtime.KieSession; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +public class ConcurrentBasesParallelTest extends AbstractParallelTest { + + public ConcurrentBasesParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testDifferentRuleset1() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA( seed == " + counter + ") " + + "then " + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfObjects; i++) { + session.insert(new BeanA(i)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset2() throws InterruptedException { + int numberOfThreads = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA( seed == " + counter + ") " + + "then " + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfThreads; i++) { + if (i != counter) { + session.insert(new BeanA(i)); + } + } + return session.fireAllRules() == 0; + + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset3() throws InterruptedException { + int numberOfThreads = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global "+ AtomicInteger.class.getCanonicalName() + " result;\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA()" + + "then " + + " result.set(" + counter + ");" + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + AtomicInteger r = new AtomicInteger(0); + session.setGlobal("result", r); + assertEquals(1, session.fireAllRules()); + assertEquals(counter, r.get()); + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset4() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + " list.add(\"${className}\");" + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + session.insert(new BeanB()); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = session.fireAllRules(); + assertEquals(1, list.size()); + assertEquals(className, list.get(0)); + return rulesFired == 1; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset5() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanB()); + } else { + session.insert(new BeanA()); + } + int rulesFired = session.fireAllRules(); + assertEquals(0, rulesFired); + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset6() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(1, rulesFired); + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetNot() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " not ${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(0, rulesFired); + return true; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetExists() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " exists ${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(1, rulesFired); + return true; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetSharedSubnetwork() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " $bean : ${className}() \n" + + "then " + + "end"; + + String subnetworkRuleTemplate = "rule Rule_subnetwork " + + "when " + + " $bean : ${className}() \n" + + " Number( doubleValue > 0) from" + + " accumulate ( BeanA() and $s : String(), count($s) )" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + String subnetworkRule = subnetworkRuleTemplate.replace("${className}", className); + + KieBase base = getKieBase(rule, subnetworkRule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert("test"); + if (counter % 2 == 0) { + session.insert(new BeanA()); + assertEquals(2, session.fireAllRules()); + } else { + session.insert(new BeanB()); + assertEquals(1, session.fireAllRules()); + } + return true; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset1() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter %2 == 0) { + session.insert(new BeanA(counter)); + return session.fireAllRules() == 1; + } else { + return session.fireAllRules() == 0; + } + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset2() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + session.insert(new BeanB()); + return session.fireAllRules() == 2; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset3() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter %2 == 0) { + session.insert(new BeanA(counter)); + } else { + session.insert(new BeanB(counter)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset4() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleNotA " + + "when " + + " not BeanA() " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleNotB " + + "when " + + " not BeanB() " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter %2 == 0) { + session.insert(new BeanA(counter)); + } else { + session.insert(new BeanB(counter)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset5() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleNotA " + + "when " + + " not BeanA() " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleNotB " + + "when " + + " not BeanB() " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter %2 == 0) { + session.insert(new BeanA(counter)); + session.insert(new BeanB(counter)); + return session.fireAllRules() == 0; + } else { + return session.fireAllRules() == 2; + } + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + public void parallelTest(int numberOfThreads, ParallelTestExecutor executor) throws InterruptedException { + + Callable[] tasks = new Callable[numberOfThreads]; + + final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads, new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + final Thread t = new Thread(r); + t.setDaemon(true); + return t; + } + }); + + try { + for (int i = 0; i < numberOfThreads; i++) { + final int counter = i; + tasks[counter] = new Callable() { + @Override + public Boolean call() throws Exception { + return executor.execute(counter); + } + }; + } + + final CompletionService completionService = new ExecutorCompletionService(executorService); + + for (Callable task : tasks) { + completionService.submit(task); + } + + int successCounter = 0; + for (int i = 0; i < numberOfThreads; i++) { + try { + if (completionService.take().get()) { + successCounter ++; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + assertEquals(numberOfThreads, successCounter); + + } finally { + executorService.shutdown(); + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } + } +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java index 1ccf2f9f1484..1cb93ab03675 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java @@ -35,8 +35,9 @@ public class DataTypeEvaluationConcurrentSessionsTest extends AbstractConcurrentSessionsTest { - public DataTypeEvaluationConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); + public DataTypeEvaluationConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { + super(enforcedJitting, serializeKieBase, sharedKieBase); } @Test diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java new file mode 100644 index 000000000000..c59f2cace0ff --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.drools.compiler.integrationtests.facts.AnEnum; +import org.drools.compiler.integrationtests.facts.FactWithBigDecimal; +import org.drools.compiler.integrationtests.facts.FactWithBoolean; +import org.drools.compiler.integrationtests.facts.FactWithByte; +import org.drools.compiler.integrationtests.facts.FactWithCharacter; +import org.drools.compiler.integrationtests.facts.FactWithDouble; +import org.drools.compiler.integrationtests.facts.FactWithEnum; +import org.drools.compiler.integrationtests.facts.FactWithFloat; +import org.drools.compiler.integrationtests.facts.FactWithInteger; +import org.drools.compiler.integrationtests.facts.FactWithLong; +import org.drools.compiler.integrationtests.facts.FactWithShort; +import org.drools.compiler.integrationtests.facts.FactWithString; +import org.junit.Test; +import org.kie.api.runtime.KieSession; + +import java.math.BigDecimal; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +public class DataTypeEvaluationSharedSessionParallelTest extends AbstractParallelTest { + + public DataTypeEvaluationSharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testBooleanPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanValue == false) \n", new FactWithBoolean(false)); + } + + @Test + public void testBoolean() throws InterruptedException { + testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanObjectValue == false) \n", new FactWithBoolean(false)); + } + + @Test + public void testBytePrimitive() throws InterruptedException { + testFactAttributeType(" $factWithByte: FactWithByte(byteValue == 15) \n", new FactWithByte((byte) 15)); + } + + @Test + public void testByte() throws InterruptedException { + testFactAttributeType(" $factWithByte: FactWithByte(byteObjectValue == 15) \n", new FactWithByte((byte) 15)); + } + + @Test + public void testShortPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithShort: FactWithShort(shortValue == 15) \n", new FactWithShort((short) 15)); + } + + @Test + public void testShort() throws InterruptedException { + testFactAttributeType(" $factWithShort: FactWithShort(shortObjectValue == 15) \n", new FactWithShort((short) 15)); + } + + @Test + public void testIntPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithInt: FactWithInteger(intValue == 15) \n", new FactWithInteger(15)); + } + + @Test + public void testInteger() throws InterruptedException { + testFactAttributeType(" $factWithInteger: FactWithInteger(integerValue == 15) \n", new FactWithInteger(15)); + } + + @Test + public void testLongPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithLong: FactWithLong(longValue == 15) \n", new FactWithLong(15)); + } + + @Test + public void testLong() throws InterruptedException { + testFactAttributeType(" $factWithLong: FactWithLong(longObjectValue == 15) \n", new FactWithLong(15)); + } + + @Test + public void testFloatPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithFloat: FactWithFloat(floatValue == 15.1) \n", new FactWithFloat(15.1f)); + } + + @Test + public void testFloat() throws InterruptedException { + testFactAttributeType(" $factWithFloat: FactWithFloat(floatObjectValue == 15.1) \n", new FactWithFloat(15.1f)); + } + + @Test + public void testDoublePrimitive() throws InterruptedException { + testFactAttributeType(" $factWithDouble: FactWithDouble(doubleValue == 15.1) \n", new FactWithDouble(15.1d)); + } + + @Test + public void testDouble() throws InterruptedException { + testFactAttributeType(" $factWithDouble: FactWithDouble(doubleObjectValue == 15.1) \n", new FactWithDouble(15.1d)); + } + + @Test + public void testBigDecimal() throws InterruptedException { + testFactAttributeType(" $factWithBigDecimal: FactWithBigDecimal(bigDecimalValue == 10) \n", new FactWithBigDecimal(BigDecimal.TEN)); + } + + @Test + public void testCharPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithChar: FactWithCharacter(charValue == 'a') \n", new FactWithCharacter('a')); + } + + @Test + public void testCharacter() throws InterruptedException { + testFactAttributeType(" $factWithChar: FactWithCharacter(characterValue == 'a') \n", new FactWithCharacter('a')); + } + + @Test + public void testString() throws InterruptedException { + testFactAttributeType(" $factWithString: FactWithString(stringValue == \"test\") \n", new FactWithString("test")); + } + + @Test + public void testEnum() throws InterruptedException { + testFactAttributeType(" $factWithEnum: FactWithEnum(enumValue == AnEnum.FIRST) \n", new FactWithEnum(AnEnum.FIRST)); + } + + private void testFactAttributeType(final String ruleConstraint, final Object factInserted) throws InterruptedException { + int numberOfThreads = 10; + + + final String drl = + " import org.drools.compiler.integrationtests.facts.*;\n" + + " global " + AtomicInteger.class.getCanonicalName() + " numberOfFirings;\n" + + " rule R1 \n" + + " when \n" + + ruleConstraint + + " then \n" + + " numberOfFirings.incrementAndGet(); \n" + + " end "; + + KieSession kieSession = getKieBase(drl).newKieSession(); + + AtomicInteger numberOfFirings = new AtomicInteger(0); + + parallelTest(numberOfThreads, new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + if (kieSession.getGlobal("numberOfFirings") == null) { + kieSession.setGlobal("numberOfFirings", numberOfFirings); + }; + kieSession.insert(factInserted); + kieSession.fireAllRules(); + return true; + } + }); + disposeSession(kieSession); + assertEquals(1, numberOfFirings.get()); + } + +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java index e32f65f297da..159d2a540043 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java @@ -29,8 +29,9 @@ public class JoinsConcurrentSessionsTest extends AbstractConcurrentSessionsTest { - public JoinsConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); + public JoinsConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { + super(enforcedJitting, serializeKieBase, sharedKieBase); } @Test diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java new file mode 100644 index 000000000000..ec8fe1aee744 --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java @@ -0,0 +1,488 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.junit.Test; +import org.kie.api.runtime.KieSession; +import org.kie.api.runtime.rule.FactHandle; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; + +public class SharedSessionParallelTest extends AbstractParallelTest { + + public SharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testNoExceptions() throws InterruptedException { + String drl = "rule R1 when String() then end"; + + int repetitions = 100; + int numberOfObjects = 1000; + int countOfThreads = 100; + + for (int i = 0; i < repetitions; i++) { + + KieSession kieSession = getKieBase(drl).newKieSession(); + + parallelTest(countOfThreads, new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + for (int j = 0; j < numberOfObjects; j++) { + kieSession.insert("test_" + numberOfObjects); + } + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }); + + disposeSession(kieSession); + } + } + + @Test + public void testCheckOneThreadOnly() throws InterruptedException { + int threadCount = 100; + List list = new ArrayList<>(); + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed) " + + "then " + + " list.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + CountDownLatch latch = new CountDownLatch(threadCount); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter)); + latch.countDown(); + + if (counter == 0) { + latch.await(); + return kieSession.fireAllRules() == threadCount; + } + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + assertEquals(threadCount, list.size()); + for (int i = 0; i < threadCount; i++) { + assertTrue(list.contains(""+i)); + } + } + + @Test + public void testCorrectFirings() throws InterruptedException { + int threadCount = 100; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List globalList;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed) " + + "then " + + " globalList.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("globalList", list); + kieSession.insert(new BeanA(counter)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, list); + } + + @Test + public void testCorrectFirings2() throws InterruptedException { + int threadCount = 100; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed, seed == 0) " + + "then " + + " list.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter%2)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + assertTrue(list.contains("" + 0)); + assertFalse(list.contains("" + 1)); + int expectedListSize = ((threadCount-1) / 2) + 1; + assertEquals(expectedListSize, list.size()); + } + + + @Test + public void testLongRunningRule() throws InterruptedException { + int threadCount = 100; + int seed = threadCount + 200; + int objectCount = 1000; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > " + threadCount + ") " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + " Thread.sleep(5);" + + "end"; + + String listDrl = "global java.util.List list2;\n" + + "rule listRule " + + "when " + + " BeanA($n : seed, seed < " + threadCount + ") " + + "then " + + " list2.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + List list2 = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.setGlobal("list2", list2); + kieSession.insert(new BeanA(seed)); + barrier.await(); + kieSession.fireAllRules(); + return true; + } else { + barrier.await(); + Thread.sleep(100); + for (int i = 0; i < objectCount; i++) { + kieSession.insert(new BeanA(counter)); + } + kieSession.fireAllRules(); + return true; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, seed, list); + checkList(1, threadCount, list2, (threadCount-1)*objectCount); + } + + @Test + public void testLongRunningRule2() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String waitingRule = "rule waitingRule " + + "when " + + " String( this == \"wait\" ) " + + "then " + + " Thread.sleep(10);" + + "end"; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, waitingRule).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.insert("wait"); + kieSession.insert(new BeanA(seed)); + barrier.await(); + return kieSession.fireAllRules() == seed * threadCount + 1; + } else { + barrier.await(); + Thread.sleep(10); + kieSession.insert(new BeanA(seed)); + return kieSession.fireAllRules() == 0; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(0, seed, list, seed * threadCount); + } + + @Test + public void testLongRunningRule3() throws InterruptedException { + int threadCount = 10; + int seed = threadCount + 50; + int objectCount = 1000; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > " + threadCount + ") " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + String listDrl = "global java.util.List list2;\n" + + "rule listRule " + + "when " + + " BeanA($n : seed, seed < " + threadCount + ") " + + "then " + + " list2.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + List list2 = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter % 2 == 0) { + kieSession.setGlobal("list", list); + kieSession.setGlobal("list2", list2); + kieSession.insert(new BeanA(seed)); + barrier.await(); + kieSession.fireAllRules(); + return true; + } else { + barrier.await(); + Thread.sleep(100); + for (int i = 0; i < objectCount; i++) { + kieSession.insert(new BeanA(counter)); + } + kieSession.fireAllRules(); + return true; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + int listExpectedSize = (threadCount/2 + threadCount%2) * (seed-threadCount); + int list2ExpectedSize = threadCount/2 * objectCount; + for (int i = 0; i < threadCount; i++) { + if (i % 2 == 1) { + list2.contains("" + i); + } + } + assertEquals(listExpectedSize, list.size()); + assertEquals(list2ExpectedSize, list2.size()); + } + + @Test + public void testCountdownBean() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule countdown " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + BeanA bean = new BeanA(seed); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.insert(bean); + } + barrier.await(); + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(seed, list); + assertEquals(0, bean.getSeed()); + } + + @Test + public void testCountdownBean2() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule countdown " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + List list = new ArrayList<>(); + BeanA[] beans = new BeanA[threadCount]; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + BeanA bean = new BeanA(seed); + beans[counter] = bean; + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + } + kieSession.insert(bean); + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + checkList(0, seed, list, seed * threadCount); + for (BeanA bean : beans) { + assertEquals(0, bean.getSeed()); + } + } + + @Test + public void testOneRulePerThread() throws InterruptedException { + int threadCount = 1000; + + String[] drls = new String[threadCount]; + for (int i = 0; i < threadCount; i++) { + drls[i] = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R" + i + " " + + "when " + + " $bean : BeanA( seed == " + i + " ) " + + "then " + + " list.add(\"" + i + "\");" + + "end"; + } + + KieSession kieSession = getKieBase(drls).newKieSession(); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, list); + } +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java index 49f30045a112..377689436e74 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java @@ -21,8 +21,9 @@ public class SubnetworkConcurrentSessionsTest extends AbstractConcurrentSessionsTest { - public SubnetworkConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); + public SubnetworkConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { + super(enforcedJitting, serializeKieBase, sharedKieBase); } @Test(timeout = 5000) From fab299cce4c55894a3d8495f16dd0c3f48c7077a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Ma=C4=8Dkay?= Date: Thu, 16 Nov 2017 11:52:34 +0100 Subject: [PATCH 3/4] Add tests for functions to ConcurrentBasesParallelTest --- .../session/AbstractParallelTest.java | 342 ++--- .../session/ConcurrentBasesParallelTest.java | 1273 +++++++++-------- ...peEvaluationSharedSessionParallelTest.java | 263 ++-- .../session/SharedSessionParallelTest.java | 876 ++++++------ 4 files changed, 1400 insertions(+), 1354 deletions(-) diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java index 2967cdda6553..3c8fd4a74e8e 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java @@ -46,179 +46,187 @@ @RunWith(Parameterized.class) public abstract class AbstractParallelTest { - protected final boolean enforcedJitting; - protected final boolean serializeKieBase; - - @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}") - public static List getTestParameters() { - return Arrays.asList( - new Boolean[]{false, false}, - new Boolean[]{false, true}, - new Boolean[]{true, false}, - new Boolean[]{true, true}); - } - - public AbstractParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { - this.enforcedJitting = enforcedJitting; - this.serializeKieBase = serializeKieBase; - } - - public void parallelTest(int numberOfThreads, ParallelTestExecutor executor) throws InterruptedException { - - Callable[] tasks; - - final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads, new ThreadFactory() { - @Override - public Thread newThread(Runnable r) { - final Thread t = new Thread(r); - t.setDaemon(true); - return t; - } - }); - - tasks = getTasks(numberOfThreads, executor); - final CompletionService completionService = new ExecutorCompletionService(executorService); - - try { - for (Callable task : tasks) { - completionService.submit(task); - } - - int successCounter = 0; - for (int i = 0; i < numberOfThreads; i++) { - try { - if (completionService.take().get()) { - successCounter ++; - } - } catch (Exception ex) { - throw new RuntimeException(ex); - } - } + protected final boolean enforcedJitting; + protected final boolean serializeKieBase; - assertEquals(numberOfThreads, successCounter); + @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}") + public static List getTestParameters() { + return Arrays.asList( + new Boolean[] {false, false}, + new Boolean[] {false, true}, + new Boolean[] {true, false}, + new Boolean[] {true, true}); + } - } finally { - executorService.shutdown(); - if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { - executorService.shutdownNow(); - } - } + public AbstractParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + this.enforcedJitting = enforcedJitting; + this.serializeKieBase = serializeKieBase; + } - } + public void parallelTest(int numberOfThreads, ParallelTestExecutor executor) throws InterruptedException { - public interface ParallelTestExecutor { - public boolean execute(int counter) throws InterruptedException; - } + Callable[] tasks; - private Callable[] getTasks(int numberOfThreads, ParallelTestExecutor executor) { - Callable[] tasks = new Callable[numberOfThreads]; - for (int i = 0; i < numberOfThreads; i++) { - final int counter = i; - tasks[counter] = new Callable() { + final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads, new ThreadFactory() { @Override - public Boolean call() throws Exception { - return executor.execute(counter); + public Thread newThread(Runnable r) { + final Thread t = new Thread(r); + t.setDaemon(true); + return t; + } + }); + + tasks = getTasks(numberOfThreads, executor); + final CompletionService completionService = new ExecutorCompletionService(executorService); + + try { + for (Callable task : tasks) { + completionService.submit(task); + } + + int successCounter = 0; + for (int i = 0; i < numberOfThreads; i++) { + try { + if (completionService.take().get()) { + successCounter++; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + assertEquals(numberOfThreads, successCounter); + + } finally { + executorService.shutdown(); + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } + + } + + public interface ParallelTestExecutor { + public boolean execute(int counter) throws InterruptedException; + } + + private Callable[] getTasks(int numberOfThreads, ParallelTestExecutor executor) { + Callable[] tasks = new Callable[numberOfThreads]; + for (int i = 0; i < numberOfThreads; i++) { + final int counter = i; + tasks[counter] = new Callable() { + @Override + public Boolean call() throws Exception { + return executor.execute(counter); + } + }; + } + return tasks; + } + + private static synchronized KieHelper getKieHelper(String... drls) { + KieHelper kieHelper = new KieHelper(); + for (String drl : drls) { + kieHelper.addContent(drl, ResourceType.DRL); + } + return kieHelper; + } + + protected synchronized KieBase getKieBase(String... drls) { + KieBaseOption[] kieBaseOptions = (enforcedJitting) ? new KieBaseOption[] {ConstraintJittingThresholdOption.get(0)} + : new KieBaseOption[] {}; + KieBase kieBase = getKieHelper(drls).build(kieBaseOptions); + if (serializeKieBase) { + kieBase = serializeAndDeserializeKieBase(kieBase); + } + return kieBase; + } + + private KieBase serializeAndDeserializeKieBase(final KieBase kieBase) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(baos); + try { + out.writeObject(kieBase); + out.flush(); + } finally { + out.close(); } - }; - } - return tasks; - } - - private static synchronized KieHelper getKieHelper(String... drls) { - KieHelper kieHelper = new KieHelper(); - for (String drl : drls) { - kieHelper.addContent(drl, ResourceType.DRL); - } - return kieHelper; - } - - protected synchronized KieBase getKieBase(String... drls) { - KieBaseOption[] kieBaseOptions = (enforcedJitting) ? new KieBaseOption[]{ConstraintJittingThresholdOption.get(0)} - : new KieBaseOption[]{}; - KieBase kieBase = getKieHelper(drls).build(kieBaseOptions); - if (serializeKieBase) { - kieBase = serializeAndDeserializeKieBase(kieBase); - } - return kieBase; - } - - private KieBase serializeAndDeserializeKieBase(final KieBase kieBase) { - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream out = new ObjectOutputStream( baos ); - try { - out.writeObject( kieBase ); - out.flush(); - } finally { - out.close(); - } - - ObjectInputStream in = new ObjectInputStream( new ByteArrayInputStream( baos.toByteArray() ) ); - try { - return (KieBase) in.readObject(); - } finally { - in.close(); - } - - } catch (IOException e) { - throw new RuntimeException(e); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - } - - protected static void disposeSession(KieSession session) { - if (session != null) { - session.dispose(); - } - } - - protected void checkList(int end, List list) { - checkList(0, end, list); - } - - protected void checkList(int start, int end, List list) { - int expectedSize = end-start; - checkList(start, end, list, expectedSize); - } - - protected void checkList(int start, int end, List list, int expectedSize) { - assertEquals(expectedSize, list.size()); - for (int i = start; i < end; i++) { - assertTrue(list.contains(""+i)); - } - } - - public static class BeanA { - int seed; - public BeanA() { - this.seed = 1; - } - public BeanA(int seed) { - this.seed = seed; - } - public int getSeed() { - return this.seed; - } - public void setSeed(int seed) { - this.seed = seed; - } - } - - public static class BeanB { - int seed; - public BeanB() { - this.seed = 1; - } - public BeanB(int seed) { - this.seed = seed; - } - public int getSeed() { - return this.seed; - } - public void setSeed(int seed) { - this.seed = seed; - } - } + + ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())); + try { + return (KieBase) in.readObject(); + } finally { + in.close(); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + protected static void disposeSession(KieSession session) { + if (session != null) { + session.dispose(); + } + } + + protected void checkList(int end, List list) { + checkList(0, end, list); + } + + protected void checkList(int start, int end, List list) { + int expectedSize = end - start; + checkList(start, end, list, expectedSize); + } + + protected void checkList(int start, int end, List list, int expectedSize) { + assertEquals(expectedSize, list.size()); + for (int i = start; i < end; i++) { + assertTrue(list.contains("" + i)); + } + } + + public static class BeanA { + int seed; + + public BeanA() { + this.seed = 1; + } + + public BeanA(int seed) { + this.seed = seed; + } + + public int getSeed() { + return this.seed; + } + + public void setSeed(int seed) { + this.seed = seed; + } + } + + public static class BeanB { + int seed; + + public BeanB() { + this.seed = 1; + } + + public BeanB(int seed) { + this.seed = seed; + } + + public int getSeed() { + return this.seed; + } + + public void setSeed(int seed) { + this.seed = seed; + } + } } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java index 2b2d27239ace..9a1ec0f25d71 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java @@ -22,648 +22,685 @@ import java.util.ArrayList; import java.util.List; -import java.util.concurrent.Callable; -import java.util.concurrent.CompletionService; -import java.util.concurrent.ExecutorCompletionService; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class ConcurrentBasesParallelTest extends AbstractParallelTest { - public ConcurrentBasesParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); - } - - @Test - public void testDifferentRuleset1() throws InterruptedException { - int numberOfThreads = 100; - int numberOfObjects = 100; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule Rule_" + counter + " " + - "when " + - " BeanA( seed == " + counter + ") " + - "then " + - "end"; - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - for (int i = 0; i < numberOfObjects; i++) { - session.insert(new BeanA(i)); - } - return session.fireAllRules() == 1; - } finally { - disposeSession(session); + public ConcurrentBasesParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testDifferentRuleset1() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA( seed == " + counter + ") " + + "then " + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfObjects; i++) { + session.insert(new BeanA(i)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRuleset2() throws InterruptedException { - int numberOfThreads = 100; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule Rule_" + counter + " " + - "when " + - " BeanA( seed == " + counter + ") " + - "then " + - "end"; - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - for (int i = 0; i < numberOfThreads; i++) { - if (i != counter) { - session.insert(new BeanA(i)); - } - } - return session.fireAllRules() == 0; - - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset2() throws InterruptedException { + int numberOfThreads = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA( seed == " + counter + ") " + + "then " + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfThreads; i++) { + if (i != counter) { + session.insert(new BeanA(i)); + } + } + return session.fireAllRules() == 0; + + } finally { + disposeSession(session); + } } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRuleset3() throws InterruptedException { - int numberOfThreads = 100; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global "+ AtomicInteger.class.getCanonicalName() + " result;\n" + - "rule Rule_" + counter + " " + - "when " + - " BeanA()" + - "then " + - " result.set(" + counter + ");" + - "end"; - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - session.insert(new BeanA()); - AtomicInteger r = new AtomicInteger(0); - session.setGlobal("result", r); - assertEquals(1, session.fireAllRules()); - assertEquals(counter, r.get()); - return true; - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset3() throws InterruptedException { + int numberOfThreads = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global " + AtomicInteger.class.getCanonicalName() + " result;\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA()" + + "then " + + " result.set(" + counter + ");" + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + AtomicInteger r = new AtomicInteger(0); + session.setGlobal("result", r); + assertEquals(1, session.fireAllRules()); + assertEquals(counter, r.get()); + return true; + } finally { + disposeSession(session); + } } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRuleset4() throws InterruptedException { - int numberOfThreads = 100; - - String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + - "import " + BeanB.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule ${ruleName} " + - "when " + - "${className}()" + - "then " + - " list.add(\"${className}\");" + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - - String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; - String ruleName = "Rule_" + className + "_" + counter; - String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - session.insert(new BeanA()); - session.insert(new BeanB()); - List list = new ArrayList<>(); - session.setGlobal("list", list); - int rulesFired = session.fireAllRules(); - assertEquals(1, list.size()); - assertEquals(className, list.get(0)); - return rulesFired == 1; - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset4() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + " list.add(\"${className}\");" + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + session.insert(new BeanB()); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = session.fireAllRules(); + assertEquals(1, list.size()); + assertEquals(className, list.get(0)); + return rulesFired == 1; + } finally { + disposeSession(session); + } } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRuleset5() throws InterruptedException { - int numberOfThreads = 100; - - String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + - "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule ${ruleName} " + - "when " + - "${className}()" + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - - String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; - String ruleName = "Rule_" + className + "_" + counter; - String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter % 2 == 0) { - session.insert(new BeanB()); - } else { - session.insert(new BeanA()); - } - int rulesFired = session.fireAllRules(); - assertEquals(0, rulesFired); - return true; - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset5() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanB()); + } else { + session.insert(new BeanA()); + } + int rulesFired = session.fireAllRules(); + assertEquals(0, rulesFired); + return true; + } finally { + disposeSession(session); + } } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRuleset6() throws InterruptedException { - int numberOfThreads = 100; - - String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + - "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule ${ruleName} " + - "when " + - "${className}()" + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; - String ruleName = "Rule_" + className + "_" + counter; - String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter % 2 == 0) { - session.insert(new BeanA()); - } else { - session.insert(new BeanB()); - } - int rulesFired = session.fireAllRules(); - assertEquals(1, rulesFired); - return true; - } finally { - disposeSession(session); - } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRulesetNot() throws InterruptedException { - int numberOfThreads = 100; - - String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + - "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule ${ruleName} " + - "when " + - " not ${className}()" + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - - String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; - String ruleName = "Rule_" + className + "_" + counter; - String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter % 2 == 0) { - session.insert(new BeanA()); - } else { - session.insert(new BeanB()); - } - int rulesFired = session.fireAllRules(); - assertEquals(0, rulesFired); - return true; - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset6() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(1, rulesFired); + return true; + } finally { + disposeSession(session); + } } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetNot() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " not ${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(0, rulesFired); + return true; + } finally { + disposeSession(session); + } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRulesetExists() throws InterruptedException { - int numberOfThreads = 100; - - String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + - "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule ${ruleName} " + - "when " + - " exists ${className}()" + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - - String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; - String ruleName = "Rule_" + className + "_" + counter; - String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); - - KieBase base = getKieBase(rule); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter % 2 == 0) { - session.insert(new BeanA()); - } else { - session.insert(new BeanB()); - } - int rulesFired = session.fireAllRules(); - assertEquals(1, rulesFired); - return true; - } finally { - disposeSession(session); } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetExists() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " exists ${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(1, rulesFired); + return true; + } finally { + disposeSession(session); + } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testDifferentRulesetSharedSubnetwork() throws InterruptedException { - int numberOfThreads = 100; - - String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + - "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule ${ruleName} " + - "when " + - " $bean : ${className}() \n" + - "then " + - "end"; - - String subnetworkRuleTemplate = "rule Rule_subnetwork " + - "when " + - " $bean : ${className}() \n" + - " Number( doubleValue > 0) from" + - " accumulate ( BeanA() and $s : String(), count($s) )" + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - - String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; - String ruleName = "Rule_" + className + "_" + counter; - String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); - String subnetworkRule = subnetworkRuleTemplate.replace("${className}", className); - - KieBase base = getKieBase(rule, subnetworkRule); - KieSession session = null; - - try { - session = base.newKieSession(); - session.insert("test"); - if (counter % 2 == 0) { - session.insert(new BeanA()); - assertEquals(2, session.fireAllRules()); - } else { - session.insert(new BeanB()); - assertEquals(1, session.fireAllRules()); - } - return true; - } finally { - disposeSession(session); } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetSharedSubnetwork() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " $bean : ${className}() \n" + + "then " + + "end"; + + String subnetworkRuleTemplate = "rule Rule_subnetwork " + + "when " + + " $bean : ${className}() \n" + + " Number( doubleValue > 0) from" + + " accumulate ( BeanA() and $s : String(), count($s) )" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + String subnetworkRule = subnetworkRuleTemplate.replace("${className}", className); + + KieBase base = getKieBase(rule, subnetworkRule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert("test"); + if (counter % 2 == 0) { + session.insert(new BeanA()); + assertEquals(2, session.fireAllRules()); + } else { + session.insert(new BeanB()); + assertEquals(1, session.fireAllRules()); + } + return true; + } finally { + disposeSession(session); + } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testSameRuleset1() throws InterruptedException { - int numberOfThreads = 100; - - String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule RuleA " + - "when " + - " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + - "then " + - "end"; - - String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule RuleB " + - "when " + - " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + - "then " + - "end"; - - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - - KieBase base = getKieBase(ruleA, ruleB); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter %2 == 0) { - session.insert(new BeanA(counter)); - return session.fireAllRules() == 1; - } else { - return session.fireAllRules() == 0; - } - } finally { - disposeSession(session); } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testSameRuleset2() throws InterruptedException { - int numberOfThreads = 100; - - String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule RuleA " + - "when " + - " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + - "then " + - "end"; - - String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule RuleB " + - "when " + - " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - KieBase base = getKieBase(ruleA, ruleB); - KieSession session = null; - - try { - session = base.newKieSession(); - session.insert(new BeanA()); - session.insert(new BeanB()); - return session.fireAllRules() == 2; - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset1() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + return session.fireAllRules() == 1; + } else { + return session.fireAllRules() == 0; + } + } finally { + disposeSession(session); + } } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testSameRuleset3() throws InterruptedException { - int numberOfThreads = 100; - - String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule RuleA " + - "when " + - " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + - "then " + - "end"; - - String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule RuleB " + - "when " + - " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - KieBase base = getKieBase(ruleA, ruleB); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter %2 == 0) { - session.insert(new BeanA(counter)); - } else { - session.insert(new BeanB(counter)); - } - return session.fireAllRules() == 1; - } finally { - disposeSession(session); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset2() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + session.insert(new BeanB()); + return session.fireAllRules() == 2; + } finally { + disposeSession(session); + } } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset3() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + } else { + session.insert(new BeanB(counter)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testSameRuleset4() throws InterruptedException { - int numberOfThreads = 100; - - String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule RuleNotA " + - "when " + - " not BeanA() " + - "then " + - "end"; - - String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule RuleNotB " + - "when " + - " not BeanB() " + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - KieBase base = getKieBase(ruleA, ruleB); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter %2 == 0) { - session.insert(new BeanA(counter)); - } else { - session.insert(new BeanB(counter)); - } - return session.fireAllRules() == 1; - } finally { - disposeSession(session); } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset4() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleNotA " + + "when " + + " not BeanA() " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleNotB " + + "when " + + " not BeanB() " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + } else { + session.insert(new BeanB(counter)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } - } - }; - - parallelTest(numberOfThreads, exec); - } - - @Test - public void testSameRuleset5() throws InterruptedException { - int numberOfThreads = 100; - - String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + - "rule RuleNotA " + - "when " + - " not BeanA() " + - "then " + - "end"; - - String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + - "rule RuleNotB " + - "when " + - " not BeanB() " + - "then " + - "end"; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - KieBase base = getKieBase(ruleA, ruleB); - KieSession session = null; - - try { - session = base.newKieSession(); - if (counter %2 == 0) { - session.insert(new BeanA(counter)); - session.insert(new BeanB(counter)); - return session.fireAllRules() == 0; - } else { - return session.fireAllRules() == 2; - } - } finally { - disposeSession(session); } - } - }; - - parallelTest(numberOfThreads, exec); - } - - public void parallelTest(int numberOfThreads, ParallelTestExecutor executor) throws InterruptedException { - - Callable[] tasks = new Callable[numberOfThreads]; - - final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads, new ThreadFactory() { - @Override - public Thread newThread(Runnable r) { - final Thread t = new Thread(r); - t.setDaemon(true); - return t; - } - }); - - try { - for (int i = 0; i < numberOfThreads; i++) { - final int counter = i; - tasks[counter] = new Callable() { - @Override - public Boolean call() throws Exception { - return executor.execute(counter); - } - }; - } - - final CompletionService completionService = new ExecutorCompletionService(executorService); - - for (Callable task : tasks) { - completionService.submit(task); - } - - int successCounter = 0; - for (int i = 0; i < numberOfThreads; i++) { - try { - if (completionService.take().get()) { - successCounter ++; - } - } catch (Exception ex) { - throw new RuntimeException(ex); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset5() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleNotA " + + "when " + + " not BeanA() " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleNotB " + + "when " + + " not BeanB() " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + session.insert(new BeanB(counter)); + return session.fireAllRules() == 0; + } else { + return session.fireAllRules() == 2; + } + } finally { + disposeSession(session); + } } - } - - assertEquals(numberOfThreads, successCounter); + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testFunctions() throws InterruptedException { + int numberOfThreads = 100; + + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;" + + "rule Rule " + + "when " + + " BeanA() " + + "then " + + " addToList(list);" + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + String function = "import java.util.List;\n" + + "function void addToList(List list) { \n" + + " list.add( \"" + counter + "\" );\n" + + "}\n"; + + KieBase base = getKieBase(rule, function); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = session.fireAllRules(); + assertEquals(1, list.size()); + assertEquals(""+counter, list.get(0)); + return rulesFired == 1; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testFunctions2() throws InterruptedException { + int numberOfThreads = 100; + int objectCount = 100; + + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;" + + "rule Rule " + + "when " + + " BeanA() " + + "then " + + " addToList(list);" + + "end"; + + String functionTemplate = "import java.util.List;\n" + + "function void addToList(List list) { \n" + + " list.add( \"${identifier}\" );\n" + + "}\n"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + + String identifier = (counter%2 == 0) ? "even" : "odd"; + String otherIdentifier = (counter%2 == 0) ? "odd" : "even"; + String functionRule = functionTemplate.replace("${identifier}", identifier); + + KieBase base = getKieBase(rule, functionRule); + KieSession session = null; + + try { + session = base.newKieSession(); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = 0; + for (int i = 0; i < objectCount; i++) { + session.insert(new BeanA(i)); + rulesFired += session.fireAllRules(); + } + assertEquals(objectCount, list.size()); + assertTrue(list.contains(identifier)); + assertTrue(!list.contains(otherIdentifier)); + return rulesFired == objectCount; + } finally { + disposeSession(session); + } + } + }; - } finally { - executorService.shutdown(); - if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { - executorService.shutdownNow(); - } - } - } + parallelTest(numberOfThreads, exec); + } } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java index c59f2cace0ff..3c1b0709c547 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java @@ -38,136 +38,137 @@ public class DataTypeEvaluationSharedSessionParallelTest extends AbstractParallelTest { - public DataTypeEvaluationSharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); - } - - @Test - public void testBooleanPrimitive() throws InterruptedException { - testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanValue == false) \n", new FactWithBoolean(false)); - } - - @Test - public void testBoolean() throws InterruptedException { - testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanObjectValue == false) \n", new FactWithBoolean(false)); - } - - @Test - public void testBytePrimitive() throws InterruptedException { - testFactAttributeType(" $factWithByte: FactWithByte(byteValue == 15) \n", new FactWithByte((byte) 15)); - } - - @Test - public void testByte() throws InterruptedException { - testFactAttributeType(" $factWithByte: FactWithByte(byteObjectValue == 15) \n", new FactWithByte((byte) 15)); - } - - @Test - public void testShortPrimitive() throws InterruptedException { - testFactAttributeType(" $factWithShort: FactWithShort(shortValue == 15) \n", new FactWithShort((short) 15)); - } - - @Test - public void testShort() throws InterruptedException { - testFactAttributeType(" $factWithShort: FactWithShort(shortObjectValue == 15) \n", new FactWithShort((short) 15)); - } - - @Test - public void testIntPrimitive() throws InterruptedException { - testFactAttributeType(" $factWithInt: FactWithInteger(intValue == 15) \n", new FactWithInteger(15)); - } - - @Test - public void testInteger() throws InterruptedException { - testFactAttributeType(" $factWithInteger: FactWithInteger(integerValue == 15) \n", new FactWithInteger(15)); - } - - @Test - public void testLongPrimitive() throws InterruptedException { - testFactAttributeType(" $factWithLong: FactWithLong(longValue == 15) \n", new FactWithLong(15)); - } - - @Test - public void testLong() throws InterruptedException { - testFactAttributeType(" $factWithLong: FactWithLong(longObjectValue == 15) \n", new FactWithLong(15)); - } - - @Test - public void testFloatPrimitive() throws InterruptedException { - testFactAttributeType(" $factWithFloat: FactWithFloat(floatValue == 15.1) \n", new FactWithFloat(15.1f)); - } - - @Test - public void testFloat() throws InterruptedException { - testFactAttributeType(" $factWithFloat: FactWithFloat(floatObjectValue == 15.1) \n", new FactWithFloat(15.1f)); - } - - @Test - public void testDoublePrimitive() throws InterruptedException { - testFactAttributeType(" $factWithDouble: FactWithDouble(doubleValue == 15.1) \n", new FactWithDouble(15.1d)); - } - - @Test - public void testDouble() throws InterruptedException { - testFactAttributeType(" $factWithDouble: FactWithDouble(doubleObjectValue == 15.1) \n", new FactWithDouble(15.1d)); - } - - @Test - public void testBigDecimal() throws InterruptedException { - testFactAttributeType(" $factWithBigDecimal: FactWithBigDecimal(bigDecimalValue == 10) \n", new FactWithBigDecimal(BigDecimal.TEN)); - } - - @Test - public void testCharPrimitive() throws InterruptedException { - testFactAttributeType(" $factWithChar: FactWithCharacter(charValue == 'a') \n", new FactWithCharacter('a')); - } - - @Test - public void testCharacter() throws InterruptedException { - testFactAttributeType(" $factWithChar: FactWithCharacter(characterValue == 'a') \n", new FactWithCharacter('a')); - } - - @Test - public void testString() throws InterruptedException { - testFactAttributeType(" $factWithString: FactWithString(stringValue == \"test\") \n", new FactWithString("test")); - } - - @Test - public void testEnum() throws InterruptedException { - testFactAttributeType(" $factWithEnum: FactWithEnum(enumValue == AnEnum.FIRST) \n", new FactWithEnum(AnEnum.FIRST)); - } - - private void testFactAttributeType(final String ruleConstraint, final Object factInserted) throws InterruptedException { - int numberOfThreads = 10; - - - final String drl = - " import org.drools.compiler.integrationtests.facts.*;\n" + - " global " + AtomicInteger.class.getCanonicalName() + " numberOfFirings;\n" + - " rule R1 \n" + - " when \n" + - ruleConstraint + - " then \n" + - " numberOfFirings.incrementAndGet(); \n" + - " end "; - - KieSession kieSession = getKieBase(drl).newKieSession(); - - AtomicInteger numberOfFirings = new AtomicInteger(0); - - parallelTest(numberOfThreads, new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - if (kieSession.getGlobal("numberOfFirings") == null) { - kieSession.setGlobal("numberOfFirings", numberOfFirings); - }; - kieSession.insert(factInserted); - kieSession.fireAllRules(); - return true; - } - }); - disposeSession(kieSession); - assertEquals(1, numberOfFirings.get()); - } + public DataTypeEvaluationSharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testBooleanPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanValue == false) \n", new FactWithBoolean(false)); + } + + @Test + public void testBoolean() throws InterruptedException { + testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanObjectValue == false) \n", new FactWithBoolean(false)); + } + + @Test + public void testBytePrimitive() throws InterruptedException { + testFactAttributeType(" $factWithByte: FactWithByte(byteValue == 15) \n", new FactWithByte((byte) 15)); + } + + @Test + public void testByte() throws InterruptedException { + testFactAttributeType(" $factWithByte: FactWithByte(byteObjectValue == 15) \n", new FactWithByte((byte) 15)); + } + + @Test + public void testShortPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithShort: FactWithShort(shortValue == 15) \n", new FactWithShort((short) 15)); + } + + @Test + public void testShort() throws InterruptedException { + testFactAttributeType(" $factWithShort: FactWithShort(shortObjectValue == 15) \n", new FactWithShort((short) 15)); + } + + @Test + public void testIntPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithInt: FactWithInteger(intValue == 15) \n", new FactWithInteger(15)); + } + + @Test + public void testInteger() throws InterruptedException { + testFactAttributeType(" $factWithInteger: FactWithInteger(integerValue == 15) \n", new FactWithInteger(15)); + } + + @Test + public void testLongPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithLong: FactWithLong(longValue == 15) \n", new FactWithLong(15)); + } + + @Test + public void testLong() throws InterruptedException { + testFactAttributeType(" $factWithLong: FactWithLong(longObjectValue == 15) \n", new FactWithLong(15)); + } + + @Test + public void testFloatPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithFloat: FactWithFloat(floatValue == 15.1) \n", new FactWithFloat(15.1f)); + } + + @Test + public void testFloat() throws InterruptedException { + testFactAttributeType(" $factWithFloat: FactWithFloat(floatObjectValue == 15.1) \n", new FactWithFloat(15.1f)); + } + + @Test + public void testDoublePrimitive() throws InterruptedException { + testFactAttributeType(" $factWithDouble: FactWithDouble(doubleValue == 15.1) \n", new FactWithDouble(15.1d)); + } + + @Test + public void testDouble() throws InterruptedException { + testFactAttributeType(" $factWithDouble: FactWithDouble(doubleObjectValue == 15.1) \n", new FactWithDouble(15.1d)); + } + + @Test + public void testBigDecimal() throws InterruptedException { + testFactAttributeType(" $factWithBigDecimal: FactWithBigDecimal(bigDecimalValue == 10) \n", new FactWithBigDecimal(BigDecimal.TEN)); + } + + @Test + public void testCharPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithChar: FactWithCharacter(charValue == 'a') \n", new FactWithCharacter('a')); + } + + @Test + public void testCharacter() throws InterruptedException { + testFactAttributeType(" $factWithChar: FactWithCharacter(characterValue == 'a') \n", new FactWithCharacter('a')); + } + + @Test + public void testString() throws InterruptedException { + testFactAttributeType(" $factWithString: FactWithString(stringValue == \"test\") \n", new FactWithString("test")); + } + + @Test + public void testEnum() throws InterruptedException { + testFactAttributeType(" $factWithEnum: FactWithEnum(enumValue == AnEnum.FIRST) \n", new FactWithEnum(AnEnum.FIRST)); + } + + private void testFactAttributeType(final String ruleConstraint, final Object factInserted) throws InterruptedException { + int numberOfThreads = 10; + + + final String drl = + " import org.drools.compiler.integrationtests.facts.*;\n" + + " global " + AtomicInteger.class.getCanonicalName() + " numberOfFirings;\n" + + " rule R1 \n" + + " when \n" + + ruleConstraint + + " then \n" + + " numberOfFirings.incrementAndGet(); \n" + + " end "; + + KieSession kieSession = getKieBase(drl).newKieSession(); + + AtomicInteger numberOfFirings = new AtomicInteger(0); + + parallelTest(numberOfThreads, new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + if (kieSession.getGlobal("numberOfFirings") == null) { + kieSession.setGlobal("numberOfFirings", numberOfFirings); + } + ; + kieSession.insert(factInserted); + kieSession.fireAllRules(); + return true; + } + }); + disposeSession(kieSession); + assertEquals(1, numberOfFirings.get()); + } } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java index ec8fe1aee744..b5174ab32a53 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java @@ -33,456 +33,456 @@ public class SharedSessionParallelTest extends AbstractParallelTest { - public SharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); - } + public SharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testNoExceptions() throws InterruptedException { + String drl = "rule R1 when String() then end"; + + int repetitions = 100; + int numberOfObjects = 1000; + int countOfThreads = 100; + + for (int i = 0; i < repetitions; i++) { + + KieSession kieSession = getKieBase(drl).newKieSession(); + + parallelTest(countOfThreads, new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + for (int j = 0; j < numberOfObjects; j++) { + kieSession.insert("test_" + numberOfObjects); + } + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }); + + disposeSession(kieSession); + } + } + + @Test + public void testCheckOneThreadOnly() throws InterruptedException { + int threadCount = 100; + List list = new ArrayList<>(); + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed) " + + "then " + + " list.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + CountDownLatch latch = new CountDownLatch(threadCount); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter)); + latch.countDown(); + + if (counter == 0) { + latch.await(); + return kieSession.fireAllRules() == threadCount; + } + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); - @Test - public void testNoExceptions() throws InterruptedException { - String drl = "rule R1 when String() then end"; + assertEquals(threadCount, list.size()); + for (int i = 0; i < threadCount; i++) { + assertTrue(list.contains("" + i)); + } + } - int repetitions = 100; - int numberOfObjects = 1000; - int countOfThreads = 100; + @Test + public void testCorrectFirings() throws InterruptedException { + int threadCount = 100; - for (int i = 0; i < repetitions; i++) { + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List globalList;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed) " + + "then " + + " globalList.add(\"\" + $n);" + + "end"; - KieSession kieSession = getKieBase(drl).newKieSession(); + KieSession kieSession = getKieBase(drl).newKieSession(); - parallelTest(countOfThreads, new ParallelTestExecutor() { + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { @Override public boolean execute(int counter) { - try { - for (int j = 0; j < numberOfObjects; j++) { - kieSession.insert("test_" + numberOfObjects); - } - kieSession.fireAllRules(); - return true; - } catch (Exception ex) { - throw new RuntimeException(ex); - } + kieSession.setGlobal("globalList", list); + kieSession.insert(new BeanA(counter)); + kieSession.fireAllRules(); + return true; } - }); - - disposeSession(kieSession); - } - } - - @Test - public void testCheckOneThreadOnly() throws InterruptedException { - int threadCount = 100; - List list = new ArrayList<>(); - - String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule R1 " + - "when " + - " BeanA($n : seed) " + - "then " + - " list.add(\"\" + $n);" + - "end"; - - KieSession kieSession = getKieBase(drl).newKieSession(); - CountDownLatch latch = new CountDownLatch(threadCount); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) throws InterruptedException { - kieSession.setGlobal("list", list); - kieSession.insert(new BeanA(counter)); - latch.countDown(); - - if (counter == 0) { - latch.await(); - return kieSession.fireAllRules() == threadCount; + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, list); + } + + @Test + public void testCorrectFirings2() throws InterruptedException { + int threadCount = 100; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed, seed == 0) " + + "then " + + " list.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter % 2)); + kieSession.fireAllRules(); + return true; } - return true; - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - - assertEquals(threadCount, list.size()); - for (int i = 0; i < threadCount; i++) { - assertTrue(list.contains(""+i)); - } - } - - @Test - public void testCorrectFirings() throws InterruptedException { - int threadCount = 100; - - String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List globalList;\n" + - "rule R1 " + - "when " + - " BeanA($n : seed) " + - "then " + - " globalList.add(\"\" + $n);" + - "end"; - - KieSession kieSession = getKieBase(drl).newKieSession(); - - List list = new ArrayList<>(); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - kieSession.setGlobal("globalList", list); - kieSession.insert(new BeanA(counter)); - kieSession.fireAllRules(); - return true; - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - checkList(threadCount, list); - } - - @Test - public void testCorrectFirings2() throws InterruptedException { - int threadCount = 100; - - String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule R1 " + - "when " + - " BeanA($n : seed, seed == 0) " + - "then " + - " list.add(\"\" + $n);" + - "end"; - - KieSession kieSession = getKieBase(drl).newKieSession(); - List list = new ArrayList<>(); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - kieSession.setGlobal("list", list); - kieSession.insert(new BeanA(counter%2)); - kieSession.fireAllRules(); - return true; - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - assertTrue(list.contains("" + 0)); - assertFalse(list.contains("" + 1)); - int expectedListSize = ((threadCount-1) / 2) + 1; - assertEquals(expectedListSize, list.size()); - } - - - @Test - public void testLongRunningRule() throws InterruptedException { - int threadCount = 100; - int seed = threadCount + 200; - int objectCount = 1000; - - String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule longRunning " + - "when " + - " $bean : BeanA($n : seed, seed > " + threadCount + ") " + - "then " + - " modify($bean) { setSeed($n-1) };" + - " list.add(\"\" + $bean.getSeed());" + - " Thread.sleep(5);" + - "end"; - - String listDrl = "global java.util.List list2;\n" + - "rule listRule " + - "when " + - " BeanA($n : seed, seed < " + threadCount + ") " + - "then " + - " list2.add(\"\" + $n);" + - "end"; - - KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); - - CyclicBarrier barrier = new CyclicBarrier(threadCount); - List list = new ArrayList<>(); - List list2 = new ArrayList<>(); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - try { - if (counter == 0) { - kieSession.setGlobal("list", list); - kieSession.setGlobal("list2", list2); - kieSession.insert(new BeanA(seed)); - barrier.await(); - kieSession.fireAllRules(); - return true; - } else { - barrier.await(); - Thread.sleep(100); - for (int i = 0; i < objectCount; i++) { - kieSession.insert(new BeanA(counter)); - } - kieSession.fireAllRules(); - return true; - } - } catch (Exception ex) { - throw new RuntimeException(ex); + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + assertTrue(list.contains("" + 0)); + assertFalse(list.contains("" + 1)); + int expectedListSize = ((threadCount - 1) / 2) + 1; + assertEquals(expectedListSize, list.size()); + } + + + @Test + public void testLongRunningRule() throws InterruptedException { + int threadCount = 100; + int seed = threadCount + 200; + int objectCount = 1000; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > " + threadCount + ") " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + " Thread.sleep(5);" + + "end"; + + String listDrl = "global java.util.List list2;\n" + + "rule listRule " + + "when " + + " BeanA($n : seed, seed < " + threadCount + ") " + + "then " + + " list2.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + List list2 = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.setGlobal("list2", list2); + kieSession.insert(new BeanA(seed)); + barrier.await(); + kieSession.fireAllRules(); + return true; + } else { + barrier.await(); + Thread.sleep(100); + for (int i = 0; i < objectCount; i++) { + kieSession.insert(new BeanA(counter)); + } + kieSession.fireAllRules(); + return true; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } } - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - checkList(threadCount, seed, list); - checkList(1, threadCount, list2, (threadCount-1)*objectCount); - } - - @Test - public void testLongRunningRule2() throws InterruptedException { - int threadCount = 100; - int seed = 1000; - - String waitingRule = "rule waitingRule " + - "when " + - " String( this == \"wait\" ) " + - "then " + - " Thread.sleep(10);" + - "end"; - - String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule longRunning " + - "when " + - " $bean : BeanA($n : seed, seed > 0 ) " + - "then " + - " modify($bean) { setSeed($n-1) };" + - " list.add(\"\" + $bean.getSeed());" + - "end"; - - KieSession kieSession = getKieBase(longRunningDrl, waitingRule).newKieSession(); - - CyclicBarrier barrier = new CyclicBarrier(threadCount); - List list = new ArrayList<>(); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - try { - if (counter == 0) { - kieSession.setGlobal("list", list); - kieSession.insert("wait"); - kieSession.insert(new BeanA(seed)); - barrier.await(); - return kieSession.fireAllRules() == seed * threadCount + 1; - } else { - barrier.await(); - Thread.sleep(10); - kieSession.insert(new BeanA(seed)); - return kieSession.fireAllRules() == 0; - } - } catch (Exception ex) { - throw new RuntimeException(ex); + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, seed, list); + checkList(1, threadCount, list2, (threadCount - 1) * objectCount); + } + + @Test + public void testLongRunningRule2() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String waitingRule = "rule waitingRule " + + "when " + + " String( this == \"wait\" ) " + + "then " + + " Thread.sleep(10);" + + "end"; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, waitingRule).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.insert("wait"); + kieSession.insert(new BeanA(seed)); + barrier.await(); + return kieSession.fireAllRules() == seed * threadCount + 1; + } else { + barrier.await(); + Thread.sleep(10); + kieSession.insert(new BeanA(seed)); + return kieSession.fireAllRules() == 0; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } } - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - checkList(0, seed, list, seed * threadCount); - } - - @Test - public void testLongRunningRule3() throws InterruptedException { - int threadCount = 10; - int seed = threadCount + 50; - int objectCount = 1000; - - String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule longRunning " + - "when " + - " $bean : BeanA($n : seed, seed > " + threadCount + ") " + - "then " + - " modify($bean) { setSeed($n-1) };" + - " list.add(\"\" + $bean.getSeed());" + - "end"; - - String listDrl = "global java.util.List list2;\n" + - "rule listRule " + - "when " + - " BeanA($n : seed, seed < " + threadCount + ") " + - "then " + - " list2.add(\"\" + $n);" + - "end"; - - KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); - - CyclicBarrier barrier = new CyclicBarrier(threadCount); - List list = new ArrayList<>(); - List list2 = new ArrayList<>(); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - try { - if (counter % 2 == 0) { - kieSession.setGlobal("list", list); - kieSession.setGlobal("list2", list2); - kieSession.insert(new BeanA(seed)); - barrier.await(); - kieSession.fireAllRules(); - return true; - } else { - barrier.await(); - Thread.sleep(100); - for (int i = 0; i < objectCount; i++) { - kieSession.insert(new BeanA(counter)); - } - kieSession.fireAllRules(); - return true; - } - } catch (Exception ex) { - throw new RuntimeException(ex); + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(0, seed, list, seed * threadCount); + } + + @Test + public void testLongRunningRule3() throws InterruptedException { + int threadCount = 10; + int seed = threadCount + 50; + int objectCount = 1000; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > " + threadCount + ") " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + String listDrl = "global java.util.List list2;\n" + + "rule listRule " + + "when " + + " BeanA($n : seed, seed < " + threadCount + ") " + + "then " + + " list2.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + List list2 = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter % 2 == 0) { + kieSession.setGlobal("list", list); + kieSession.setGlobal("list2", list2); + kieSession.insert(new BeanA(seed)); + barrier.await(); + kieSession.fireAllRules(); + return true; + } else { + barrier.await(); + Thread.sleep(100); + for (int i = 0; i < objectCount; i++) { + kieSession.insert(new BeanA(counter)); + } + kieSession.fireAllRules(); + return true; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } } - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - - int listExpectedSize = (threadCount/2 + threadCount%2) * (seed-threadCount); - int list2ExpectedSize = threadCount/2 * objectCount; - for (int i = 0; i < threadCount; i++) { - if (i % 2 == 1) { - list2.contains("" + i); - } - } - assertEquals(listExpectedSize, list.size()); - assertEquals(list2ExpectedSize, list2.size()); - } - - @Test - public void testCountdownBean() throws InterruptedException { - int threadCount = 100; - int seed = 1000; - - String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule countdown " + - "when " + - " $bean : BeanA($n : seed, seed > 0 ) " + - "then " + - " modify($bean) { setSeed($n-1) };" + - " list.add(\"\" + $bean.getSeed());" + - "end"; - - KieSession kieSession = getKieBase(drl).newKieSession(); - CyclicBarrier barrier = new CyclicBarrier(threadCount); - List list = new ArrayList<>(); - BeanA bean = new BeanA(seed); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - try { - if (counter == 0) { - kieSession.setGlobal("list", list); - kieSession.insert(bean); - } - barrier.await(); - kieSession.fireAllRules(); - return true; - } catch (Exception ex) { - throw new RuntimeException(ex); + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + int listExpectedSize = (threadCount / 2 + threadCount % 2) * (seed - threadCount); + int list2ExpectedSize = threadCount / 2 * objectCount; + for (int i = 0; i < threadCount; i++) { + if (i % 2 == 1) { + list2.contains("" + i); } - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - checkList(seed, list); - assertEquals(0, bean.getSeed()); - } - - @Test - public void testCountdownBean2() throws InterruptedException { - int threadCount = 100; - int seed = 1000; - - String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule countdown " + - "when " + - " $bean : BeanA($n : seed, seed > 0 ) " + - "then " + - " modify($bean) { setSeed($n-1) };" + - " list.add(\"\" + $bean.getSeed());" + - "end"; - - KieSession kieSession = getKieBase(drl).newKieSession(); - List list = new ArrayList<>(); - BeanA[] beans = new BeanA[threadCount]; - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - BeanA bean = new BeanA(seed); - beans[counter] = bean; - try { - if (counter == 0) { - kieSession.setGlobal("list", list); - } - kieSession.insert(bean); - kieSession.fireAllRules(); - return true; - } catch (Exception ex) { - throw new RuntimeException(ex); + } + assertEquals(listExpectedSize, list.size()); + assertEquals(list2ExpectedSize, list2.size()); + } + + @Test + public void testCountdownBean() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule countdown " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + BeanA bean = new BeanA(seed); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.insert(bean); + } + barrier.await(); + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } } - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - - checkList(0, seed, list, seed * threadCount); - for (BeanA bean : beans) { - assertEquals(0, bean.getSeed()); - } - } - - @Test - public void testOneRulePerThread() throws InterruptedException { - int threadCount = 1000; - - String[] drls = new String[threadCount]; - for (int i = 0; i < threadCount; i++) { - drls[i] = "import " + BeanA.class.getCanonicalName() + ";\n" + - "global java.util.List list;\n" + - "rule R" + i + " " + - "when " + - " $bean : BeanA( seed == " + i + " ) " + - "then " + - " list.add(\"" + i + "\");" + - "end"; - } - - KieSession kieSession = getKieBase(drls).newKieSession(); - List list = new ArrayList<>(); - - ParallelTestExecutor exec = new ParallelTestExecutor() { - @Override - public boolean execute(int counter) { - kieSession.setGlobal("list", list); - kieSession.insert(new BeanA(counter)); - kieSession.fireAllRules(); - return true; - } - }; - - parallelTest(threadCount, exec); - disposeSession(kieSession); - checkList(threadCount, list); - } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(seed, list); + assertEquals(0, bean.getSeed()); + } + + @Test + public void testCountdownBean2() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule countdown " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + List list = new ArrayList<>(); + BeanA[] beans = new BeanA[threadCount]; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + BeanA bean = new BeanA(seed); + beans[counter] = bean; + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + } + kieSession.insert(bean); + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + checkList(0, seed, list, seed * threadCount); + for (BeanA bean : beans) { + assertEquals(0, bean.getSeed()); + } + } + + @Test + public void testOneRulePerThread() throws InterruptedException { + int threadCount = 1000; + + String[] drls = new String[threadCount]; + for (int i = 0; i < threadCount; i++) { + drls[i] = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R" + i + " " + + "when " + + " $bean : BeanA( seed == " + i + " ) " + + "then " + + " list.add(\"" + i + "\");" + + "end"; + } + + KieSession kieSession = getKieBase(drls).newKieSession(); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, list); + } } From 9c85a859c032908a0f355bc93df1d678011c4d35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Ma=C4=8Dkay?= Date: Thu, 16 Nov 2017 15:15:43 +0100 Subject: [PATCH 4/4] Add tests for queries to ConcurrentBasesParallelTest --- .../session/ConcurrentBasesParallelTest.java | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java index 9a1ec0f25d71..cd0ab8985af8 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java @@ -19,6 +19,8 @@ import org.junit.Test; import org.kie.api.KieBase; import org.kie.api.runtime.KieSession; +import org.kie.api.runtime.rule.QueryResults; +import org.kie.api.runtime.rule.QueryResultsRow; import java.util.ArrayList; import java.util.List; @@ -703,4 +705,84 @@ public boolean execute(int counter) throws InterruptedException { parallelTest(numberOfThreads, exec); } + + @Test + public void testQueries() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + String query = "import " + BeanA.class.getCanonicalName() + ";\n" + + "query Query " + + " bean : BeanA( seed == "+ counter +" ) " + + "end"; + + KieBase base = getKieBase(query); + KieSession session = null; + + try { + session = base.newKieSession(); + BeanA bean = new BeanA(counter); + session.insert(bean); + for (int i = 0; i < numberOfObjects; i++) { + if (i != counter) { + session.insert(new BeanA(i)); + } + } + QueryResults results = session.getQueryResults("Query"); + assertEquals(1, results.size()); + for (QueryResultsRow row : results) { + assertEquals(bean, (BeanA) row.get("bean")); + } + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testQueries2() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + String queryTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "query Query " + + " bean : BeanA( seed == ${seed} ) " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + int seed = counter % 2; + String seedString = "" + seed; + String queryDrl = queryTemplate.replace("${seed}", seedString); + KieBase base = getKieBase(queryDrl); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfObjects; i++) { + session.insert(new BeanA(seed)); + } + QueryResults results = session.getQueryResults("Query"); + assertEquals(numberOfObjects, results.size()); + for (QueryResultsRow row : results) { + BeanA bean = (BeanA) row.get("bean"); + assertEquals(seed, bean.getSeed()); + } + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } }