diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java index 6e12f882c4b9..26a0d26f0fa4 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadTest.java @@ -28,6 +28,7 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; @@ -67,7 +68,7 @@ public class MultithreadTest extends CommonTestMethodBase { private static final Logger LOG = LoggerFactory.getLogger(MultithreadTest.class); @Test(timeout = 10000) - public void testConcurrentInsertionsFewObjectsManyThreads() { + public void testConcurrentInsertionsFewObjectsManyThreads() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + "\n" + "rule \"R\"\n" + @@ -79,7 +80,7 @@ public void testConcurrentInsertionsFewObjectsManyThreads() { } @Test(timeout = 10000) - public void testConcurrentInsertionsManyObjectsFewThreads() { + public void testConcurrentInsertionsManyObjectsFewThreads() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + "\n" + "rule \"R\"\n" + @@ -91,7 +92,7 @@ public void testConcurrentInsertionsManyObjectsFewThreads() { } @Test(timeout = 10000) - public void testConcurrentInsertionsNewSessionEachThreadObjectTypeNode() { + public void testConcurrentInsertionsNewSessionEachThreadObjectTypeNode() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + " query existsBeanSeed5More() \n" + " Bean( seed > 5 ) \n" + @@ -112,7 +113,7 @@ public void testConcurrentInsertionsNewSessionEachThreadObjectTypeNode() { } @Test(timeout = 10000) - public void testConcurrentInsertionsNewSessionEachThread() { + public void testConcurrentInsertionsNewSessionEachThread() throws InterruptedException { final String drl = "import org.drools.compiler.integrationtests.MultithreadTest.Bean\n" + " query existsBeanSeed5More() \n" + " Bean( seed > 5 ) \n" + @@ -144,11 +145,11 @@ public void testConcurrentInsertionsNewSessionEachThread() { } private void testConcurrentInsertions(final String drl, final int objectCount, final int threadCount, - final boolean newSessionForEachThread, final boolean updateFacts) { + final boolean newSessionForEachThread, final boolean updateFacts) throws InterruptedException { final KieBase kieBase = new KieHelper().addContent(drl, ResourceType.DRL).build(); - Executor executor = Executors.newCachedThreadPool(new ThreadFactory() { + ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); @@ -157,37 +158,44 @@ public Thread newThread(Runnable r) { }); KieSession ksession = null; - Callable[] tasks = new Callable[threadCount]; - if (newSessionForEachThread) { - for (int i = 0; i < threadCount; i++) { - tasks[i] = getTask( objectCount, kieBase, updateFacts ); - } - } else { - ksession = kieBase.newKieSession(); - for (int i = 0; i < threadCount; i++) { - tasks[i] = getTask( objectCount, ksession, false , updateFacts ); + try { + Callable[] tasks = new Callable[threadCount]; + if (newSessionForEachThread) { + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, kieBase, updateFacts); + } + } else { + ksession = kieBase.newKieSession(); + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, ksession, false, updateFacts); + } } - } - CompletionService ecs = new ExecutorCompletionService(executor); - for (Callable task : tasks) { - ecs.submit( task ); - } + CompletionService ecs = new ExecutorCompletionService(executor); + for (Callable task : tasks) { + ecs.submit(task); + } - int successCounter = 0; - for (int i = 0; i < threadCount; i++) { - try { - if ( ecs.take().get() ) { - successCounter++; + int successCounter = 0; + for (int i = 0; i < threadCount; i++) { + try { + if (ecs.take().get()) { + successCounter++; + } + } catch (Exception e) { + throw new RuntimeException(e); } - } catch (Exception e) { - throw new RuntimeException(e); } - } - assertEquals(threadCount, successCounter); - if (ksession != null) { - ksession.dispose(); + assertEquals(threadCount, successCounter); + if (ksession != null) { + ksession.dispose(); + } + } finally { + executor.shutdown(); + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } } } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java new file mode 100644 index 000000000000..024fe82a1318 --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/MultithreadedSubnetworkTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests; + + +import org.drools.compiler.CommonTestMethodBase; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.kie.api.KieBase; +import org.kie.api.io.ResourceType; +import org.kie.api.runtime.KieSession; +import org.kie.api.runtime.rule.FactHandle; +import org.kie.internal.utils.KieHelper; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +@RunWith(Parameterized.class) +public class MultithreadedSubnetworkTest extends CommonTestMethodBase { + + protected final String drl; + + @Parameters(name = "DRL={0}") + public static List getTestParameter() { + return Arrays.asList( + new String[] {"sharedSubnetworkRule", sharedSubnetworkRule}, + new String[] {"noSharingSubnetworkRule", noSharingSubnetworkRule}, + new String[] {"notSubnetworkRule", notSubnetworkRule}, + new String[] {"existsSubnetworkRule", existsSubnetworkRule}); + } + + public MultithreadedSubnetworkTest(String drlName, final String drl) { + this.drl = drl; + } + + final static String sharedSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1y when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + + " Long()\n" + + "then\n" + + " System.out.println(\"R1y\");" + + "end\n" + + "\n" + + "rule R1x when\n" + + " AtomicInteger( get() == 1 ) \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )\n" + + "then\n" + + " System.out.println(\"R1x\");" + + "end\n" + + "" + + "rule R2 when\n" + + " $i : AtomicInteger( get() < 3 )\n" + + "then\n" + + " System.out.println(\"R2\");" + + " $i.incrementAndGet();" + + " update($i);" + + "end\n"; + + final static String noSharingSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1y when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( AtomicInteger() and $s : String( this == \"test_1\" ) ; count($s) )" + + " Long()\n" + + "then\n" + + " System.out.println(\"R1y\");" + + "end\n" + + "\n" + + "rule R1x when\n" + + " AtomicInteger() \n" + + " Number() from accumulate ( $i : AtomicInteger( get() == 1) and String( this == \"test_2\" ) ; count($i) )\n" + + "then\n" + + " System.out.println(\"R1x\");" + + "end\n" + + "" + + "rule R2 when\n" + + " $i : AtomicInteger( get() < 3 )\n" + + "then\n" + + " System.out.println(\"R2\");" + + " $i.incrementAndGet();" + + " update($i);" + + "end\n"; + + final static String notSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " AtomicInteger() \n" + + " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R1\");" + + "end\n" + + "\n" + + "rule R2 when\n" + + " AtomicInteger() \n" + + " String( this != \"test_2\" ) \n" + + " not(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R2\");" + + "end\n"; + + final static String existsSubnetworkRule = + "import " + AtomicInteger.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " AtomicInteger() \n" + + " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R1\");" + + "end\n" + + "\n" + + "rule R2 when\n" + + " AtomicInteger() \n" + + " String( this != \"test_2\" ) \n" + + " exists(AtomicInteger( get() == 1 ) and String( this == \"test_1\" )) \n" + + "then\n" + + " System.out.println(\"R2\");" + + "end\n"; + + @Test(timeout = 10000) + public void testConcurrentInsertionsFewObjectsManyThreads() throws InterruptedException { + testConcurrentInsertions(drl, 1, 1000, false, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsManyObjectsFewThreads() throws InterruptedException { + testConcurrentInsertions(drl, 500, 4, false, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsManyObjectsSingleThread() throws InterruptedException { + testConcurrentInsertions(drl, 1000, 1, false, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsNewSessionEachThread() throws InterruptedException { + testConcurrentInsertions(drl, 10, 1000, true, false); + } + + @Test(timeout = 10000) + public void testConcurrentInsertionsNewSessionEachThreadUpdate() throws InterruptedException { + testConcurrentInsertions(drl, 10, 1000, true, true); + } + + private void testConcurrentInsertions(final String drl, final int objectCount, final int threadCount, + final boolean newSessionForEachThread, final boolean updateFacts) throws InterruptedException { + + final KieBase kieBase = new KieHelper().addContent(drl, ResourceType.DRL).build(); + + ExecutorService executor = Executors.newCachedThreadPool(new ThreadFactory() { + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setDaemon(true); + return t; + } + }); + + KieSession ksession = null; + try { + Callable[] tasks = new Callable[threadCount]; + if (newSessionForEachThread) { + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, kieBase, updateFacts); + } + } else { + ksession = kieBase.newKieSession(); + for (int i = 0; i < threadCount; i++) { + tasks[i] = getTask(objectCount, ksession, false, updateFacts); + } + } + + CompletionService ecs = new ExecutorCompletionService(executor); + for (Callable task : tasks) { + ecs.submit(task); + } + + int successCounter = 0; + for (int i = 0; i < threadCount; i++) { + try { + if (ecs.take().get()) { + successCounter++; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + assertEquals(threadCount, successCounter); + if (ksession != null) { + ksession.dispose(); + } + } finally { + executor.shutdown(); + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } + } + + private Callable getTask(final int objectCount, final KieBase kieBase, final boolean updateFacts) { + return getTask(objectCount, kieBase.newKieSession(), true, updateFacts); + } + + private Callable getTask( + final int objectCount, + final KieSession ksession, + final boolean disposeSession, + final boolean updateFacts) { + return new Callable() { + public Boolean call() throws Exception { + try { + for (int j = 0; j < 10; j++) { + FactHandle[] facts = new FactHandle[objectCount]; + FactHandle[] stringFacts = new FactHandle[objectCount]; + for (int i = 0; i < objectCount; i++) { + facts[i] = ksession.insert(new AtomicInteger(i)); + stringFacts[i] = ksession.insert("test_" + i); + } + if (updateFacts) { + for (int i = 0; i < objectCount; i++) { + ksession.update(facts[i], new AtomicInteger(-i)); + ksession.update(stringFacts[i], "updated_test_" + i); + } + } + for (int i = 0; i < objectCount; i++) { + ksession.delete(facts[i]); + ksession.delete(stringFacts[i]); + } + ksession.fireAllRules(); + } + return true; + } catch (Exception e) { + e.printStackTrace(); + return false; + } finally { + if (disposeSession) { + ksession.dispose(); + } + } + } + }; + } +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java index 26f53a261fd8..287c3e6b0a4a 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractConcurrentSessionsTest.java @@ -47,19 +47,26 @@ public abstract class AbstractConcurrentSessionsTest { protected final boolean enforcedJitting; protected final boolean serializeKieBase; + protected final boolean sharedKieBase; - @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}") + @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}, Share KieBase={2}") public static List getTestParameters() { return Arrays.asList( - new Boolean[]{false, false}, - new Boolean[]{false, true}, - new Boolean[]{true, false}, - new Boolean[]{true, true}); + new Boolean[]{false, false, false}, + new Boolean[]{false, true, false}, + new Boolean[]{true, false, false}, + new Boolean[]{true, true, false}, + new Boolean[]{false, false, true}, + new Boolean[]{false, true, true}, + new Boolean[]{true, false, true}, + new Boolean[]{true, true, true}); } - public AbstractConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { + public AbstractConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { this.enforcedJitting = enforcedJitting; this.serializeKieBase = serializeKieBase; + this.sharedKieBase = sharedKieBase; } interface KieSessionExecutor { @@ -103,10 +110,22 @@ public Thread newThread( final Runnable r ) { for (int i = 0; i < threadCount; i++) { final int counter = i; + + KieBase kieBaseLocal; + if (sharedKieBase) { + kieBaseLocal = kieBase; + } else { + if (serializeKieBase) { + kieBaseLocal = serializeAndDeserializeKieBase(kieHelper.build(kieBaseOptions)); + } else { + kieBaseLocal = kieHelper.build(kieBaseOptions); + } + } + tasks[i] = new Callable() { @Override public Boolean call() throws Exception { - return kieSessionExecutor.execute(kieBase.newKieSession(), counter); + return kieSessionExecutor.execute(kieBaseLocal.newKieSession(), counter); } }; } diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java new file mode 100644 index 000000000000..3c8fd4a74e8e --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/AbstractParallelTest.java @@ -0,0 +1,232 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.kie.api.KieBase; +import org.kie.api.conf.KieBaseOption; +import org.kie.api.io.ResourceType; +import org.kie.api.runtime.KieSession; +import org.kie.internal.conf.ConstraintJittingThresholdOption; +import org.kie.internal.utils.KieHelper; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public abstract class AbstractParallelTest { + + protected final boolean enforcedJitting; + protected final boolean serializeKieBase; + + @Parameterized.Parameters(name = "Enforced jitting={0}, Serialize KieBase={1}") + public static List getTestParameters() { + return Arrays.asList( + new Boolean[] {false, false}, + new Boolean[] {false, true}, + new Boolean[] {true, false}, + new Boolean[] {true, true}); + } + + public AbstractParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + this.enforcedJitting = enforcedJitting; + this.serializeKieBase = serializeKieBase; + } + + public void parallelTest(int numberOfThreads, ParallelTestExecutor executor) throws InterruptedException { + + Callable[] tasks; + + final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads, new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + final Thread t = new Thread(r); + t.setDaemon(true); + return t; + } + }); + + tasks = getTasks(numberOfThreads, executor); + final CompletionService completionService = new ExecutorCompletionService(executorService); + + try { + for (Callable task : tasks) { + completionService.submit(task); + } + + int successCounter = 0; + for (int i = 0; i < numberOfThreads; i++) { + try { + if (completionService.take().get()) { + successCounter++; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + assertEquals(numberOfThreads, successCounter); + + } finally { + executorService.shutdown(); + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } + + } + + public interface ParallelTestExecutor { + public boolean execute(int counter) throws InterruptedException; + } + + private Callable[] getTasks(int numberOfThreads, ParallelTestExecutor executor) { + Callable[] tasks = new Callable[numberOfThreads]; + for (int i = 0; i < numberOfThreads; i++) { + final int counter = i; + tasks[counter] = new Callable() { + @Override + public Boolean call() throws Exception { + return executor.execute(counter); + } + }; + } + return tasks; + } + + private static synchronized KieHelper getKieHelper(String... drls) { + KieHelper kieHelper = new KieHelper(); + for (String drl : drls) { + kieHelper.addContent(drl, ResourceType.DRL); + } + return kieHelper; + } + + protected synchronized KieBase getKieBase(String... drls) { + KieBaseOption[] kieBaseOptions = (enforcedJitting) ? new KieBaseOption[] {ConstraintJittingThresholdOption.get(0)} + : new KieBaseOption[] {}; + KieBase kieBase = getKieHelper(drls).build(kieBaseOptions); + if (serializeKieBase) { + kieBase = serializeAndDeserializeKieBase(kieBase); + } + return kieBase; + } + + private KieBase serializeAndDeserializeKieBase(final KieBase kieBase) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(baos); + try { + out.writeObject(kieBase); + out.flush(); + } finally { + out.close(); + } + + ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())); + try { + return (KieBase) in.readObject(); + } finally { + in.close(); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + protected static void disposeSession(KieSession session) { + if (session != null) { + session.dispose(); + } + } + + protected void checkList(int end, List list) { + checkList(0, end, list); + } + + protected void checkList(int start, int end, List list) { + int expectedSize = end - start; + checkList(start, end, list, expectedSize); + } + + protected void checkList(int start, int end, List list, int expectedSize) { + assertEquals(expectedSize, list.size()); + for (int i = start; i < end; i++) { + assertTrue(list.contains("" + i)); + } + } + + public static class BeanA { + int seed; + + public BeanA() { + this.seed = 1; + } + + public BeanA(int seed) { + this.seed = seed; + } + + public int getSeed() { + return this.seed; + } + + public void setSeed(int seed) { + this.seed = seed; + } + } + + public static class BeanB { + int seed; + + public BeanB() { + this.seed = 1; + } + + public BeanB(int seed) { + this.seed = seed; + } + + public int getSeed() { + return this.seed; + } + + public void setSeed(int seed) { + this.seed = seed; + } + } + +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java new file mode 100644 index 000000000000..cd0ab8985af8 --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/ConcurrentBasesParallelTest.java @@ -0,0 +1,788 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.junit.Test; +import org.kie.api.KieBase; +import org.kie.api.runtime.KieSession; +import org.kie.api.runtime.rule.QueryResults; +import org.kie.api.runtime.rule.QueryResultsRow; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class ConcurrentBasesParallelTest extends AbstractParallelTest { + + public ConcurrentBasesParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testDifferentRuleset1() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA( seed == " + counter + ") " + + "then " + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfObjects; i++) { + session.insert(new BeanA(i)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset2() throws InterruptedException { + int numberOfThreads = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA( seed == " + counter + ") " + + "then " + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfThreads; i++) { + if (i != counter) { + session.insert(new BeanA(i)); + } + } + return session.fireAllRules() == 0; + + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset3() throws InterruptedException { + int numberOfThreads = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global " + AtomicInteger.class.getCanonicalName() + " result;\n" + + "rule Rule_" + counter + " " + + "when " + + " BeanA()" + + "then " + + " result.set(" + counter + ");" + + "end"; + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + AtomicInteger r = new AtomicInteger(0); + session.setGlobal("result", r); + assertEquals(1, session.fireAllRules()); + assertEquals(counter, r.get()); + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset4() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + " list.add(\"${className}\");" + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + session.insert(new BeanB()); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = session.fireAllRules(); + assertEquals(1, list.size()); + assertEquals(className, list.get(0)); + return rulesFired == 1; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset5() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanB()); + } else { + session.insert(new BeanA()); + } + int rulesFired = session.fireAllRules(); + assertEquals(0, rulesFired); + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRuleset6() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + "${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(1, rulesFired); + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetNot() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " not ${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(0, rulesFired); + return true; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetExists() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " exists ${className}()" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + + KieBase base = getKieBase(rule); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA()); + } else { + session.insert(new BeanB()); + } + int rulesFired = session.fireAllRules(); + assertEquals(1, rulesFired); + return true; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testDifferentRulesetSharedSubnetwork() throws InterruptedException { + int numberOfThreads = 100; + + String ruleTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule ${ruleName} " + + "when " + + " $bean : ${className}() \n" + + "then " + + "end"; + + String subnetworkRuleTemplate = "rule Rule_subnetwork " + + "when " + + " $bean : ${className}() \n" + + " Number( doubleValue > 0) from" + + " accumulate ( BeanA() and $s : String(), count($s) )" + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + String className = (counter % 2 == 0) ? "BeanA" : "BeanB"; + String ruleName = "Rule_" + className + "_" + counter; + String rule = ruleTemplate.replace("${ruleName}", ruleName).replace("${className}", className); + String subnetworkRule = subnetworkRuleTemplate.replace("${className}", className); + + KieBase base = getKieBase(rule, subnetworkRule); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert("test"); + if (counter % 2 == 0) { + session.insert(new BeanA()); + assertEquals(2, session.fireAllRules()); + } else { + session.insert(new BeanB()); + assertEquals(1, session.fireAllRules()); + } + return true; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset1() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + return session.fireAllRules() == 1; + } else { + return session.fireAllRules() == 0; + } + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset2() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + session.insert(new BeanB()); + return session.fireAllRules() == 2; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset3() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleA " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanA(), count($bean)) " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleB " + + "when " + + " $n : Number( doubleValue == 1 ) from accumulate($bean : BeanB(), count($bean)) " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + } else { + session.insert(new BeanB(counter)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset4() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleNotA " + + "when " + + " not BeanA() " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleNotB " + + "when " + + " not BeanB() " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + } else { + session.insert(new BeanB(counter)); + } + return session.fireAllRules() == 1; + } finally { + disposeSession(session); + } + + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testSameRuleset5() throws InterruptedException { + int numberOfThreads = 100; + + String ruleA = "import " + BeanA.class.getCanonicalName() + ";\n" + + "rule RuleNotA " + + "when " + + " not BeanA() " + + "then " + + "end"; + + String ruleB = "import " + BeanB.class.getCanonicalName() + ";\n" + + "rule RuleNotB " + + "when " + + " not BeanB() " + + "then " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + KieBase base = getKieBase(ruleA, ruleB); + KieSession session = null; + + try { + session = base.newKieSession(); + if (counter % 2 == 0) { + session.insert(new BeanA(counter)); + session.insert(new BeanB(counter)); + return session.fireAllRules() == 0; + } else { + return session.fireAllRules() == 2; + } + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testFunctions() throws InterruptedException { + int numberOfThreads = 100; + + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;" + + "rule Rule " + + "when " + + " BeanA() " + + "then " + + " addToList(list);" + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + String function = "import java.util.List;\n" + + "function void addToList(List list) { \n" + + " list.add( \"" + counter + "\" );\n" + + "}\n"; + + KieBase base = getKieBase(rule, function); + KieSession session = null; + + try { + session = base.newKieSession(); + session.insert(new BeanA()); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = session.fireAllRules(); + assertEquals(1, list.size()); + assertEquals(""+counter, list.get(0)); + return rulesFired == 1; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testFunctions2() throws InterruptedException { + int numberOfThreads = 100; + int objectCount = 100; + + String rule = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;" + + "rule Rule " + + "when " + + " BeanA() " + + "then " + + " addToList(list);" + + "end"; + + String functionTemplate = "import java.util.List;\n" + + "function void addToList(List list) { \n" + + " list.add( \"${identifier}\" );\n" + + "}\n"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + + String identifier = (counter%2 == 0) ? "even" : "odd"; + String otherIdentifier = (counter%2 == 0) ? "odd" : "even"; + String functionRule = functionTemplate.replace("${identifier}", identifier); + + KieBase base = getKieBase(rule, functionRule); + KieSession session = null; + + try { + session = base.newKieSession(); + List list = new ArrayList<>(); + session.setGlobal("list", list); + int rulesFired = 0; + for (int i = 0; i < objectCount; i++) { + session.insert(new BeanA(i)); + rulesFired += session.fireAllRules(); + } + assertEquals(objectCount, list.size()); + assertTrue(list.contains(identifier)); + assertTrue(!list.contains(otherIdentifier)); + return rulesFired == objectCount; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testQueries() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + String query = "import " + BeanA.class.getCanonicalName() + ";\n" + + "query Query " + + " bean : BeanA( seed == "+ counter +" ) " + + "end"; + + KieBase base = getKieBase(query); + KieSession session = null; + + try { + session = base.newKieSession(); + BeanA bean = new BeanA(counter); + session.insert(bean); + for (int i = 0; i < numberOfObjects; i++) { + if (i != counter) { + session.insert(new BeanA(i)); + } + } + QueryResults results = session.getQueryResults("Query"); + assertEquals(1, results.size()); + for (QueryResultsRow row : results) { + assertEquals(bean, (BeanA) row.get("bean")); + } + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } + + @Test + public void testQueries2() throws InterruptedException { + int numberOfThreads = 100; + int numberOfObjects = 100; + + String queryTemplate = "import " + BeanA.class.getCanonicalName() + ";\n" + + "query Query " + + " bean : BeanA( seed == ${seed} ) " + + "end"; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + int seed = counter % 2; + String seedString = "" + seed; + String queryDrl = queryTemplate.replace("${seed}", seedString); + KieBase base = getKieBase(queryDrl); + KieSession session = null; + + try { + session = base.newKieSession(); + for (int i = 0; i < numberOfObjects; i++) { + session.insert(new BeanA(seed)); + } + QueryResults results = session.getQueryResults("Query"); + assertEquals(numberOfObjects, results.size()); + for (QueryResultsRow row : results) { + BeanA bean = (BeanA) row.get("bean"); + assertEquals(seed, bean.getSeed()); + } + return true; + } finally { + disposeSession(session); + } + } + }; + + parallelTest(numberOfThreads, exec); + } +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java index 1ccf2f9f1484..1cb93ab03675 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationConcurrentSessionsTest.java @@ -35,8 +35,9 @@ public class DataTypeEvaluationConcurrentSessionsTest extends AbstractConcurrentSessionsTest { - public DataTypeEvaluationConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); + public DataTypeEvaluationConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { + super(enforcedJitting, serializeKieBase, sharedKieBase); } @Test diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java new file mode 100644 index 000000000000..3c1b0709c547 --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/DataTypeEvaluationSharedSessionParallelTest.java @@ -0,0 +1,174 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.drools.compiler.integrationtests.facts.AnEnum; +import org.drools.compiler.integrationtests.facts.FactWithBigDecimal; +import org.drools.compiler.integrationtests.facts.FactWithBoolean; +import org.drools.compiler.integrationtests.facts.FactWithByte; +import org.drools.compiler.integrationtests.facts.FactWithCharacter; +import org.drools.compiler.integrationtests.facts.FactWithDouble; +import org.drools.compiler.integrationtests.facts.FactWithEnum; +import org.drools.compiler.integrationtests.facts.FactWithFloat; +import org.drools.compiler.integrationtests.facts.FactWithInteger; +import org.drools.compiler.integrationtests.facts.FactWithLong; +import org.drools.compiler.integrationtests.facts.FactWithShort; +import org.drools.compiler.integrationtests.facts.FactWithString; +import org.junit.Test; +import org.kie.api.runtime.KieSession; + +import java.math.BigDecimal; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; + +public class DataTypeEvaluationSharedSessionParallelTest extends AbstractParallelTest { + + public DataTypeEvaluationSharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testBooleanPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanValue == false) \n", new FactWithBoolean(false)); + } + + @Test + public void testBoolean() throws InterruptedException { + testFactAttributeType(" $factWithBoolean: FactWithBoolean(booleanObjectValue == false) \n", new FactWithBoolean(false)); + } + + @Test + public void testBytePrimitive() throws InterruptedException { + testFactAttributeType(" $factWithByte: FactWithByte(byteValue == 15) \n", new FactWithByte((byte) 15)); + } + + @Test + public void testByte() throws InterruptedException { + testFactAttributeType(" $factWithByte: FactWithByte(byteObjectValue == 15) \n", new FactWithByte((byte) 15)); + } + + @Test + public void testShortPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithShort: FactWithShort(shortValue == 15) \n", new FactWithShort((short) 15)); + } + + @Test + public void testShort() throws InterruptedException { + testFactAttributeType(" $factWithShort: FactWithShort(shortObjectValue == 15) \n", new FactWithShort((short) 15)); + } + + @Test + public void testIntPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithInt: FactWithInteger(intValue == 15) \n", new FactWithInteger(15)); + } + + @Test + public void testInteger() throws InterruptedException { + testFactAttributeType(" $factWithInteger: FactWithInteger(integerValue == 15) \n", new FactWithInteger(15)); + } + + @Test + public void testLongPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithLong: FactWithLong(longValue == 15) \n", new FactWithLong(15)); + } + + @Test + public void testLong() throws InterruptedException { + testFactAttributeType(" $factWithLong: FactWithLong(longObjectValue == 15) \n", new FactWithLong(15)); + } + + @Test + public void testFloatPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithFloat: FactWithFloat(floatValue == 15.1) \n", new FactWithFloat(15.1f)); + } + + @Test + public void testFloat() throws InterruptedException { + testFactAttributeType(" $factWithFloat: FactWithFloat(floatObjectValue == 15.1) \n", new FactWithFloat(15.1f)); + } + + @Test + public void testDoublePrimitive() throws InterruptedException { + testFactAttributeType(" $factWithDouble: FactWithDouble(doubleValue == 15.1) \n", new FactWithDouble(15.1d)); + } + + @Test + public void testDouble() throws InterruptedException { + testFactAttributeType(" $factWithDouble: FactWithDouble(doubleObjectValue == 15.1) \n", new FactWithDouble(15.1d)); + } + + @Test + public void testBigDecimal() throws InterruptedException { + testFactAttributeType(" $factWithBigDecimal: FactWithBigDecimal(bigDecimalValue == 10) \n", new FactWithBigDecimal(BigDecimal.TEN)); + } + + @Test + public void testCharPrimitive() throws InterruptedException { + testFactAttributeType(" $factWithChar: FactWithCharacter(charValue == 'a') \n", new FactWithCharacter('a')); + } + + @Test + public void testCharacter() throws InterruptedException { + testFactAttributeType(" $factWithChar: FactWithCharacter(characterValue == 'a') \n", new FactWithCharacter('a')); + } + + @Test + public void testString() throws InterruptedException { + testFactAttributeType(" $factWithString: FactWithString(stringValue == \"test\") \n", new FactWithString("test")); + } + + @Test + public void testEnum() throws InterruptedException { + testFactAttributeType(" $factWithEnum: FactWithEnum(enumValue == AnEnum.FIRST) \n", new FactWithEnum(AnEnum.FIRST)); + } + + private void testFactAttributeType(final String ruleConstraint, final Object factInserted) throws InterruptedException { + int numberOfThreads = 10; + + + final String drl = + " import org.drools.compiler.integrationtests.facts.*;\n" + + " global " + AtomicInteger.class.getCanonicalName() + " numberOfFirings;\n" + + " rule R1 \n" + + " when \n" + + ruleConstraint + + " then \n" + + " numberOfFirings.incrementAndGet(); \n" + + " end "; + + KieSession kieSession = getKieBase(drl).newKieSession(); + + AtomicInteger numberOfFirings = new AtomicInteger(0); + + parallelTest(numberOfThreads, new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + if (kieSession.getGlobal("numberOfFirings") == null) { + kieSession.setGlobal("numberOfFirings", numberOfFirings); + } + ; + kieSession.insert(factInserted); + kieSession.fireAllRules(); + return true; + } + }); + disposeSession(kieSession); + assertEquals(1, numberOfFirings.get()); + } + +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java index e32f65f297da..159d2a540043 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/JoinsConcurrentSessionsTest.java @@ -29,8 +29,9 @@ public class JoinsConcurrentSessionsTest extends AbstractConcurrentSessionsTest { - public JoinsConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); + public JoinsConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { + super(enforcedJitting, serializeKieBase, sharedKieBase); } @Test diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java new file mode 100644 index 000000000000..b5174ab32a53 --- /dev/null +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SharedSessionParallelTest.java @@ -0,0 +1,488 @@ +/* + * Copyright 2017 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.compiler.integrationtests.session; + +import org.junit.Test; +import org.kie.api.runtime.KieSession; +import org.kie.api.runtime.rule.FactHandle; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; + +public class SharedSessionParallelTest extends AbstractParallelTest { + + public SharedSessionParallelTest(final boolean enforcedJitting, final boolean serializeKieBase) { + super(enforcedJitting, serializeKieBase); + } + + @Test + public void testNoExceptions() throws InterruptedException { + String drl = "rule R1 when String() then end"; + + int repetitions = 100; + int numberOfObjects = 1000; + int countOfThreads = 100; + + for (int i = 0; i < repetitions; i++) { + + KieSession kieSession = getKieBase(drl).newKieSession(); + + parallelTest(countOfThreads, new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + for (int j = 0; j < numberOfObjects; j++) { + kieSession.insert("test_" + numberOfObjects); + } + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }); + + disposeSession(kieSession); + } + } + + @Test + public void testCheckOneThreadOnly() throws InterruptedException { + int threadCount = 100; + List list = new ArrayList<>(); + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed) " + + "then " + + " list.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + CountDownLatch latch = new CountDownLatch(threadCount); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) throws InterruptedException { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter)); + latch.countDown(); + + if (counter == 0) { + latch.await(); + return kieSession.fireAllRules() == threadCount; + } + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + assertEquals(threadCount, list.size()); + for (int i = 0; i < threadCount; i++) { + assertTrue(list.contains("" + i)); + } + } + + @Test + public void testCorrectFirings() throws InterruptedException { + int threadCount = 100; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List globalList;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed) " + + "then " + + " globalList.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("globalList", list); + kieSession.insert(new BeanA(counter)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, list); + } + + @Test + public void testCorrectFirings2() throws InterruptedException { + int threadCount = 100; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R1 " + + "when " + + " BeanA($n : seed, seed == 0) " + + "then " + + " list.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter % 2)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + assertTrue(list.contains("" + 0)); + assertFalse(list.contains("" + 1)); + int expectedListSize = ((threadCount - 1) / 2) + 1; + assertEquals(expectedListSize, list.size()); + } + + + @Test + public void testLongRunningRule() throws InterruptedException { + int threadCount = 100; + int seed = threadCount + 200; + int objectCount = 1000; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > " + threadCount + ") " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + " Thread.sleep(5);" + + "end"; + + String listDrl = "global java.util.List list2;\n" + + "rule listRule " + + "when " + + " BeanA($n : seed, seed < " + threadCount + ") " + + "then " + + " list2.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + List list2 = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.setGlobal("list2", list2); + kieSession.insert(new BeanA(seed)); + barrier.await(); + kieSession.fireAllRules(); + return true; + } else { + barrier.await(); + Thread.sleep(100); + for (int i = 0; i < objectCount; i++) { + kieSession.insert(new BeanA(counter)); + } + kieSession.fireAllRules(); + return true; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, seed, list); + checkList(1, threadCount, list2, (threadCount - 1) * objectCount); + } + + @Test + public void testLongRunningRule2() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String waitingRule = "rule waitingRule " + + "when " + + " String( this == \"wait\" ) " + + "then " + + " Thread.sleep(10);" + + "end"; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, waitingRule).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.insert("wait"); + kieSession.insert(new BeanA(seed)); + barrier.await(); + return kieSession.fireAllRules() == seed * threadCount + 1; + } else { + barrier.await(); + Thread.sleep(10); + kieSession.insert(new BeanA(seed)); + return kieSession.fireAllRules() == 0; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(0, seed, list, seed * threadCount); + } + + @Test + public void testLongRunningRule3() throws InterruptedException { + int threadCount = 10; + int seed = threadCount + 50; + int objectCount = 1000; + + String longRunningDrl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule longRunning " + + "when " + + " $bean : BeanA($n : seed, seed > " + threadCount + ") " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + String listDrl = "global java.util.List list2;\n" + + "rule listRule " + + "when " + + " BeanA($n : seed, seed < " + threadCount + ") " + + "then " + + " list2.add(\"\" + $n);" + + "end"; + + KieSession kieSession = getKieBase(longRunningDrl, listDrl).newKieSession(); + + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + List list2 = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter % 2 == 0) { + kieSession.setGlobal("list", list); + kieSession.setGlobal("list2", list2); + kieSession.insert(new BeanA(seed)); + barrier.await(); + kieSession.fireAllRules(); + return true; + } else { + barrier.await(); + Thread.sleep(100); + for (int i = 0; i < objectCount; i++) { + kieSession.insert(new BeanA(counter)); + } + kieSession.fireAllRules(); + return true; + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + int listExpectedSize = (threadCount / 2 + threadCount % 2) * (seed - threadCount); + int list2ExpectedSize = threadCount / 2 * objectCount; + for (int i = 0; i < threadCount; i++) { + if (i % 2 == 1) { + list2.contains("" + i); + } + } + assertEquals(listExpectedSize, list.size()); + assertEquals(list2ExpectedSize, list2.size()); + } + + @Test + public void testCountdownBean() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule countdown " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + CyclicBarrier barrier = new CyclicBarrier(threadCount); + List list = new ArrayList<>(); + BeanA bean = new BeanA(seed); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + kieSession.insert(bean); + } + barrier.await(); + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(seed, list); + assertEquals(0, bean.getSeed()); + } + + @Test + public void testCountdownBean2() throws InterruptedException { + int threadCount = 100; + int seed = 1000; + + String drl = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule countdown " + + "when " + + " $bean : BeanA($n : seed, seed > 0 ) " + + "then " + + " modify($bean) { setSeed($n-1) };" + + " list.add(\"\" + $bean.getSeed());" + + "end"; + + KieSession kieSession = getKieBase(drl).newKieSession(); + List list = new ArrayList<>(); + BeanA[] beans = new BeanA[threadCount]; + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + BeanA bean = new BeanA(seed); + beans[counter] = bean; + try { + if (counter == 0) { + kieSession.setGlobal("list", list); + } + kieSession.insert(bean); + kieSession.fireAllRules(); + return true; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + + checkList(0, seed, list, seed * threadCount); + for (BeanA bean : beans) { + assertEquals(0, bean.getSeed()); + } + } + + @Test + public void testOneRulePerThread() throws InterruptedException { + int threadCount = 1000; + + String[] drls = new String[threadCount]; + for (int i = 0; i < threadCount; i++) { + drls[i] = "import " + BeanA.class.getCanonicalName() + ";\n" + + "global java.util.List list;\n" + + "rule R" + i + " " + + "when " + + " $bean : BeanA( seed == " + i + " ) " + + "then " + + " list.add(\"" + i + "\");" + + "end"; + } + + KieSession kieSession = getKieBase(drls).newKieSession(); + List list = new ArrayList<>(); + + ParallelTestExecutor exec = new ParallelTestExecutor() { + @Override + public boolean execute(int counter) { + kieSession.setGlobal("list", list); + kieSession.insert(new BeanA(counter)); + kieSession.fireAllRules(); + return true; + } + }; + + parallelTest(threadCount, exec); + disposeSession(kieSession); + checkList(threadCount, list); + } +} diff --git a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java index 49f30045a112..377689436e74 100644 --- a/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java +++ b/drools-compiler/src/test/java/org/drools/compiler/integrationtests/session/SubnetworkConcurrentSessionsTest.java @@ -21,8 +21,9 @@ public class SubnetworkConcurrentSessionsTest extends AbstractConcurrentSessionsTest { - public SubnetworkConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase) { - super(enforcedJitting, serializeKieBase); + public SubnetworkConcurrentSessionsTest(final boolean enforcedJitting, final boolean serializeKieBase, + final boolean sharedKieBase) { + super(enforcedJitting, serializeKieBase, sharedKieBase); } @Test(timeout = 5000)