diff --git a/src/main/scala/higherkindness/rules_scala/common/worker/FutureTaskWaitOnCancel.scala b/src/main/scala/higherkindness/rules_scala/common/worker/FutureTaskWaitOnCancel.scala index 2b4e36d3..aa07eb85 100644 --- a/src/main/scala/higherkindness/rules_scala/common/worker/FutureTaskWaitOnCancel.scala +++ b/src/main/scala/higherkindness/rules_scala/common/worker/FutureTaskWaitOnCancel.scala @@ -60,6 +60,12 @@ private class CallableLockedWhileRunning[S](callable: Callable[S]) extends Calla override def call(): S = { isRunning.lock() try { + // Clear any stale interrupt flag from a previous task. `FutureTask` deliberately doesn't clear the interrupt flag + // once task cancellation is complete because thread interruption could be used as an "an independent mechanism + // for a task to communicate with its caller". + // https://github.com/openjdk/jdk/blame/jdk-21%2B35/src/java.base/share/classes/java/util/concurrent/FutureTask.java#L390 + Thread.interrupted() + callable.call() } finally { isRunning.unlock() diff --git a/tests/cancellation/BUILD b/tests/cancellation/BUILD index ecb51cfd..a9ceefac 100644 --- a/tests/cancellation/BUILD +++ b/tests/cancellation/BUILD @@ -31,3 +31,16 @@ scala_test( "@rules_scala_annex//third_party/bazel/src/main/protobuf:worker_protocol_java_proto", ], ) + +scala_test( + name = "cancellabletask-spec", + srcs = ["CancellableTaskSpec.scala"], + scala_toolchain_name = "test_zinc_2_13", + tags = ["manual"], + deps = [ + "@annex_test//:org_scalactic_scalactic_2_13", + "@annex_test//:org_scalatest_scalatest_core_2_13", + "@annex_test//:org_scalatest_scalatest_flatspec_2_13", + "@rules_scala_annex//src/main/scala/higherkindness/rules_scala/common/worker", + ], +) diff --git a/tests/cancellation/CancellableTaskSpec.scala b/tests/cancellation/CancellableTaskSpec.scala new file mode 100644 index 00000000..9abc1371 --- /dev/null +++ b/tests/cancellation/CancellableTaskSpec.scala @@ -0,0 +1,30 @@ +package anx.cancellation + +import higherkindness.rules_scala.common.worker.CancellableTask +import org.scalatest.flatspec.AnyFlatSpec +import java.util.concurrent.ForkJoinPool +import scala.concurrent.{Await, ExecutionContext} +import scala.concurrent.duration.Duration + +class CancellableTaskSpec extends AnyFlatSpec { + "CancellableTask" should "not leak a stale interrupt flag to the next task on the same thread" in { + val threadPool = new ForkJoinPool(1) + val executionContext = ExecutionContext.fromExecutor(threadPool) + + try { + val task1 = CancellableTask((_: () => Boolean) => Thread.currentThread().interrupt()) + val task2 = CancellableTask { (_: () => Boolean) => + if (Thread.interrupted()) { + throw new InterruptedException("Stale interrupt flag leaked from a previous task!") + } + } + + task1.execute(executionContext) + task2.execute(executionContext) + + Await.result(task2.future, Duration.Inf) + } finally { + threadPool.shutdown() + } + } +} diff --git a/tests/cancellation/test b/tests/cancellation/test index 9ce6985a..75fa6a1a 100755 --- a/tests/cancellation/test +++ b/tests/cancellation/test @@ -1,4 +1,4 @@ #!/bin/bash -e . "$(dirname "$0")"/../common.sh -bazel test :cancel-spec +bazel test :cancel-spec :cancellabletask-spec