From 23157a3d05570e75c29315d18476970ccce13b6a Mon Sep 17 00:00:00 2001 From: Jaden Peterson Date: Mon, 23 Mar 2026 17:22:50 -0400 Subject: [PATCH] Correctly handle cancellation in WorkerMain Although we intended to unwrap fatal exceptions like `InterruptedException` in `CancellableTask`, we were actually throwing them outside of the underlying `Future` instead of failing the `Future`. That meant that `InterruptedException`s (and other fatal errors) floated up the call stack to the thread pool where we execute work requests, crashing the worker. Here's a detailed breakdown of what would happen: 1. `InterruptedException` is thrown inside `WorkerMain#work` 2. That failure is caught inside the `recover` call in `CancellableTask` 3. The exception is re-thrown, but because `recover` only catches non-fatal exceptions, the exception floats up 4. Eventually, it reaches the top of the call stack and the worker dies This commit fixes that issue. I've verified that Bazel acknowledges actions' being cancelled from the worker's end when these changes, and an intentional `throw new InterruptedException(???)` in `ZincRunner`, are applied. --- .../common/worker/CancellableTask.scala | 18 ++++--- .../common/worker/WorkerMain.scala | 49 ++++++++++--------- tests/cancellation/AlwaysCancelWorker.scala | 18 +++++++ tests/cancellation/BUILD | 1 + tests/cancellation/CancelSpec.scala | 20 +++++++- 5 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 tests/cancellation/AlwaysCancelWorker.scala diff --git a/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala b/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala index 8237f595..e24e8d64 100644 --- a/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala +++ b/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala @@ -1,7 +1,7 @@ package higherkindness.rules_scala.common.worker import java.util.concurrent.Callable -import scala.concurrent.{ExecutionContext, ExecutionException, Future, Promise} +import scala.concurrent.{ExecutionContext, Future, Promise} import scala.util.Try /** @@ -14,6 +14,9 @@ import scala.util.Try * * Heavily inspired by the following: https://github.com/NthPortal/cancellable-task/tree/master * https://stackoverflow.com/a/39986418/6442597 + * + * Note that for complicated reasons explained in its implementation, `CancellableTask` wraps all exceptions thrown + * within the task in an `ExecutionException`, so be sure to unwrap them. */ class CancellableTask[S] private (fn: Function1[Function0[Boolean], S]) { private val promise = Promise[S]() @@ -25,13 +28,12 @@ class CancellableTask[S] private (fn: Function1[Function0[Boolean], S]) { private val task = new FutureTaskWaitOnCancel[S](fnCallable) { override def done() = promise.complete { - Try(get()).recover { - // FutureTask wraps exceptions in an ExecutionException. We want to re-throw the underlying - // error because Scala's Future handles things like fatal exception in a special way that - // we miss out on if they're wrapped in that ExecutionException. Put another way: leaving - // them wrapped in the ExecutionException breaks the contract that Scala Future users expect. - case e: ExecutionException => throw e.getCause() - } + // `FutureTask` wraps exceptions in an `ExecutionException`. Although we'd like `FutureTask` to function exactly + // like a `Future` (which doesn't wrap exceptions like this), we can't unwrap fatal exceptions. That's because + // `promise.complete` will just re-wrap fatal exceptions in an `ExecutionException`. To be consistent about how we + // handle various exceptions, we leave all exceptions unwrapped and declare it the responsibility of the user to + // unwrap the exceptions they wish to handle. + Try(get()) } } diff --git a/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala b/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala index 060a9004..740a0904 100644 --- a/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala +++ b/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala @@ -183,31 +183,31 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream writeResponse(requestId, maybeOutStream, Some(code)) logVerbose(s"WorkResponse for request id: $requestId sent with code $code") - case Failure(e: ExecutionException) => - e.getCause() match { - // Task successfully cancelled - case cancelError: InterruptedException => - flushOut() - writeResponse(requestId, None, None, wasCancelled = true) - logVerbose( - s"Cancellation WorkResponse sent for request id: $requestId in response to an" + - " InterruptedException", - ) - // Work task threw a non-fatal error - case e => - maybeOut.map(e.printStackTrace(_)) - flushOut() - writeResponse(requestId, maybeOutStream, Some(-1)) - logVerbose( - "Encountered an uncaught exception that was wrapped in an ExecutionException while" + - s" proccessing the Future for WorkRequest id: $requestId. This usually means a non-fatal" + - " error was thrown in the Future.", - ) - e.printStackTrace(System.err) - } + // `CancellableTask` wraps all exceptions in `ExecutionException`, so we need to unwrap them here + case Failure(e: ExecutionException) + if e.getCause().isInstanceOf[CancellationException] + || e.getCause().isInstanceOf[ClosedByInterruptException] + || e.getCause().isInstanceOf[InterruptedException] => + flushOut() + writeResponse(requestId, None, None, wasCancelled = true) + logVerbose( + s"Cancellation WorkResponse sent for request id: $requestId in response to a ${e.getCause().getClass.getCanonicalName}", + ) + + // Work task threw an uncaught exception + case Failure(e: ExecutionException) if e.getCause() != null => + maybeOut.map(e.getCause().printStackTrace(_)) + flushOut() + writeResponse(requestId, maybeOutStream, Some(-1)) + logVerbose( + s"Uncaught exception in Future while proccessing WorkRequest id: $requestId\nType: ${e.getCause().getClass.getCanonicalName}", + ) + e.getCause().printStackTrace(System.err) // Task successfully cancelled - case Failure(e @ (_: CancellationException | _: ClosedByInterruptException)) => + case Failure( + e @ (_: CancellationException | _: ClosedByInterruptException | _: InterruptedException), + ) => flushOut() writeResponse(requestId, None, None, wasCancelled = true) logVerbose( @@ -215,7 +215,8 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream e.getClass.getCanonicalName, ) - // Work task threw an uncaught exception + // Work task threw an uncaught exception. This branch should never be activated because of the + // exception wrapping described above, but it never hurts to be defensive ¯\_(ツ)_/¯ case Failure(e) => maybeOut.map(e.printStackTrace(_)) flushOut() diff --git a/tests/cancellation/AlwaysCancelWorker.scala b/tests/cancellation/AlwaysCancelWorker.scala new file mode 100644 index 00000000..6af614b4 --- /dev/null +++ b/tests/cancellation/AlwaysCancelWorker.scala @@ -0,0 +1,18 @@ +package anx.cancellation + +import higherkindness.rules_scala.common.worker.{WorkerMain, WorkTask} + +import java.io.{InputStream, PrintStream} + +/** + * A worker that immediately throws an InterruptedException. This simulates what happens when a worker calls + * [[higherkindness.rules_scala.common.interrupt.InterruptUtil.throwIfInterrupted]] and finds the thread has been + * interrupted. + */ +class AlwaysCancelWorker(stdin: InputStream, stdout: PrintStream) + extends WorkerMain[Unit](stdin, stdout) { + override def init(args: Option[Array[String]]): Unit = () + override def work(task: WorkTask[Unit]): Unit = { + throw new InterruptedException("WorkRequest was cancelled via Thread interruption.") + } +} diff --git a/tests/cancellation/BUILD b/tests/cancellation/BUILD index c7c3d9e2..ecb51cfd 100644 --- a/tests/cancellation/BUILD +++ b/tests/cancellation/BUILD @@ -3,6 +3,7 @@ load("@rules_scala_annex//rules:scala.bzl", "scala_library", "scala_test") scala_library( name = "cancel-spec-worker", srcs = [ + "AlwaysCancelWorker.scala", "RunnerForCancelSpec.scala", ], scala_toolchain_name = "test_zinc_2_13", diff --git a/tests/cancellation/CancelSpec.scala b/tests/cancellation/CancelSpec.scala index 2a3e3829..4a7f7118 100644 --- a/tests/cancellation/CancelSpec.scala +++ b/tests/cancellation/CancelSpec.scala @@ -139,4 +139,22 @@ class CancelSpec extends AnyFlatSpec { } } } -} \ No newline at end of file + + it should "treat an InterruptedException thrown by the worker as a cancellation" in { + val requestId = 1 + val workRequest = WorkerTestUtil.getWorkRequest(requestId) + + WorkerTestUtil.withIOStreams { (testOut, testIn, workerStdOut, workerStdIn) => + val worker = new AlwaysCancelWorker(workerStdIn, workerStdOut) + + Future(worker.main(Array("--persistent_worker")))(ExecutionContext.global) + + workRequest.writeDelimitedTo(testOut) + + val response = WorkerProtocol.WorkResponse.parseDelimitedFrom(testIn) + + assert(response.getRequestId() == requestId) + assert(response.getWasCancelled()) + } + } +}