diff --git a/src/saturn_engine/worker/executors/executable.py b/src/saturn_engine/worker/executors/executable.py index 06a6b5d7..2a010e5d 100644 --- a/src/saturn_engine/worker/executors/executable.py +++ b/src/saturn_engine/worker/executors/executable.py @@ -49,6 +49,7 @@ def __init__( self.output = output self.resources: dict[str, ResourceContext] = {} self.queue = queue + self.is_cancelled = False @property def id(self) -> str: @@ -93,6 +94,9 @@ def saturn_context(self) -> t.Iterator[None]: with job_context(self.queue.definition), message_context(self.message.message): yield + def cancel(self) -> None: + self.is_cancelled = True + class ExecutableQueue: def __init__( diff --git a/src/saturn_engine/worker/executors/queue.py b/src/saturn_engine/worker/executors/queue.py index d14e21b3..d368d5ba 100644 --- a/src/saturn_engine/worker/executors/queue.py +++ b/src/saturn_engine/worker/executors/queue.py @@ -22,6 +22,10 @@ from .executable import ExecutableMessage +class MessageCancelled(Exception): + pass + + class ExecutorQueue: CLOSE_TIMEOUT = datetime.timedelta(seconds=60) @@ -68,6 +72,8 @@ async def scope( xmsg: ExecutableMessage, ) -> PipelineResults: try: + if xmsg.is_cancelled: + raise MessageCancelled return await self.executor.process_message(xmsg) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info()