From ce9a251d1b2672acaa8ea02dfd0b9b364667b08a Mon Sep 17 00:00:00 2001 From: Jaden Peterson Date: Thu, 21 May 2026 10:47:35 -0400 Subject: [PATCH 1/2] Moved TestTaskExecutor to a new file --- .../rules_scala/common/sbt-testing/Test.scala | 26 ++----------------- .../common/sbt-testing/TestTaskExecutor.scala | 26 +++++++++++++++++++ .../zinc/test/TestFrameworkRunner.scala | 1 - 3 files changed, 28 insertions(+), 25 deletions(-) create mode 100644 src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala index 64f49c96..2e311bc0 100644 --- a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala @@ -2,7 +2,7 @@ package higherkindness.rules_scala.common.sbt_testing import java.nio.file.{Path, Paths} import play.api.libs.json.{Format, Json} -import sbt.testing.{Event, Framework, Logger, Runner, Status, Task, TaskDef, TestWildcardSelector} +import sbt.testing.{Framework, Logger, Runner, Task, TaskDef, TestWildcardSelector} import scala.collection.mutable import scala.util.control.NonFatal @@ -37,7 +37,7 @@ class TestFrameworkLoader(loader: ClassLoader) { (Some(`class`.getDeclaredConstructor().newInstance()), loadedJar) } catch { case _: ClassNotFoundException => (None, None) - case NonFatal(e) => throw new Exception(s"Failed to load framework $className", e) + case NonFatal(e) => throw new Exception(s"Failed to load framework $className", e) } framework.map { case framework: Framework => @@ -101,25 +101,3 @@ class TestReporter(logger: Logger) { def preTask(task: Task) = logger.info(task.taskDef.fullyQualifiedName) } - -class TestTaskExecutor(logger: Logger) { - def execute(task: Task, failures: mutable.Set[String]): mutable.ListBuffer[Event] = { - var events = new mutable.ListBuffer[Event]() - def execute(task: Task): Unit = { - val tasks = task.execute( - event => { - events += event - event.status match { - case Status.Failure | Status.Error => - failures += task.taskDef.fullyQualifiedName - case _ => - } - }, - Array(new PrefixedTestingLogger(logger, " ")), - ) - tasks.foreach(execute) - } - execute(task) - events - } -} diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala new file mode 100644 index 00000000..26d8bc6b --- /dev/null +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala @@ -0,0 +1,26 @@ +package higherkindness.rules_scala.common.sbt_testing + +import sbt.testing.{Event, Logger, Status, Task} +import scala.collection.mutable + +class TestTaskExecutor(logger: Logger) { + def execute(task: Task, failures: mutable.Set[String]): mutable.ListBuffer[Event] = { + var events = new mutable.ListBuffer[Event]() + def execute(task: Task): Unit = { + val tasks = task.execute( + event => { + events += event + event.status match { + case Status.Failure | Status.Error => + failures += task.taskDef.fullyQualifiedName + case _ => + } + }, + Array(new PrefixedTestingLogger(logger, " ")), + ) + tasks.foreach(execute) + } + execute(task) + events + } +} diff --git a/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala index 1e8eccc2..0f3861a0 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala @@ -98,7 +98,6 @@ class ProcessTestRunner( } } - val taskExecutor = new TestTaskExecutor(logger) val failures = mutable.Set[String]() tests.foreach { test => val process = new ProcessBuilder((command.executable +: command.arguments): _*) From 8a0c1b2e4aa78515fb330e6aa0dad7bc9858a30e Mon Sep 17 00:00:00 2001 From: Jaden Peterson Date: Thu, 21 May 2026 14:34:07 -0400 Subject: [PATCH 2/2] Implemented concurrent test execution --- .bazelrc_shared | 2 +- rules/private/phases/phase_test_launcher.bzl | 2 + rules/scala.bzl | 4 + .../common/sbt-testing/BufferedLogger.scala | 37 +++++ .../common/sbt-testing/JUnitXmlReporter.scala | 87 ++++++------ .../common/sbt-testing/SubprocessRunner.scala | 21 +-- .../rules_scala/common/sbt-testing/Test.scala | 2 +- .../common/sbt-testing/TestTaskExecutor.scala | 133 +++++++++++++++--- .../zinc/test/TestFrameworkRunner.scala | 83 ++++++----- .../workers/zinc/test/TestRunner.scala | 44 ++++-- tests/test-frameworks/concurrency/BUILD.bazel | 16 +++ tests/test-frameworks/concurrency/Spec1.scala | 9 ++ tests/test-frameworks/concurrency/Spec2.scala | 9 ++ tests/test-frameworks/concurrency/Spec3.scala | 9 ++ tests/test-frameworks/concurrency/Spec4.scala | 9 ++ .../test-frameworks/concurrency/package.scala | 21 +++ tests/test-frameworks/concurrency/test | 74 ++++++++++ 17 files changed, 442 insertions(+), 120 deletions(-) create mode 100644 src/main/scala/higherkindness/rules_scala/common/sbt-testing/BufferedLogger.scala create mode 100644 tests/test-frameworks/concurrency/BUILD.bazel create mode 100644 tests/test-frameworks/concurrency/Spec1.scala create mode 100644 tests/test-frameworks/concurrency/Spec2.scala create mode 100644 tests/test-frameworks/concurrency/Spec3.scala create mode 100644 tests/test-frameworks/concurrency/Spec4.scala create mode 100644 tests/test-frameworks/concurrency/package.scala create mode 100755 tests/test-frameworks/concurrency/test diff --git a/.bazelrc_shared b/.bazelrc_shared index 9ee97bf7..f581e739 100644 --- a/.bazelrc_shared +++ b/.bazelrc_shared @@ -6,7 +6,7 @@ build --tool_java_language_version="21" build --tool_java_runtime_version="remotejdk_21" # Other options -build --experimental_use_hermetic_linux_sandbox +# build --experimental_use_hermetic_linux_sandbox build --experimental_worker_cancellation build --experimental_worker_multiplex_sandboxing build --experimental_worker_sandbox_hardening diff --git a/rules/private/phases/phase_test_launcher.bzl b/rules/private/phases/phase_test_launcher.bzl index c6f28774..82eaba91 100644 --- a/rules/private/phases/phase_test_launcher.bzl +++ b/rules/private/phases/phase_test_launcher.bzl @@ -61,6 +61,8 @@ def phase_test_launcher(ctx, g): files.append(subprocess_executable) args.add("--isolation", "process") args.add("--subprocess_exec", subprocess_executable.short_path) + if ctx.attr.sequential: + args.add("--sequential") args.add_all("--", test_jars, map_each = _short_path) args.set_param_file_format("multiline") args_file = ctx.actions.declare_file("{}/test.params".format(ctx.label.name)) diff --git a/rules/scala.bzl b/rules/scala.bzl index b6bf8a52..81e956e8 100644 --- a/rules/scala.bzl +++ b/rules/scala.bzl @@ -382,6 +382,10 @@ def make_scala_test(*extras): cfg = _scala_outgoing_transition, default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/workers/zinc/test", ), + "sequential": attr.bool( + default = False, + doc = "Whether to run test classes sequentially. If false, they'll be run concurrently.", + ), "subprocess_runner": attr.label( cfg = _scala_outgoing_transition, default = "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/sbt-testing:subprocess", diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/BufferedLogger.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/BufferedLogger.scala new file mode 100644 index 00000000..70a2554a --- /dev/null +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/BufferedLogger.scala @@ -0,0 +1,37 @@ +package higherkindness.rules_scala.common.sbt_testing + +import sbt.testing.Logger + +import scala.collection.mutable + +private sealed trait SbtLogEntry +private object SbtLogEntry { + case class Error(message: String) extends SbtLogEntry + case class Warn(message: String) extends SbtLogEntry + case class Info(message: String) extends SbtLogEntry + case class Debug(message: String) extends SbtLogEntry + case class Trace(throwable: Throwable) extends SbtLogEntry +} + +class BufferedLogger(underlying: Logger) extends Logger { + private val buffer = mutable.ArrayBuffer.empty[SbtLogEntry] + + override def ansiCodesSupported(): Boolean = underlying.ansiCodesSupported() + override def error(message: String): Unit = buffer.addOne(SbtLogEntry.Error(message)) + override def warn(message: String): Unit = buffer.addOne(SbtLogEntry.Warn(message)) + override def info(message: String): Unit = buffer.addOne(SbtLogEntry.Info(message)) + override def debug(message: String): Unit = buffer.addOne(SbtLogEntry.Debug(message)) + override def trace(throwable: Throwable): Unit = buffer.addOne(SbtLogEntry.Trace(throwable)) + + def flush(): Unit = { + buffer.foreach { + case SbtLogEntry.Error(message) => underlying.error(message) + case SbtLogEntry.Warn(message) => underlying.warn(message) + case SbtLogEntry.Info(message) => underlying.info(message) + case SbtLogEntry.Debug(message) => underlying.debug(message) + case SbtLogEntry.Trace(throwable) => underlying.trace(throwable) + } + + buffer.clear() + } +} diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/JUnitXmlReporter.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/JUnitXmlReporter.scala index 523cc9d1..65c3cb9e 100644 --- a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/JUnitXmlReporter.scala +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/JUnitXmlReporter.scala @@ -2,11 +2,10 @@ package higherkindness.rules_scala.common.sbt_testing import java.io.{PrintWriter, StringWriter} import sbt.testing.Status.{Canceled, Error, Failure, Ignored, Pending, Skipped} -import sbt.testing.{Event, Status, TestSelector} -import scala.collection.mutable.ListBuffer +import sbt.testing.{Event, Status, Task, TestSelector} import scala.xml.{Elem, Utility, XML} -class JUnitXmlReporter(tasksAndEvents: ListBuffer[(String, ListBuffer[Event])]) { +class JUnitXmlReporter(taskEvents: Map[Task, Array[Event]]) { private def escape(info: String): String = info match { case str: String => Utility.escape(str) case null => "" @@ -14,58 +13,60 @@ class JUnitXmlReporter(tasksAndEvents: ListBuffer[(String, ListBuffer[Event])]) def result: Elem = XML.loadString(s""" - ${(for ((name, events) <- tasksAndEvents) - yield s""" e.status == Ignored || e.status == Skipped || e.status == Pending || e.status == Canceled) + .toString}" time="${(events.map(_.duration).sum / 1000d).toString}"> ${(for (e <- events) - yield s""" escape(selector.testName) + case _ => "Error occurred outside of a test case." + }}" time="${(e.duration / 1000d).toString}"> ${val stringWriter = new StringWriter() - if (e.throwable.isDefined) { - val writer = new PrintWriter(stringWriter) - e.throwable.get.printStackTrace(writer) - writer.flush() - } - val trace: String = stringWriter.toString - e.status match { - case Status.Error if e.throwable.isDefined => - val t = e.throwable.get - s"""${escape( - trace, - )}""" - case Status.Error => - s"""""" - case Status.Failure if e.throwable.isDefined => - val t = e.throwable.get - s"""${escape( - trace, - )}""" - case Status.Failure => - s"""""" - case Status.Canceled if e.throwable.isDefined => - val t = e.throwable.get - s"""${escape( - trace, - )}""" - case Status.Canceled => - s"""""" - case Status.Ignored | Status.Skipped | Status.Pending => - "" - case _ => - }} + if (e.throwable.isDefined) { + val writer = new PrintWriter(stringWriter) + e.throwable.get.printStackTrace(writer) + writer.flush() + } + val trace: String = stringWriter.toString + e.status match { + case Status.Error if e.throwable.isDefined => + val t = e.throwable.get + s"""${escape( + trace, + )}""" + case Status.Error => + s"""""" + case Status.Failure if e.throwable.isDefined => + val t = e.throwable.get + s"""${escape( + trace, + )}""" + case Status.Failure => + s"""""" + case Status.Canceled if e.throwable.isDefined => + val t = e.throwable.get + s"""${escape( + trace, + )}""" + case Status.Canceled => + s"""""" + case Status.Ignored | Status.Skipped | Status.Pending => + "" + case _ => + }} """).mkString("")} diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/SubprocessRunner.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/SubprocessRunner.scala index 7358349b..6b9ee2c7 100644 --- a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/SubprocessRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/SubprocessRunner.scala @@ -3,7 +3,8 @@ package higherkindness.rules_scala.common.sbt_testing import higherkindness.rules_scala.common.classloaders.ClassLoaders import java.io.ObjectInputStream import java.nio.file.Paths -import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration.Duration object SubprocessTestRunner { @@ -20,14 +21,16 @@ object SubprocessTestRunner { val tasks = runner.tasks(Array(TestHelper.taskDef(request.test, request.scopeAndTestName))) tasks.length == 0 || { val reporter = new TestReporter(request.logger) - val taskExecutor = new TestTaskExecutor(request.logger) - val failures = mutable.Set[String]() - tasks.foreach { task => - reporter.preTask(task) - taskExecutor.execute(task, failures) - reporter.postTask() - } - !failures.nonEmpty + + // We're only running a single test class, so there's not much of a point in using the + // `ConcurrentTestTaskExecutor` + val taskExecutor = new SequentialTestTaskExecutor(request.logger) + + tasks.foreach(taskExecutor.submitTask) + + val result = Await.result(taskExecutor.waitForTasks(), Duration.Inf) + + !result.failures.nonEmpty } } } diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala index 2e311bc0..909a5cfd 100644 --- a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/Test.scala @@ -37,7 +37,7 @@ class TestFrameworkLoader(loader: ClassLoader) { (Some(`class`.getDeclaredConstructor().newInstance()), loadedJar) } catch { case _: ClassNotFoundException => (None, None) - case NonFatal(e) => throw new Exception(s"Failed to load framework $className", e) + case NonFatal(e) => throw new Exception(s"Failed to load framework $className", e) } framework.map { case framework: Framework => diff --git a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala index 26d8bc6b..e113fc66 100644 --- a/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala +++ b/src/main/scala/higherkindness/rules_scala/common/sbt-testing/TestTaskExecutor.scala @@ -1,26 +1,119 @@ package higherkindness.rules_scala.common.sbt_testing +import java.util.concurrent.ConcurrentLinkedQueue import sbt.testing.{Event, Logger, Status, Task} -import scala.collection.mutable - -class TestTaskExecutor(logger: Logger) { - def execute(task: Task, failures: mutable.Set[String]): mutable.ListBuffer[Event] = { - var events = new mutable.ListBuffer[Event]() - def execute(task: Task): Unit = { - val tasks = task.execute( - event => { - events += event - event.status match { - case Status.Failure | Status.Error => - failures += task.taskDef.fullyQualifiedName - case _ => - } - }, - Array(new PrefixedTestingLogger(logger, " ")), - ) - tasks.foreach(execute) +import scala.collection.{concurrent, mutable} +import scala.concurrent.{blocking, ExecutionContext, Future} +import scala.jdk.CollectionConverters.* + +case class TaskExecutorResult(taskEvents: Map[Task, Array[Event]], failures: Array[String]) + +private object TaskExecutorResult { + private[sbt_testing] case class Mutable( + taskEvents: concurrent.TrieMap[Task, ConcurrentLinkedQueue[Event]], + failures: ConcurrentLinkedQueue[String], + ) { + def clear(): Unit = { + taskEvents.clear() + failures.clear() } - execute(task) - events + + def toTaskExecutorResult: TaskExecutorResult = TaskExecutorResult( + taskEvents.view.map { case task -> events => task -> events.asScala.toArray }.toMap, + failures.asScala.toArray, + ) + } + + private[sbt_testing] object Mutable { + def empty: Mutable = apply(concurrent.TrieMap.empty, new ConcurrentLinkedQueue()) + } +} + +trait TestTaskExecutor { + def submitTask(task: Task): Unit + def waitForTasks(): Future[TaskExecutorResult] +} + +class ConcurrentTestTaskExecutor(logger: Logger) extends TestTaskExecutor { + private val activeTasks = new ConcurrentLinkedQueue[Future[Unit]]() + private val currentResult = TaskExecutorResult.Mutable.empty + + override def submitTask(task: Task): Unit = activeTasks.add( + Future { + blocking { + val bufferedLogger = new BufferedLogger(logger) + val reporter = new TestReporter(bufferedLogger) + + reporter.preTask(task) + + val additionalTasks = task.execute( + event => { + currentResult.synchronized { + currentResult.taskEvents.getOrElseUpdate(task, new ConcurrentLinkedQueue()).add(event) + + event.status match { + case Status.Failure | Status.Error => currentResult.failures.add(task.taskDef.fullyQualifiedName) + case _ => + } + } + }, + Array(new PrefixedTestingLogger(bufferedLogger, " ")), + ) + + additionalTasks.foreach(submitTask) + + reporter.postTask() + + // Only one task should write to stderr/stdout at a time. Of course, the task implementation could write to + // stdout/stderr directly, but that's out of our control. + synchronized { + bufferedLogger.flush() + } + } + }(ExecutionContext.global), + ) + + override def waitForTasks(): Future[TaskExecutorResult] = { + given ExecutionContext = ExecutionContext.global + + Future + .sequence(activeTasks.asScala) + .map { _ => + activeTasks.clear() + currentResult.toTaskExecutorResult + }(ExecutionContext.global) + } +} + +class SequentialTestTaskExecutor(logger: Logger) extends TestTaskExecutor { + private val currentResult = TaskExecutorResult.Mutable.empty + + override def submitTask(task: Task): Unit = { + val reporter = new TestReporter(logger) + + reporter.preTask(task) + + val additionalTasks = task.execute( + event => { + currentResult.taskEvents.getOrElseUpdate(task, new ConcurrentLinkedQueue()).add(event) + + event.status match { + case Status.Failure | Status.Error => currentResult.failures.add(task.taskDef.fullyQualifiedName) + case _ => + } + }, + Array(new PrefixedTestingLogger(logger, " ")), + ) + + additionalTasks.foreach(submitTask) + reporter.postTask() + } + + override def waitForTasks(): Future[TaskExecutorResult] = { + val result = currentResult.toTaskExecutorResult + + currentResult.clear() + + Future.successful(result) } } diff --git a/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala index 0f3861a0..94c13180 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestFrameworkRunner.scala @@ -10,38 +10,48 @@ import higherkindness.rules_scala.common.sbt_testing.TestRequest import higherkindness.rules_scala.common.sbt_testing.TestTaskExecutor import java.io.ObjectOutputStream import java.nio.file.Path -import sbt.testing.{Event, Framework, Logger} +import sbt.testing.{Framework, Logger} import scala.collection.mutable +import scala.concurrent.{blocking, ExecutionContext, Future} -class BasicTestRunner(framework: Framework, classLoader: ClassLoader, logger: Logger) extends TestFrameworkRunner { - def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]) = { - var tasksAndEvents = new mutable.ListBuffer[(String, mutable.ListBuffer[Event])]() +class BasicTestRunner( + framework: Framework, + classLoader: ClassLoader, + logger: Logger, + testTaskExecutor: TestTaskExecutor, +) extends TestFrameworkRunner { + def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]): Future[Boolean] = { ClassLoaders.withContextClassLoader(classLoader) { TestHelper.withRunner(framework, scopeAndTestName, classLoader, arguments) { runner => val reporter = new TestReporter(logger) val tasks = runner.tasks(tests.map(TestHelper.taskDef(_, scopeAndTestName)).toArray) reporter.pre(framework, tasks) - val taskExecutor = new TestTaskExecutor(logger) - val failures = mutable.Set[String]() - tasks.foreach { task => - reporter.preTask(task) - val events = taskExecutor.execute(task, failures) - reporter.postTask() - tasksAndEvents += ((task.taskDef.fullyQualifiedName, events)) - } - reporter.post(failures) - val xmlReporter = new JUnitXmlReporter(tasksAndEvents) - xmlReporter.write() - !failures.nonEmpty + tasks.foreach(testTaskExecutor.submitTask) + testTaskExecutor + .waitForTasks() + .map { result => + blocking { + reporter.post(result.failures) + + val xmlReporter = new JUnitXmlReporter(result.taskEvents) + + xmlReporter.write() + + !result.failures.nonEmpty + } + }(ExecutionContext.global) } } } } -class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () => ClassLoader, logger: Logger) - extends TestFrameworkRunner { - def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]) = { - var tasksAndEvents = new mutable.ListBuffer[(String, mutable.ListBuffer[Event])]() +class ClassLoaderTestRunner( + framework: Framework, + classLoaderProvider: () => ClassLoader, + logger: Logger, + testTaskExecutor: TestTaskExecutor, +) extends TestFrameworkRunner { + def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]): Future[Boolean] = { val reporter = new TestReporter(logger) val classLoader = framework.getClass.getClassLoader @@ -52,27 +62,30 @@ class ClassLoaderTestRunner(framework: Framework, classLoaderProvider: () => Cla } } - val taskExecutor = new TestTaskExecutor(logger) - val failures = mutable.Set[String]() tests.foreach { test => val classLoader = classLoaderProvider() val isolatedFramework = new TestFrameworkLoader(classLoader).load(framework.getClass.getName).get TestHelper.withRunner(isolatedFramework, scopeAndTestName, classLoader, arguments) { runner => ClassLoaders.withContextClassLoader(classLoader) { val tasks = runner.tasks(Array(TestHelper.taskDef(test, scopeAndTestName))) - tasks.foreach { task => - reporter.preTask(task) - val events = taskExecutor.execute(task, failures) - reporter.postTask() - tasksAndEvents += ((task.taskDef.fullyQualifiedName, events)) - } + tasks.foreach(testTaskExecutor.submitTask) } } } - reporter.post(failures) - val xmlReporter = new JUnitXmlReporter(tasksAndEvents) - xmlReporter.write() - !failures.nonEmpty + + testTaskExecutor + .waitForTasks() + .map { result => + blocking { + reporter.post(result.failures) + + val xmlReporter = new JUnitXmlReporter(result.taskEvents) + + xmlReporter.write() + + !result.failures.nonEmpty + } + }(ExecutionContext.global) } } @@ -87,7 +100,7 @@ class ProcessTestRunner( command: ProcessCommand, logger: Logger with Serializable, ) extends TestFrameworkRunner { - def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]) = { + def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]): Future[Boolean] = { val reporter = new TestReporter(logger) val classLoader = framework.getClass.getClassLoader @@ -122,10 +135,10 @@ class ProcessTestRunner( } finally process.destroy } reporter.post(failures) - !failures.nonEmpty + Future.successful(!failures.nonEmpty) // We don't yet support concurrent execution with process isolation } } trait TestFrameworkRunner { - def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]): Boolean + def execute(tests: List[TestDefinition], scopeAndTestName: String, arguments: List[String]): Future[Boolean] } diff --git a/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestRunner.scala index d50e9027..36953fdd 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/zinc/test/TestRunner.scala @@ -4,7 +4,8 @@ import higherkindness.rules_scala.common.args.ArgsUtil.PathArgumentType import higherkindness.rules_scala.common.args.implicits.* import higherkindness.rules_scala.common.classloaders.ClassLoaders import higherkindness.rules_scala.common.sandbox.SandboxUtil -import higherkindness.rules_scala.common.sbt_testing.{AnnexTestingLogger, TestDefinition, TestFrameworkLoader, TestsFileData, Verbosity} +import higherkindness.rules_scala.common.sbt_testing.{AnnexTestingLogger, ConcurrentTestTaskExecutor, SequentialTestTaskExecutor, TestDefinition, TestFrameworkLoader, TestsFileData, Verbosity} +import higherkindness.rules_scala.workers.zinc.test.TestRunner.Isolation import java.io.FileInputStream import java.net.URLClassLoader import java.nio.file.attribute.FileTime @@ -16,6 +17,8 @@ import net.sourceforge.argparse4j.ArgumentParsers import net.sourceforge.argparse4j.impl.Arguments import net.sourceforge.argparse4j.inf.{ArgumentParser, Namespace} import play.api.libs.json.Json +import scala.concurrent.Await +import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters.* import scala.util.Using @@ -83,9 +86,10 @@ object TestRunner { } private class TestRunnerRequest private ( - val subprocessExecutable: Option[Path], val isolation: Isolation, + val sequential: Boolean, val sharedClasspath: List[Path], + val subprocessExecutable: Option[Path], val testClasspath: List[Path], val testsFile: Path, ) @@ -93,10 +97,11 @@ object TestRunner { private object TestRunnerRequest { def apply(runPath: Path, namespace: Namespace): TestRunnerRequest = { new TestRunnerRequest( - subprocessExecutable = - Option(namespace.get[Path]("subprocess_exec")).map(SandboxUtil.getSandboxPath(runPath, _)), isolation = Isolation(namespace.getString("isolation")), + sequential = namespace.getBoolean("sequential"), sharedClasspath = SandboxUtil.getSandboxPaths(runPath, namespace.getList[Path]("shared_classpath")), + subprocessExecutable = + Option(namespace.get[Path]("subprocess_exec")).map(SandboxUtil.getSandboxPath(runPath, _)), testClasspath = SandboxUtil.getSandboxPaths(runPath, namespace.getList[Path]("classpath")), testsFile = namespace.get[Path]("tests_file"), ) @@ -105,15 +110,15 @@ object TestRunner { private val testArgParser: ArgumentParser = { val parser = ArgumentParsers.newFor("test").addHelp(true).build() - parser - .addArgument("--subprocess_exec") - .help("Executable for SubprocessTestRunner") - .`type`(PathArgumentType.apply()) parser .addArgument("--isolation") .choices(Isolation.values.keys.toSeq: _*) .help("Test isolation") .setDefault_(Isolation.None.level) + parser + .addArgument("--sequential") + .help("If passed, run test classes sequentially instead of concurrently.") + .action(Arguments.storeTrue()) parser .addArgument("--shared_classpath") .help("Classpath to share between tests") @@ -121,6 +126,10 @@ object TestRunner { .nargs("*") .`type`(PathArgumentType.apply()) .setDefault_(Collections.emptyList) + parser + .addArgument("--subprocess_exec") + .help("Executable for SubprocessTestRunner") + .`type`(PathArgumentType.apply()) parser .addArgument("--tests_file") .help("File containing discovered tests.") @@ -188,11 +197,21 @@ object TestRunner { } } filteredTests.isEmpty || { + if (testRunnerRequest.sequential && testRunnerRequest.isolation == Isolation.Process) { + throw new Exception("Process isolation isn't yet compatible with sequential execution.") + } + + val testTaskExecutor = if (testRunnerRequest.sequential) { + new SequentialTestTaskExecutor(logger) + } else { + new ConcurrentTestTaskExecutor(logger) + } + val runner = testRunnerRequest.isolation match { case Isolation.ClassLoader => val urls = testClasspath.filterNot(sharedClasspath.toSet).map(_.toUri.toURL).toArray def classLoaderProvider() = new URLClassLoader(urls, sharedClassLoader) - new ClassLoaderTestRunner(framework, classLoaderProvider _, logger) + new ClassLoaderTestRunner(framework, classLoaderProvider _, logger, testTaskExecutor) case Isolation.Process => val executable = testRunnerRequest.subprocessExecutable.map(_.toString).getOrElse { throw new Exception("Subprocess executable missing for test ran in process isolation mode.") @@ -203,11 +222,14 @@ object TestRunner { new ProcessCommand(executable, testRunnerArgs.subprocessArgs), logger, ) - case Isolation.None => new BasicTestRunner(framework, classLoader, logger) + case Isolation.None => new BasicTestRunner(framework, classLoader, logger, testTaskExecutor) } try { - runner.execute(filteredTests.toList, testScopeAndName.getOrElse(""), testRunnerArgs.frameworkArgs) + Await.result( + runner.execute(filteredTests.toList, testScopeAndName.getOrElse(""), testRunnerArgs.frameworkArgs), + Duration.Inf, + ) } catch { case e: Throwable => e.printStackTrace() diff --git a/tests/test-frameworks/concurrency/BUILD.bazel b/tests/test-frameworks/concurrency/BUILD.bazel new file mode 100644 index 00000000..f982f982 --- /dev/null +++ b/tests/test-frameworks/concurrency/BUILD.bazel @@ -0,0 +1,16 @@ +load("@rules_scala_annex//rules:scala.bzl", "scala_test") + +scala_test( + name = "concurrent", + srcs = glob(["*.scala"]), + jvm_flags = ["-DTEST_MODE=concurrent"], + deps = ["@annex_test//:org_specs2_specs2_core_2_13"], +) + +scala_test( + name = "sequential", + srcs = glob(["*.scala"]), + jvm_flags = ["-DTEST_MODE=sequential"], + sequential = True, + deps = ["@annex_test//:org_specs2_specs2_core_2_13"], +) diff --git a/tests/test-frameworks/concurrency/Spec1.scala b/tests/test-frameworks/concurrency/Spec1.scala new file mode 100644 index 00000000..312c9c03 --- /dev/null +++ b/tests/test-frameworks/concurrency/Spec1.scala @@ -0,0 +1,9 @@ +package annex.concurrency + +import org.specs2.mutable.Specification + +class Spec1 extends Specification { + println("Spec1 executing...") + + maybeWaitForOthers() +} diff --git a/tests/test-frameworks/concurrency/Spec2.scala b/tests/test-frameworks/concurrency/Spec2.scala new file mode 100644 index 00000000..8b02fa3a --- /dev/null +++ b/tests/test-frameworks/concurrency/Spec2.scala @@ -0,0 +1,9 @@ +package annex.concurrency + +import org.specs2.mutable.Specification + +class Spec2 extends Specification { + println("Spec2 executing...") + + maybeWaitForOthers() +} diff --git a/tests/test-frameworks/concurrency/Spec3.scala b/tests/test-frameworks/concurrency/Spec3.scala new file mode 100644 index 00000000..8b1da528 --- /dev/null +++ b/tests/test-frameworks/concurrency/Spec3.scala @@ -0,0 +1,9 @@ +package annex.concurrency + +import org.specs2.mutable.Specification + +class Spec3 extends Specification { + println("Spec3 executing...") + + maybeWaitForOthers() +} diff --git a/tests/test-frameworks/concurrency/Spec4.scala b/tests/test-frameworks/concurrency/Spec4.scala new file mode 100644 index 00000000..d3350612 --- /dev/null +++ b/tests/test-frameworks/concurrency/Spec4.scala @@ -0,0 +1,9 @@ +package annex.concurrency + +import org.specs2.mutable.Specification + +class Spec4 extends Specification { + println("Spec4 executing...") + + maybeWaitForOthers() +} diff --git a/tests/test-frameworks/concurrency/package.scala b/tests/test-frameworks/concurrency/package.scala new file mode 100644 index 00000000..58e78ea2 --- /dev/null +++ b/tests/test-frameworks/concurrency/package.scala @@ -0,0 +1,21 @@ +package annex + +import java.util.concurrent.CountDownLatch + +package object concurrency { + private val testLatch = new CountDownLatch(4) + + /** + * All of the tests call this method to wait for each other to have started and printed a message before exiting. + * + * This allows us to verify with certainty that the tests' output is buffered, because if it weren't, each would + * print its "pre" message (the test name), followed by a "SpecX executing..." message, followed by its + * "post" message. The *only* way to get all the "pre" message printed first and in sequence is if the output is + * buffered and the "pre" and "post" messages are flushed at the end of the test, after all the + * "SpecX executing..." messages have been printed. + */ + def maybeWaitForOthers(): Unit = if (Option(System.getProperty("TEST_MODE")).contains[String]("concurrent")) { + System.out.flush() + testLatch.countDown() + } +} \ No newline at end of file diff --git a/tests/test-frameworks/concurrency/test b/tests/test-frameworks/concurrency/test new file mode 100755 index 00000000..b8f683f5 --- /dev/null +++ b/tests/test-frameworks/concurrency/test @@ -0,0 +1,74 @@ +#!/bin/bash -e +. "$(dirname "$0")"/../../common.sh + +# That the test completes at all is proof that the test classes are executed concurrently, because each waits for all +# the others to start. This means that if the test doesn't complete, then the test classes are being executed +# sequentially. You may find this helpful for debugging. +concurrent_output="$(bazel run :concurrent)" + +# The logs for each test class should be buffered. See `tests/test-frameworks/concurrency/package.scala` to understand +# why these messages being printed in sequence is proof of that. +echo "$concurrent_output" | grep 'Spec\d executing... +Spec\d executing... +Spec\d executing... +Spec\d executing... +' + +# The output for each test class shouldn't be interleaved +for i in {1..4}; do + echo "$concurrent_output" | grep "com.lucidchart.example.Spec$i + Spec$i + + + + Total for specification Spec$i +" +done + +sequential_output="$(bazel run :sequential)" + +# The test classes should run in order +echo "$sequential_output" | grep ' +com\.lucidchart\.example\.Spec1 +Spec1 executing\.\.\. + Spec1 + + + + Total for specification Spec1 + Finished in .* +0 example, 0 failure, 0 error + + +com\.lucidchart\.example\.Spec2 +Spec2 executing\.\.\. + Spec2 + + + + Total for specification Spec2 + Finished in .* +0 example, 0 failure, 0 error + + +com\.lucidchart\.example\.Spec3 +Spec3 executing\.\.\. + Spec3 + + + + Total for specification Spec3 + Finished in .* +0 example, 0 failure, 0 error + + +com\.lucidchart\.example\.Spec4 +Spec4 executing\.\.\. + Spec4 + + + + Total for specification Spec4 + Finished in .* +0 example, 0 failure, 0 error +'