From 3327132a55427bcb8904f2d17fa30f5b6d4f267b Mon Sep 17 00:00:00 2001 From: Kannan Rajah Date: Tue, 17 Feb 2026 18:32:05 -0800 Subject: [PATCH] Add ActivityCommandTask dispatch via Nexus - Add activityCommandTaskDispatcher to dispatch tasks to workers - Use typed Nexus service definition for worker communication - Add proper error handling for DispatchNexusTask response - Release workflow lock before making RPC call - Add functional test for dispatch flow --- .../activity_command_task_dispatcher.go | 218 ++++++++++++++++++ .../outbound_queue_active_task_executor.go | 30 ++- ...utbound_queue_active_task_executor_test.go | 2 + service/history/outbound_queue_factory.go | 4 + tests/activity_command_task_test.go | 149 ++++++++++++ 5 files changed, 396 insertions(+), 7 deletions(-) create mode 100644 service/history/activity_command_task_dispatcher.go create mode 100644 tests/activity_command_task_test.go diff --git a/service/history/activity_command_task_dispatcher.go b/service/history/activity_command_task_dispatcher.go new file mode 100644 index 0000000000..a092abbb95 --- /dev/null +++ b/service/history/activity_command_task_dispatcher.go @@ -0,0 +1,218 @@ +package history + +import ( + "context" + "fmt" + "time" + + enumspb "go.temporal.io/api/enums/v1" + nexuspb "go.temporal.io/api/nexus/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + workerpb "go.temporal.io/api/worker/v1" + enumsspb "go.temporal.io/server/api/enums/v1" + "go.temporal.io/server/api/matchingservice/v1" + tokenspb "go.temporal.io/server/api/token/v1" + "go.temporal.io/server/common" + "go.temporal.io/server/common/debug" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/payload" + "go.temporal.io/server/common/resource" + "go.temporal.io/server/service/history/configs" + historyi "go.temporal.io/server/service/history/interfaces" + "go.temporal.io/server/service/history/tasks" + wcache "go.temporal.io/server/service/history/workflow/cache" +) + +const ( + activityCommandTaskTimeout = time.Second * 10 * debug.TimeoutMultiplier +) + +// activityCommandTaskDispatcher handles dispatching activity command tasks to workers. +type activityCommandTaskDispatcher struct { + shardContext historyi.ShardContext + cache wcache.Cache + matchingRawClient resource.MatchingRawClient + config *configs.Config + metricsHandler metrics.Handler + logger log.Logger +} + +func newActivityCommandTaskDispatcher( + shardContext historyi.ShardContext, + cache wcache.Cache, + matchingRawClient resource.MatchingRawClient, + config *configs.Config, + metricsHandler metrics.Handler, + logger log.Logger, +) *activityCommandTaskDispatcher { + return &activityCommandTaskDispatcher{ + shardContext: shardContext, + cache: cache, + matchingRawClient: matchingRawClient, + config: config, + metricsHandler: metricsHandler, + logger: logger, + } +} + +func (d *activityCommandTaskDispatcher) execute( + ctx context.Context, + task *tasks.ActivityCommandTask, +) error { + if !d.config.EnableActivityCancellationNexusTask() { + return nil + } + + if len(task.ScheduledEventIDs) == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, activityCommandTaskTimeout) + defer cancel() + + taskTokens, err := d.buildTaskTokens(ctx, task) + if err != nil { + return err + } + if len(taskTokens) == 0 { + return nil + } + + return d.dispatchToWorker(ctx, task, taskTokens) +} + +// buildTaskTokens loads mutable state and builds task tokens for activities that need commands. +// Lock is acquired and released within this method. +func (d *activityCommandTaskDispatcher) buildTaskTokens( + ctx context.Context, + task *tasks.ActivityCommandTask, +) ([][]byte, error) { + weContext, release, err := getWorkflowExecutionContextForTask(ctx, d.shardContext, d.cache, task) + if err != nil { + return nil, err + } + defer release(nil) + + mutableState, err := weContext.LoadMutableState(ctx, d.shardContext) + if err != nil { + return nil, err + } + if mutableState == nil || !mutableState.IsWorkflowExecutionRunning() { + return nil, nil + } + + var taskTokens [][]byte + for _, scheduledEventID := range task.ScheduledEventIDs { + ai, ok := mutableState.GetActivityInfo(scheduledEventID) + if !ok || ai.StartedEventId == common.EmptyEventID { + continue + } + if task.CommandType == enumsspb.ACTIVITY_COMMAND_TYPE_CANCEL && !ai.CancelRequested { + continue + } + + taskToken := &tokenspb.Task{ + NamespaceId: task.NamespaceID, + WorkflowId: task.WorkflowID, + RunId: task.RunID, + ScheduledEventId: scheduledEventID, + Attempt: ai.Attempt, + ActivityId: ai.ActivityId, + StartedEventId: ai.StartedEventId, + Version: ai.Version, + } + taskTokenBytes, err := taskToken.Marshal() + if err != nil { + return nil, err + } + taskTokens = append(taskTokens, taskTokenBytes) + } + return taskTokens, nil +} + +func (d *activityCommandTaskDispatcher) dispatchToWorker( + ctx context.Context, + task *tasks.ActivityCommandTask, + taskTokens [][]byte, +) error { + notificationRequest := &workerpb.ActivityNotificationRequest{ + NotificationType: workerpb.ActivityNotificationType(task.CommandType), + TaskTokens: taskTokens, + } + requestPayload, err := payload.Encode(notificationRequest) + if err != nil { + return fmt.Errorf("failed to encode activity command request: %w", err) + } + + nexusRequest := &nexuspb.Request{ + Header: map[string]string{}, + Variant: &nexuspb.Request_StartOperation{ + StartOperation: &nexuspb.StartOperationRequest{ + Service: workerpb.WorkerService.ServiceName, + Operation: workerpb.WorkerService.NotifyActivity.Name(), + Payload: requestPayload, + }, + }, + } + + resp, err := d.matchingRawClient.DispatchNexusTask(ctx, &matchingservice.DispatchNexusTaskRequest{ + NamespaceId: task.NamespaceID, + TaskQueue: &taskqueuepb.TaskQueue{ + Name: task.Destination, + Kind: enumspb.TASK_QUEUE_KIND_NORMAL, + }, + Request: nexusRequest, + }) + if err != nil { + d.logger.Warn("Failed to dispatch activity command to worker", + tag.NewStringTag("control_queue", task.Destination), + tag.Error(err)) + return err + } + + return d.handleDispatchResponse(resp, task.Destination) +} + +func (d *activityCommandTaskDispatcher) handleDispatchResponse( + resp *matchingservice.DispatchNexusTaskResponse, + controlQueue string, +) error { + // Check for timeout (no worker polling) + if resp.GetRequestTimeout() != nil { + d.logger.Warn("No worker polling control queue for activity command", + tag.NewStringTag("control_queue", controlQueue)) + return fmt.Errorf("no worker polling control queue") + } + + // Check for worker handler failure + if failure := resp.GetFailure(); failure != nil { + d.logger.Warn("Worker handler failed for activity command", + tag.NewStringTag("control_queue", controlQueue), + tag.NewStringTag("failure_message", failure.GetMessage())) + return fmt.Errorf("worker handler failed: %s", failure.GetMessage()) + } + + // Check operation-level response + nexusResp := resp.GetResponse() + if nexusResp == nil { + return nil + } + + startOpResp := nexusResp.GetStartOperation() + if startOpResp == nil { + return nil + } + + // Check for operation failure (terminal - don't retry) + if opFailure := startOpResp.GetFailure(); opFailure != nil { + d.logger.Warn("Activity command operation failure", + tag.NewStringTag("control_queue", controlQueue), + tag.NewStringTag("failure_message", opFailure.GetMessage())) + return nil + } + + return nil +} + diff --git a/service/history/outbound_queue_active_task_executor.go b/service/history/outbound_queue_active_task_executor.go index bc3af18822..a33bf9066c 100644 --- a/service/history/outbound_queue_active_task_executor.go +++ b/service/history/outbound_queue_active_task_executor.go @@ -10,6 +10,8 @@ import ( "go.temporal.io/server/common/debug" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/resource" + "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/consts" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/queues" @@ -24,7 +26,8 @@ const ( type outboundQueueActiveTaskExecutor struct { stateMachineEnvironment - chasmEngine chasm.Engine + chasmEngine chasm.Engine + activityCommandTaskDispatcher *activityCommandTaskDispatcher } var _ queues.Executor = &outboundQueueActiveTaskExecutor{} @@ -35,17 +38,28 @@ func newOutboundQueueActiveTaskExecutor( logger log.Logger, metricsHandler metrics.Handler, chasmEngine chasm.Engine, + matchingRawClient resource.MatchingRawClient, + config *configs.Config, ) *outboundQueueActiveTaskExecutor { + scopedMetricsHandler := metricsHandler.WithTags( + metrics.OperationTag(metrics.OperationOutboundQueueProcessorScope), + ) return &outboundQueueActiveTaskExecutor{ stateMachineEnvironment: stateMachineEnvironment{ - shardContext: shardCtx, - cache: workflowCache, - logger: logger, - metricsHandler: metricsHandler.WithTags( - metrics.OperationTag(metrics.OperationOutboundQueueProcessorScope), - ), + shardContext: shardCtx, + cache: workflowCache, + logger: logger, + metricsHandler: scopedMetricsHandler, }, chasmEngine: chasmEngine, + activityCommandTaskDispatcher: newActivityCommandTaskDispatcher( + shardCtx, + workflowCache, + matchingRawClient, + config, + scopedMetricsHandler, + logger, + ), } } @@ -92,6 +106,8 @@ func (e *outboundQueueActiveTaskExecutor) Execute( return respond(e.executeStateMachineTask(ctx, task)) case *tasks.ChasmTask: return respond(e.executeChasmSideEffectTask(ctx, task)) + case *tasks.ActivityCommandTask: + return respond(e.activityCommandTaskDispatcher.execute(ctx, task)) } return respond(queueserrors.NewUnprocessableTaskError(fmt.Sprintf("unknown task type '%T'", task))) diff --git a/service/history/outbound_queue_active_task_executor_test.go b/service/history/outbound_queue_active_task_executor_test.go index 511b4e2810..dfcc860973 100644 --- a/service/history/outbound_queue_active_task_executor_test.go +++ b/service/history/outbound_queue_active_task_executor_test.go @@ -115,6 +115,8 @@ func (s *outboundQueueActiveTaskExecutorSuite) SetupTest() { s.logger, s.metricsHandler, s.mockChasmEngine, + nil, // matchingRawClient - not used in these tests + nil, // config - not used in these tests ) } diff --git a/service/history/outbound_queue_factory.go b/service/history/outbound_queue_factory.go index 77dffccef6..9dcd094405 100644 --- a/service/history/outbound_queue_factory.go +++ b/service/history/outbound_queue_factory.go @@ -10,6 +10,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/quotas" + "go.temporal.io/server/common/resource" ctasks "go.temporal.io/server/common/tasks" "go.temporal.io/server/common/telemetry" "go.temporal.io/server/service/history/circuitbreakerpool" @@ -31,6 +32,7 @@ type outboundQueueFactoryParams struct { QueueFactoryBaseParams CircuitBreakerPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + MatchingRawClient resource.MatchingRawClient } type groupLimiter struct { @@ -227,6 +229,8 @@ func (f *outboundQueueFactory) CreateQueue( logger, metricsHandler, f.ChasmEngine, + f.MatchingRawClient, + shardContext.GetConfig(), ) standbyExecutor := newOutboundQueueStandbyTaskExecutor( diff --git a/tests/activity_command_task_test.go b/tests/activity_command_task_test.go new file mode 100644 index 0000000000..fbcd5dac9c --- /dev/null +++ b/tests/activity_command_task_test.go @@ -0,0 +1,149 @@ +package tests + +import ( + "context" + "testing" + "time" + + commandpb "go.temporal.io/api/command/v1" + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + workerpb "go.temporal.io/api/worker/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/tests/testcore" + "google.golang.org/protobuf/types/known/durationpb" +) + +// TestDispatchCancelToWorker tests that when an activity cancellation is requested, +// the server dispatches an ActivityCommandTask to the worker's control queue via Nexus. +func TestDispatchCancelToWorker(t *testing.T) { + env := testcore.NewEnv(t, testcore.WithDynamicConfig(dynamicconfig.EnableActivityCancellationNexusTask, true)) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + tv := env.Tv() + poller := env.TaskPoller() + + // Get the control queue name from test vars + controlQueueName := tv.ControlQueueName(env.Namespace().String()) + t.Logf("WorkerInstanceKey: %s", tv.WorkerInstanceKey()) + t.Logf("ControlQueueName: %s", controlQueueName) + + // Start the workflow + startResp, err := env.FrontendClient().StartWorkflowExecution(ctx, &workflowservice.StartWorkflowExecutionRequest{ + RequestId: tv.Any().String(), + Namespace: env.Namespace().String(), + WorkflowId: tv.WorkflowID(), + WorkflowType: tv.WorkflowType(), + TaskQueue: tv.TaskQueue(), + WorkflowExecutionTimeout: durationpb.New(60 * time.Second), + WorkflowTaskTimeout: durationpb.New(10 * time.Second), + }) + env.NoError(err) + t.Logf("Started workflow: %s/%s", tv.WorkflowID(), startResp.RunId) + + // Poll and complete first workflow task - schedule the activity + _, err = poller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: []*commandpb.Command{ + { + CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{ + ActivityId: tv.ActivityID(), + ActivityType: tv.ActivityType(), + TaskQueue: tv.TaskQueue(), + ScheduleToCloseTimeout: durationpb.New(60 * time.Second), + StartToCloseTimeout: durationpb.New(60 * time.Second), + }, + }, + }, + }, + }, nil + }) + env.NoError(err) + t.Log("Scheduled activity") + + // Poll for activity task and start running the activity. + activityPollResp, err := env.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: env.Namespace().String(), + TaskQueue: tv.TaskQueue(), + Identity: tv.WorkerIdentity(), + WorkerInstanceKey: tv.WorkerInstanceKey(), + WorkerControlTaskQueue: controlQueueName, + }) + env.NoError(err) + env.NotNil(activityPollResp) + env.NotEmpty(activityPollResp.TaskToken) + t.Log("Activity started with WorkerInstanceKey") + + // Request workflow cancellation + t.Log("Requesting workflow cancellation...") + _, err = env.FrontendClient().RequestCancelWorkflowExecution(ctx, &workflowservice.RequestCancelWorkflowExecutionRequest{ + Namespace: env.Namespace().String(), + WorkflowExecution: &commonpb.WorkflowExecution{ + WorkflowId: tv.WorkflowID(), + RunId: startResp.RunId, + }, + }) + env.NoError(err) + + // Simulate what the SDK does when a workflow is cancelled. + // Poll and complete the workflow task with RequestCancelActivityTask command. + // This sets CancelRequested=true and triggers the dispatch of ActivityCommandTask. + _, err = poller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + // Find the scheduled event ID + var scheduledEventID int64 + for _, event := range task.History.Events { + if event.EventType == enumspb.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED { + scheduledEventID = event.EventId + break + } + } + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: []*commandpb.Command{ + { + CommandType: enumspb.COMMAND_TYPE_REQUEST_CANCEL_ACTIVITY_TASK, + Attributes: &commandpb.Command_RequestCancelActivityTaskCommandAttributes{ + RequestCancelActivityTaskCommandAttributes: &commandpb.RequestCancelActivityTaskCommandAttributes{ + ScheduledEventId: scheduledEventID, + }, + }, + }, + }, + }, nil + }) + env.NoError(err) + t.Log("Workflow task completed with RequestCancelActivityTask command") + + // Poll Nexus control queue until we receive the notification request + var nexusPollResp *workflowservice.PollNexusTaskQueueResponse + env.Eventually(func() bool { + pollCtx, pollCancel := context.WithTimeout(ctx, 5*time.Second) + defer pollCancel() + resp, err := env.FrontendClient().PollNexusTaskQueue(pollCtx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: env.Namespace().String(), + TaskQueue: &taskqueuepb.TaskQueue{Name: controlQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + Identity: tv.WorkerIdentity(), + }) + if err == nil && resp != nil && resp.Request != nil { + nexusPollResp = resp + return true + } + return false + }, 120*time.Second, 100*time.Millisecond, "Timed out waiting for Nexus task") + + // Verify we received the notification request on the control queue + env.NotNil(nexusPollResp.Request, "Expected to receive Nexus request on control queue") + + startOp := nexusPollResp.Request.GetStartOperation() + env.NotNil(startOp, "Expected StartOperation in Nexus request") + env.Equal(workerpb.WorkerService.ServiceName, startOp.Service, "Expected WorkerService") + env.Equal(workerpb.WorkerService.NotifyActivity.Name(), startOp.Operation, "Expected notify-activity operation") + t.Logf("SUCCESS: Received notify-activity Nexus request on control queue") +}