Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spark/client/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ type SparkConnectClient interface {
SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2 *generated.Plan) (bool, error)
SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error)
Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error)
// Interrupt asks the server to interrupt running operations in this session. When interruptType
// is INTERRUPT_TYPE_OPERATION_ID, operationIdOrTag must be the UUID returned via OperationId;
// when INTERRUPT_TYPE_TAG, it must match a tag previously attached to the operation;
// when INTERRUPT_TYPE_ALL, operationIdOrTag is ignored.
Interrupt(ctx context.Context, interruptType generated.InterruptRequest_InterruptType,
operationIdOrTag string) (*generated.InterruptResponse, error)
}

type ExecuteResponseStream interface {
Expand Down
101 changes: 91 additions & 10 deletions spark/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"io"
"time"

"github.com/apache/spark-connect-go/spark/sql/utils"

Expand Down Expand Up @@ -82,13 +83,13 @@ func (s *sparkConnectClientImpl) ExecuteCommand(ctx context.Context, plan *proto
}

// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(ctx, request)
rpcCtx := metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(rpcCtx, request)
if err != nil {
return nil, nil, nil, sparkerrors.WithType(
fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}
respHandler := NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts)
respHandler := newExecuteResponseStream(ctx, s, c, s.sessionId, *request.OperationId, s.opts)
schema, table, err := respHandler.ToTable()
if err != nil {
return nil, nil, nil, err
Expand All @@ -100,13 +101,38 @@ func (s *sparkConnectClientImpl) ExecutePlan(ctx context.Context, plan *proto.Pl
request := s.newExecutePlanRequest(plan)

// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(ctx, request)
rpcCtx := metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(rpcCtx, request)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf(
"failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}
return NewExecuteResponseStream(c, s.sessionId, *request.OperationId, s.opts), nil
return newExecuteResponseStream(ctx, s, c, s.sessionId, *request.OperationId, s.opts), nil
}

// Interrupt asks the server to cancel running operations on this session.
// See base.SparkConnectClient.Interrupt for the meaning of interruptType.
func (s *sparkConnectClientImpl) Interrupt(ctx context.Context,
interruptType proto.InterruptRequest_InterruptType, operationIdOrTag string,
) (*proto.InterruptResponse, error) {
request := &proto.InterruptRequest{
SessionId: s.sessionId,
UserContext: &proto.UserContext{UserId: s.opts.UserId},
ClientType: &s.opts.UserAgent,
InterruptType: interruptType,
}
switch interruptType {
case proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID:
request.Interrupt = &proto.InterruptRequest_OperationId{OperationId: operationIdOrTag}
case proto.InterruptRequest_INTERRUPT_TYPE_TAG:
request.Interrupt = &proto.InterruptRequest_OperationTag{OperationTag: operationIdOrTag}
}
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
resp, err := s.client.Interrupt(ctx, request)
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return resp, nil
}

// Creates a new AnalyzePlanRequest with the necessary metadata.
Expand Down Expand Up @@ -348,10 +374,20 @@ type ExecutePlanClient struct {
// The schema of the result of the operation.
schema *types.StructType
// The sessionId is ised to verify the server side session.
sessionId string
done bool
properties map[string]any
opts options.SparkClientOptions
sessionId string
// operationId identifies this execution on the server so we can send an InterruptRequest
// for it when the caller's context is cancelled.
operationId string
// callerCtx is the caller's original context (without the gRPC outgoing-metadata wrapping).
// When it is cancelled, ToTable fires Interrupt(OPERATION_ID, operationId) so the server
// stops the running query instead of waiting for the 5-minute idle timeout.
callerCtx context.Context
// interrupter is used to send the InterruptRequest. May be nil in tests; nil disables the
// cancellation-watcher path.
interrupter base.SparkConnectClient
done bool
properties map[string]any
opts options.SparkClientOptions
}

func (c *ExecutePlanClient) Properties() map[string]any {
Expand All @@ -364,6 +400,25 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) {
var arrowSchema *arrow.Schema
recordBatches = make([]arrow.Record, 0)

// When the caller cancels their context, also tell the server to interrupt the running
// operation. Without this the gRPC stream tears down locally but the server keeps executing
// the query until its idle timeout — see issue #126.
if c.callerCtx != nil && c.interrupter != nil && c.operationId != "" {
watcherDone := make(chan struct{})
defer close(watcherDone)
go func() {
select {
case <-c.callerCtx.Done():
// Use a detached context with a short deadline — the caller's ctx is already done.
killCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_, _ = c.interrupter.Interrupt(killCtx,
proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, c.operationId)
case <-watcherDone:
}
}()
}

// Explicitly needed when tracking re-attachble execution.
c.done = false
for {
Expand Down Expand Up @@ -434,6 +489,9 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) {
}
}

// NewExecuteResponseStream wraps a raw gRPC ExecutePlan stream. It does not arm the context
// cancellation watcher — callers that want server-side interrupt on ctx cancellation should
// use newExecuteResponseStream instead.
func NewExecuteResponseStream(
responseClient proto.SparkConnectService_ExecutePlanClient,
sessionId string,
Expand All @@ -443,6 +501,29 @@ func NewExecuteResponseStream(
return &ExecutePlanClient{
responseStream: responseClient,
sessionId: sessionId,
operationId: operationId,
done: false,
properties: make(map[string]any),
opts: opts,
}
}

// newExecuteResponseStream wraps a gRPC ExecutePlan stream and remembers the caller's context
// plus a client back-reference so ToTable can fire an InterruptRequest when ctx is cancelled.
func newExecuteResponseStream(
callerCtx context.Context,
interrupter base.SparkConnectClient,
responseClient proto.SparkConnectService_ExecutePlanClient,
sessionId string,
operationId string,
opts options.SparkClientOptions,
) base.ExecuteResponseStream {
return &ExecutePlanClient{
responseStream: responseClient,
sessionId: sessionId,
operationId: operationId,
callerCtx: callerCtx,
interrupter: interrupter,
done: false,
properties: make(map[string]any),
opts: opts,
Expand Down
127 changes: 127 additions & 0 deletions spark/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ package client_test
import (
"context"
"testing"
"time"

"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

proto "github.com/apache/spark-connect-go/internal/generated"
"github.com/apache/spark-connect-go/spark/client"
Expand Down Expand Up @@ -108,3 +113,125 @@ func Test_Execute_SchemaParsingFails(t *testing.T) {
_, _, _, err := c.ExecuteCommand(ctx, sqlCommand)
assert.ErrorIs(t, err, sparkerrors.ExecutionError)
}

// blockingStream is a mock ExecutePlan client whose Recv blocks until release is closed,
// then returns a Canceled status — emulating a long-running server-side query.
type blockingStream struct {
proto.SparkConnectService_ExecutePlanClient
release chan struct{}
}

func (b *blockingStream) Recv() (*proto.ExecutePlanResponse, error) {
<-b.release
return nil, status.Error(codes.Canceled, "canceled")
}

func (b *blockingStream) Header() (metadata.MD, error) { return nil, nil }
func (b *blockingStream) Trailer() metadata.MD { return nil }
func (b *blockingStream) CloseSend() error { return nil }
func (b *blockingStream) Context() context.Context { return context.Background() }
func (b *blockingStream) SendMsg(any) error { return nil }
func (b *blockingStream) RecvMsg(any) error { return nil }

// interruptRecorder wraps the testutils mock and records Interrupt invocations.
type interruptRecorder struct {
proto.SparkConnectServiceClient
calls chan *proto.InterruptRequest
release chan struct{}
}

func (i *interruptRecorder) Interrupt(ctx context.Context, in *proto.InterruptRequest,
opts ...grpc.CallOption,
) (*proto.InterruptResponse, error) {
i.calls <- in
// Unblock the streaming Recv so ToTable returns.
select {
case <-i.release:
default:
close(i.release)
}
return &proto.InterruptResponse{SessionId: in.SessionId}, nil
}

// Regression test for issue #126: cancelling the caller's context during Collect/ExecutePlan
// must send a server-side InterruptRequest with the operation ID, not just tear down the
// gRPC stream locally.
func TestExecutePlanCancellingContextSendsInterrupt(t *testing.T) {
release := make(chan struct{})
stream := &blockingStream{release: release}

underlying := testutils.NewConnectServiceClientMock(stream, nil, nil, t)
recorder := &interruptRecorder{
SparkConnectServiceClient: underlying,
calls: make(chan *proto.InterruptRequest, 1),
release: release,
}

c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId)

ctx, cancel := context.WithCancel(context.Background())
stream2, err := c.ExecutePlan(ctx, &proto.Plan{})
assert.NoError(t, err)

done := make(chan error, 1)
go func() {
_, _, err := stream2.ToTable()
done <- err
}()

// Give the watcher goroutine a moment to be wired up, then cancel.
time.Sleep(50 * time.Millisecond)
cancel()

select {
case req := <-recorder.calls:
assert.Equal(t, proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, req.InterruptType)
assert.NotEmpty(t, req.GetOperationId())
assert.Equal(t, mocks.MockSessionId, req.SessionId)
case <-time.After(2 * time.Second):
t.Fatal("Interrupt was not invoked within 2s of ctx cancellation")
}

// ToTable should also unwind once Recv returns an error.
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("ToTable did not return after Interrupt")
}
}

func TestInterruptAllCallsClient(t *testing.T) {
release := make(chan struct{})
close(release)
recorder := &interruptRecorder{
SparkConnectServiceClient: testutils.NewConnectServiceClientMock(nil, nil, nil, t),
calls: make(chan *proto.InterruptRequest, 1),
release: release,
}
c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId)

resp, err := c.Interrupt(context.Background(), proto.InterruptRequest_INTERRUPT_TYPE_ALL, "")
assert.NoError(t, err)
assert.NotNil(t, resp)
req := <-recorder.calls
assert.Equal(t, proto.InterruptRequest_INTERRUPT_TYPE_ALL, req.InterruptType)
assert.Nil(t, req.Interrupt)
}

func TestInterruptOperationCallsClient(t *testing.T) {
release := make(chan struct{})
close(release)
recorder := &interruptRecorder{
SparkConnectServiceClient: testutils.NewConnectServiceClientMock(nil, nil, nil, t),
calls: make(chan *proto.InterruptRequest, 1),
release: release,
}
c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId)

opID := uuid.NewString()
_, err := c.Interrupt(context.Background(), proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, opID)
assert.NoError(t, err)
req := <-recorder.calls
assert.Equal(t, proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, req.InterruptType)
assert.Equal(t, opID, req.GetOperationId())
}
6 changes: 6 additions & 0 deletions spark/mocks/mock_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,9 @@ func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan) (
func (t *TestExecutor) Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) {
return nil, errors.New("not implemented")
}

func (t *TestExecutor) Interrupt(ctx context.Context,
interruptType generated.InterruptRequest_InterruptType, operationIdOrTag string,
) (*generated.InterruptResponse, error) {
return nil, errors.New("not implemented")
}
30 changes: 30 additions & 0 deletions spark/sql/sparksession.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ type SparkSession interface {
CreateDataFrameFromArrow(ctx context.Context, data arrow.Table) (DataFrame, error)
CreateDataFrame(ctx context.Context, data [][]any, schema *types.StructType) (DataFrame, error)
Config() client.RuntimeConfig
// InterruptAll cancels every running operation in this session.
InterruptAll(ctx context.Context) ([]string, error)
// InterruptTag cancels every running operation tagged with tag.
InterruptTag(ctx context.Context, tag string) ([]string, error)
// InterruptOperation cancels the operation with the given operation id.
InterruptOperation(ctx context.Context, operationId string) ([]string, error)
}

// NewSessionBuilder creates a new session builder for starting a new spark session
Expand Down Expand Up @@ -167,6 +173,30 @@ func (s *sparkSessionImpl) Stop() error {
return nil
}

func (s *sparkSessionImpl) InterruptAll(ctx context.Context) ([]string, error) {
resp, err := s.client.Interrupt(ctx, proto.InterruptRequest_INTERRUPT_TYPE_ALL, "")
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to interrupt all: %w", err), sparkerrors.ExecutionError)
}
return resp.GetInterruptedIds(), nil
}

func (s *sparkSessionImpl) InterruptTag(ctx context.Context, tag string) ([]string, error) {
resp, err := s.client.Interrupt(ctx, proto.InterruptRequest_INTERRUPT_TYPE_TAG, tag)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to interrupt tag %s: %w", tag, err), sparkerrors.ExecutionError)
}
return resp.GetInterruptedIds(), nil
}

func (s *sparkSessionImpl) InterruptOperation(ctx context.Context, operationId string) ([]string, error) {
resp, err := s.client.Interrupt(ctx, proto.InterruptRequest_INTERRUPT_TYPE_OPERATION_ID, operationId)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to interrupt operation %s: %w", operationId, err), sparkerrors.ExecutionError)
}
return resp.GetInterruptedIds(), nil
}

func (s *sparkSessionImpl) Table(name string) (DataFrame, error) {
return s.Read().Table(name)
}
Expand Down