diff --git a/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala b/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala index 8237f595..e24e8d64 100644 --- a/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala +++ b/src/main/scala/higherkindness/rules_scala/common/worker/CancellableTask.scala @@ -1,7 +1,7 @@ package higherkindness.rules_scala.common.worker import java.util.concurrent.Callable -import scala.concurrent.{ExecutionContext, ExecutionException, Future, Promise} +import scala.concurrent.{ExecutionContext, Future, Promise} import scala.util.Try /** @@ -14,6 +14,9 @@ import scala.util.Try * * Heavily inspired by the following: https://github.com/NthPortal/cancellable-task/tree/master * https://stackoverflow.com/a/39986418/6442597 + * + * Note that for complicated reasons explained in its implementation, `CancellableTask` wraps all exceptions thrown + * within the task in an `ExecutionException`, so be sure to unwrap them. */ class CancellableTask[S] private (fn: Function1[Function0[Boolean], S]) { private val promise = Promise[S]() @@ -25,13 +28,12 @@ class CancellableTask[S] private (fn: Function1[Function0[Boolean], S]) { private val task = new FutureTaskWaitOnCancel[S](fnCallable) { override def done() = promise.complete { - Try(get()).recover { - // FutureTask wraps exceptions in an ExecutionException. We want to re-throw the underlying - // error because Scala's Future handles things like fatal exception in a special way that - // we miss out on if they're wrapped in that ExecutionException. Put another way: leaving - // them wrapped in the ExecutionException breaks the contract that Scala Future users expect. - case e: ExecutionException => throw e.getCause() - } + // `FutureTask` wraps exceptions in an `ExecutionException`. Although we'd like `FutureTask` to function exactly + // like a `Future` (which doesn't wrap exceptions like this), we can't unwrap fatal exceptions. That's because + // `promise.complete` will just re-wrap fatal exceptions in an `ExecutionException`. To be consistent about how we + // handle various exceptions, we leave all exceptions unwrapped and declare it the responsibility of the user to + // unwrap the exceptions they wish to handle. + Try(get()) } } diff --git a/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala b/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala index 060a9004..740a0904 100644 --- a/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala +++ b/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala @@ -183,31 +183,31 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream writeResponse(requestId, maybeOutStream, Some(code)) logVerbose(s"WorkResponse for request id: $requestId sent with code $code") - case Failure(e: ExecutionException) => - e.getCause() match { - // Task successfully cancelled - case cancelError: InterruptedException => - flushOut() - writeResponse(requestId, None, None, wasCancelled = true) - logVerbose( - s"Cancellation WorkResponse sent for request id: $requestId in response to an" + - " InterruptedException", - ) - // Work task threw a non-fatal error - case e => - maybeOut.map(e.printStackTrace(_)) - flushOut() - writeResponse(requestId, maybeOutStream, Some(-1)) - logVerbose( - "Encountered an uncaught exception that was wrapped in an ExecutionException while" + - s" proccessing the Future for WorkRequest id: $requestId. This usually means a non-fatal" + - " error was thrown in the Future.", - ) - e.printStackTrace(System.err) - } + // `CancellableTask` wraps all exceptions in `ExecutionException`, so we need to unwrap them here + case Failure(e: ExecutionException) + if e.getCause().isInstanceOf[CancellationException] + || e.getCause().isInstanceOf[ClosedByInterruptException] + || e.getCause().isInstanceOf[InterruptedException] => + flushOut() + writeResponse(requestId, None, None, wasCancelled = true) + logVerbose( + s"Cancellation WorkResponse sent for request id: $requestId in response to a ${e.getCause().getClass.getCanonicalName}", + ) + + // Work task threw an uncaught exception + case Failure(e: ExecutionException) if e.getCause() != null => + maybeOut.map(e.getCause().printStackTrace(_)) + flushOut() + writeResponse(requestId, maybeOutStream, Some(-1)) + logVerbose( + s"Uncaught exception in Future while proccessing WorkRequest id: $requestId\nType: ${e.getCause().getClass.getCanonicalName}", + ) + e.getCause().printStackTrace(System.err) // Task successfully cancelled - case Failure(e @ (_: CancellationException | _: ClosedByInterruptException)) => + case Failure( + e @ (_: CancellationException | _: ClosedByInterruptException | _: InterruptedException), + ) => flushOut() writeResponse(requestId, None, None, wasCancelled = true) logVerbose( @@ -215,7 +215,8 @@ abstract class WorkerMain[S](stdin: InputStream = System.in, stdout: PrintStream e.getClass.getCanonicalName, ) - // Work task threw an uncaught exception + // Work task threw an uncaught exception. This branch should never be activated because of the + // exception wrapping described above, but it never hurts to be defensive ¯\_(ツ)_/¯ case Failure(e) => maybeOut.map(e.printStackTrace(_)) flushOut() diff --git a/tests/cancellation/AlwaysCancelWorker.scala b/tests/cancellation/AlwaysCancelWorker.scala new file mode 100644 index 00000000..6af614b4 --- /dev/null +++ b/tests/cancellation/AlwaysCancelWorker.scala @@ -0,0 +1,18 @@ +package anx.cancellation + +import higherkindness.rules_scala.common.worker.{WorkerMain, WorkTask} + +import java.io.{InputStream, PrintStream} + +/** + * A worker that immediately throws an InterruptedException. This simulates what happens when a worker calls + * [[higherkindness.rules_scala.common.interrupt.InterruptUtil.throwIfInterrupted]] and finds the thread has been + * interrupted. + */ +class AlwaysCancelWorker(stdin: InputStream, stdout: PrintStream) + extends WorkerMain[Unit](stdin, stdout) { + override def init(args: Option[Array[String]]): Unit = () + override def work(task: WorkTask[Unit]): Unit = { + throw new InterruptedException("WorkRequest was cancelled via Thread interruption.") + } +} diff --git a/tests/cancellation/BUILD b/tests/cancellation/BUILD index c7c3d9e2..ecb51cfd 100644 --- a/tests/cancellation/BUILD +++ b/tests/cancellation/BUILD @@ -3,6 +3,7 @@ load("@rules_scala_annex//rules:scala.bzl", "scala_library", "scala_test") scala_library( name = "cancel-spec-worker", srcs = [ + "AlwaysCancelWorker.scala", "RunnerForCancelSpec.scala", ], scala_toolchain_name = "test_zinc_2_13", diff --git a/tests/cancellation/CancelSpec.scala b/tests/cancellation/CancelSpec.scala index 2a3e3829..4a7f7118 100644 --- a/tests/cancellation/CancelSpec.scala +++ b/tests/cancellation/CancelSpec.scala @@ -139,4 +139,22 @@ class CancelSpec extends AnyFlatSpec { } } } -} \ No newline at end of file + + it should "treat an InterruptedException thrown by the worker as a cancellation" in { + val requestId = 1 + val workRequest = WorkerTestUtil.getWorkRequest(requestId) + + WorkerTestUtil.withIOStreams { (testOut, testIn, workerStdOut, workerStdIn) => + val worker = new AlwaysCancelWorker(workerStdIn, workerStdOut) + + Future(worker.main(Array("--persistent_worker")))(ExecutionContext.global) + + workRequest.writeDelimitedTo(testOut) + + val response = WorkerProtocol.WorkResponse.parseDelimitedFrom(testIn) + + assert(response.getRequestId() == requestId) + assert(response.getWasCancelled()) + } + } +}