diff --git a/spark/client/base/base.go b/spark/client/base/base.go index c5f01c7..fcebc5c 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -44,6 +44,12 @@ type SparkConnectClient interface { SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2 *generated.Plan) (bool, error) SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error) Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) + // Interrupt asks the server to interrupt running operations in this session. When interruptType + // is INTERRUPT_TYPE_OPERATION_ID, operationIdOrTag must be the UUID returned via OperationId; + // when INTERRUPT_TYPE_TAG, it must match a tag previously attached to the operation; + // when INTERRUPT_TYPE_ALL, operationIdOrTag is ignored. + Interrupt(ctx context.Context, interruptType generated.InterruptRequest_InterruptType, + operationIdOrTag string) (*generated.InterruptResponse, error) } type ExecuteResponseStream interface { diff --git a/spark/client/client.go b/spark/client/client.go index a0d7a4c..84ffa10 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "time" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -82,13 +83,13 @@ func (s *sparkConnectClientImpl) ExecuteCommand(ctx context.Context, plan *proto } // Append the other items to the request. - ctx = metadata.NewOutgoingContext(ctx, s.metadata) - c, err := s.client.ExecutePlan(ctx, request) + rpcCtx := metadata.NewOutgoingContext(ctx, s.metadata) + c, err := s.client.ExecutePlan(rpcCtx, request) if err != nil { return nil, nil, nil, sparkerrors.WithType( fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) } - respHandler := NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts) + respHandler := newExecuteResponseStream(ctx, s, c, s.sessionId, *request.OperationId, s.opts) schema, table, err := respHandler.ToTable() if err != nil { return nil, nil, nil, err @@ -100,13 +101,38 @@ func (s *sparkConnectClientImpl) ExecutePlan(ctx context.Context, plan *proto.Pl request := s.newExecutePlanRequest(plan) // Append the other items to the request. - ctx = metadata.NewOutgoingContext(ctx, s.metadata) - c, err := s.client.ExecutePlan(ctx, request) + rpcCtx := metadata.NewOutgoingContext(ctx, s.metadata) + c, err := s.client.ExecutePlan(rpcCtx, request) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf( "failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) } - return NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts), nil + return newExecuteResponseStream(ctx, s, c, s.sessionId, *request.OperationId, s.opts), nil +} + +// Interrupt asks the server to cancel running operations on this session. +// See base.SparkConnectClient.Interrupt for the meaning of interruptType. +func (s *sparkConnectClientImpl) Interrupt(ctx context.Context, + interruptType proto.InterruptRequest_InterruptType, operationIdOrTag string, +) (*proto.InterruptResponse, error) { + request := &proto.InterruptRequest{ + SessionId: s.sessionId, + UserContext: &proto.UserContext{UserId: s.opts.UserId}, + ClientType: &s.opts.UserAgent, + InterruptType: interruptType, + } + switch interruptType { + case proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID: + request.Interrupt = &proto.InterruptRequest_OperationId{OperationId: operationIdOrTag} + case proto.InterruptRequest_INTERRUPT_TYPE_TAG: + request.Interrupt = &proto.InterruptRequest_OperationTag{OperationTag: operationIdOrTag} + } + ctx = metadata.NewOutgoingContext(ctx, s.metadata) + resp, err := s.client.Interrupt(ctx, request) + if se := sparkerrors.FromRPCError(err); se != nil { + return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError) + } + return resp, nil } // Creates a new AnalyzePlanRequest with the necessary metadata. @@ -348,10 +374,20 @@ type ExecutePlanClient struct { // The schema of the result of the operation. schema *types.StructType // The sessionId is ised to verify the server side session. - sessionId string - done bool - properties map[string]any - opts options.SparkClientOptions + sessionId string + // operationId identifies this execution on the server so we can send an InterruptRequest + // for it when the caller's context is cancelled. + operationId string + // callerCtx is the caller's original context (without the gRPC outgoing-metadata wrapping). + // When it is cancelled, ToTable fires Interrupt(OPERATION_ID, operationId) so the server + // stops the running query instead of waiting for the 5-minute idle timeout. + callerCtx context.Context + // interrupter is used to send the InterruptRequest. May be nil in tests; nil disables the + // cancellation-watcher path. + interrupter base.SparkConnectClient + done bool + properties map[string]any + opts options.SparkClientOptions } func (c *ExecutePlanClient) Properties() map[string]any { @@ -364,6 +400,25 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { var arrowSchema *arrow.Schema recordBatches = make([]arrow.Record, 0) + // When the caller cancels their context, also tell the server to interrupt the running + // operation. Without this the gRPC stream tears down locally but the server keeps executing + // the query until its idle timeout — see issue #126. + if c.callerCtx != nil && c.interrupter != nil && c.operationId != "" { + watcherDone := make(chan struct{}) + defer close(watcherDone) + go func() { + select { + case <-c.callerCtx.Done(): + // Use a detached context with a short deadline — the caller's ctx is already done. + killCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.interrupter.Interrupt(killCtx, + proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, c.operationId) + case <-watcherDone: + } + }() + } + // Explicitly needed when tracking re-attachble execution. c.done = false for { @@ -434,6 +489,9 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } +// NewExecuteResponseStream wraps a raw gRPC ExecutePlan stream. It does not arm the context +// cancellation watcher — callers that want server-side interrupt on ctx cancellation should +// use newExecuteResponseStream instead. func NewExecuteResponseStream( responseClient proto.SparkConnectService_ExecutePlanClient, sessionId string, @@ -443,6 +501,29 @@ func NewExecuteResponseStream( return &ExecutePlanClient{ responseStream: responseClient, sessionId: sessionId, + operationId: operationId, + done: false, + properties: make(map[string]any), + opts: opts, + } +} + +// newExecuteResponseStream wraps a gRPC ExecutePlan stream and remembers the caller's context +// plus a client back-reference so ToTable can fire an InterruptRequest when ctx is cancelled. +func newExecuteResponseStream( + callerCtx context.Context, + interrupter base.SparkConnectClient, + responseClient proto.SparkConnectService_ExecutePlanClient, + sessionId string, + operationId string, + opts options.SparkClientOptions, +) base.ExecuteResponseStream { + return &ExecutePlanClient{ + responseStream: responseClient, + sessionId: sessionId, + operationId: operationId, + callerCtx: callerCtx, + interrupter: interrupter, done: false, properties: make(map[string]any), opts: opts, diff --git a/spark/client/client_test.go b/spark/client/client_test.go index ca7019c..1c4804d 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -18,8 +18,13 @@ package client_test import ( "context" "testing" + "time" "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" proto "github.com/apache/spark-connect-go/internal/generated" "github.com/apache/spark-connect-go/spark/client" @@ -108,3 +113,125 @@ func Test_Execute_SchemaParsingFails(t *testing.T) { _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) assert.ErrorIs(t, err, sparkerrors.ExecutionError) } + +// blockingStream is a mock ExecutePlan client whose Recv blocks until release is closed, +// then returns a Canceled status — emulating a long-running server-side query. +type blockingStream struct { + proto.SparkConnectService_ExecutePlanClient + release chan struct{} +} + +func (b *blockingStream) Recv() (*proto.ExecutePlanResponse, error) { + <-b.release + return nil, status.Error(codes.Canceled, "canceled") +} + +func (b *blockingStream) Header() (metadata.MD, error) { return nil, nil } +func (b *blockingStream) Trailer() metadata.MD { return nil } +func (b *blockingStream) CloseSend() error { return nil } +func (b *blockingStream) Context() context.Context { return context.Background() } +func (b *blockingStream) SendMsg(any) error { return nil } +func (b *blockingStream) RecvMsg(any) error { return nil } + +// interruptRecorder wraps the testutils mock and records Interrupt invocations. +type interruptRecorder struct { + proto.SparkConnectServiceClient + calls chan *proto.InterruptRequest + release chan struct{} +} + +func (i *interruptRecorder) Interrupt(ctx context.Context, in *proto.InterruptRequest, + opts ...grpc.CallOption, +) (*proto.InterruptResponse, error) { + i.calls <- in + // Unblock the streaming Recv so ToTable returns. + select { + case <-i.release: + default: + close(i.release) + } + return &proto.InterruptResponse{SessionId: in.SessionId}, nil +} + +// Regression test for issue #126: cancelling the caller's context during Collect/ExecutePlan +// must send a server-side InterruptRequest with the operation ID, not just tear down the +// gRPC stream locally. +func TestExecutePlanCancellingContextSendsInterrupt(t *testing.T) { + release := make(chan struct{}) + stream := &blockingStream{release: release} + + underlying := testutils.NewConnectServiceClientMock(stream, nil, nil, t) + recorder := &interruptRecorder{ + SparkConnectServiceClient: underlying, + calls: make(chan *proto.InterruptRequest, 1), + release: release, + } + + c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId) + + ctx, cancel := context.WithCancel(context.Background()) + stream2, err := c.ExecutePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + + done := make(chan error, 1) + go func() { + _, _, err := stream2.ToTable() + done <- err + }() + + // Give the watcher goroutine a moment to be wired up, then cancel. + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case req := <-recorder.calls: + assert.Equal(t, proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, req.InterruptType) + assert.NotEmpty(t, req.GetOperationId()) + assert.Equal(t, mocks.MockSessionId, req.SessionId) + case <-time.After(2 * time.Second): + t.Fatal("Interrupt was not invoked within 2s of ctx cancellation") + } + + // ToTable should also unwind once Recv returns an error. + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("ToTable did not return after Interrupt") + } +} + +func TestInterruptAllCallsClient(t *testing.T) { + release := make(chan struct{}) + close(release) + recorder := &interruptRecorder{ + SparkConnectServiceClient: testutils.NewConnectServiceClientMock(nil, nil, nil, t), + calls: make(chan *proto.InterruptRequest, 1), + release: release, + } + c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId) + + resp, err := c.Interrupt(context.Background(), proto.InterruptRequest_INTERRUPT_TYPE_ALL, "") + assert.NoError(t, err) + assert.NotNil(t, resp) + req := <-recorder.calls + assert.Equal(t, proto.InterruptRequest_INTERRUPT_TYPE_ALL, req.InterruptType) + assert.Nil(t, req.Interrupt) +} + +func TestInterruptOperationCallsClient(t *testing.T) { + release := make(chan struct{}) + close(release) + recorder := &interruptRecorder{ + SparkConnectServiceClient: testutils.NewConnectServiceClientMock(nil, nil, nil, t), + calls: make(chan *proto.InterruptRequest, 1), + release: release, + } + c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId) + + opID := uuid.NewString() + _, err := c.Interrupt(context.Background(), proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, opID) + assert.NoError(t, err) + req := <-recorder.calls + assert.Equal(t, proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, req.InterruptType) + assert.Equal(t, opID, req.GetOperationId()) +} diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go index 600e9e0..c3dd107 100644 --- a/spark/mocks/mock_executor.go +++ b/spark/mocks/mock_executor.go @@ -89,3 +89,9 @@ func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan) ( func (t *TestExecutor) Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) { return nil, errors.New("not implemented") } + +func (t *TestExecutor) Interrupt(ctx context.Context, + interruptType generated.InterruptRequest_InterruptType, operationIdOrTag string, +) (*generated.InterruptResponse, error) { + return nil, errors.New("not implemented") +} diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index a84bb61..d7189bf 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -48,6 +48,12 @@ type SparkSession interface { CreateDataFrameFromArrow(ctx context.Context, data arrow.Table) (DataFrame, error) CreateDataFrame(ctx context.Context, data [][]any, schema *types.StructType) (DataFrame, error) Config() client.RuntimeConfig + // InterruptAll cancels every running operation in this session. + InterruptAll(ctx context.Context) ([]string, error) + // InterruptTag cancels every running operation tagged with tag. + InterruptTag(ctx context.Context, tag string) ([]string, error) + // InterruptOperation cancels the operation with the given operation id. + InterruptOperation(ctx context.Context, operationId string) ([]string, error) } // NewSessionBuilder creates a new session builder for starting a new spark session @@ -167,6 +173,30 @@ func (s *sparkSessionImpl) Stop() error { return nil } +func (s *sparkSessionImpl) InterruptAll(ctx context.Context) ([]string, error) { + resp, err := s.client.Interrupt(ctx, proto.InterruptRequest_INTERRUPT_TYPE_ALL, "") + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to interrupt all: %w", err), sparkerrors.ExecutionError) + } + return resp.GetInterruptedIds(), nil +} + +func (s *sparkSessionImpl) InterruptTag(ctx context.Context, tag string) ([]string, error) { + resp, err := s.client.Interrupt(ctx, proto.InterruptRequest_INTERRUPT_TYPE_TAG, tag) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to interrupt tag %s: %w", tag, err), sparkerrors.ExecutionError) + } + return resp.GetInterruptedIds(), nil +} + +func (s *sparkSessionImpl) InterruptOperation(ctx context.Context, operationId string) ([]string, error) { + resp, err := s.client.Interrupt(ctx, proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, operationId) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to interrupt operation %s: %w", operationId, err), sparkerrors.ExecutionError) + } + return resp.GetInterruptedIds(), nil +} + func (s *sparkSessionImpl) Table(name string) (DataFrame, error) { return s.Read().Table(name) }