From 59dd0b29ae7c938f7b881d04e32fc495c5befd5f Mon Sep 17 00:00:00 2001 From: 40u5 Date: Sun, 7 Jun 2026 23:55:35 -0400 Subject: [PATCH] [FEAT] Add SparkSession tag management for Interrupt() Surfaces SparkSession.AddTag / RemoveTag / GetTags / ClearTags so user code can tag operations and later cancel them by tag via InterruptTag. Tags are stored on the underlying SparkConnectClient (mutex-protected) and threaded into ExecutePlanRequest.Tags on every execute, matching PySpark's tag semantics. Validation matches the Spark Connect contract: tags must be non-empty and must not contain ','. Together with the Interrupt() / InterruptTag / InterruptOperation surface added in #182, this closes #49. --- spark/client/base/base.go | 25 ++++++++++++ spark/client/client.go | 49 ++++++++++++++++++++++++ spark/client/client_test.go | 70 ++++++++++++++++++++++++++++++++++ spark/mocks/mock_executor.go | 37 ++++++++++++++++++ spark/sql/sparksession.go | 28 ++++++++++++++ spark/sql/sparksession_test.go | 23 +++++++++++ 6 files changed, 232 insertions(+) diff --git a/spark/client/base/base.go b/spark/client/base/base.go index c5f01c7..c7d59aa 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -17,6 +17,8 @@ package base import ( "context" + "fmt" + "strings" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -25,6 +27,18 @@ import ( "github.com/apache/spark-connect-go/spark/sql/types" ) +// ValidateTag enforces the Spark Connect tag rules: a tag must be non-empty +// and must not contain ','. See ExecutePlanRequest.tags in spark/connect/base.proto. +func ValidateTag(tag string) error { + if tag == "" { + return fmt.Errorf("spark connect tag cannot be an empty string") + } + if strings.ContainsRune(tag, ',') { + return fmt.Errorf("spark connect tag cannot contain ',': %q", tag) + } + return nil +} + type SparkConnectRPCClient generated.SparkConnectServiceClient // SparkConnectClient is the interface for executing a plan in Spark. @@ -44,6 +58,17 @@ type SparkConnectClient interface { SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2 *generated.Plan) (bool, error) SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error) Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) + + // AddTag attaches the given tag to every subsequent ExecutePlan request issued by this client. + // Tags follow the Spark Connect contract enforced by ValidateTag: non-empty and no commas. + AddTag(tag string) error + // RemoveTag removes the given tag from this client. Removing a tag that was never added is a no-op. + RemoveTag(tag string) error + // GetTags returns a snapshot of the tags currently attached to this client. + // The slice is sorted lexicographically so callers and tests get a deterministic order. + GetTags() []string + // ClearTags drops every tag attached to this client. + ClearTags() } type ExecuteResponseStream interface { diff --git a/spark/client/client.go b/spark/client/client.go index a0d7a4c..77346f6 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" "io" + "sort" + "sync" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -46,6 +48,52 @@ type sparkConnectClientImpl struct { metadata metadata.MD sessionId string opts options.SparkClientOptions + + tagsMu sync.RWMutex + tags map[string]struct{} +} + +func (s *sparkConnectClientImpl) AddTag(tag string) error { + if err := base.ValidateTag(tag); err != nil { + return sparkerrors.WithType(err, sparkerrors.InvalidArgumentError) + } + s.tagsMu.Lock() + defer s.tagsMu.Unlock() + if s.tags == nil { + s.tags = make(map[string]struct{}) + } + s.tags[tag] = struct{}{} + return nil +} + +func (s *sparkConnectClientImpl) RemoveTag(tag string) error { + if err := base.ValidateTag(tag); err != nil { + return sparkerrors.WithType(err, sparkerrors.InvalidArgumentError) + } + s.tagsMu.Lock() + defer s.tagsMu.Unlock() + delete(s.tags, tag) + return nil +} + +func (s *sparkConnectClientImpl) GetTags() []string { + s.tagsMu.RLock() + defer s.tagsMu.RUnlock() + if len(s.tags) == 0 { + return nil + } + out := make([]string, 0, len(s.tags)) + for t := range s.tags { + out = append(out, t) + } + sort.Strings(out) + return out +} + +func (s *sparkConnectClientImpl) ClearTags() { + s.tagsMu.Lock() + defer s.tagsMu.Unlock() + s.tags = nil } func (s *sparkConnectClientImpl) newExecutePlanRequest(plan *proto.Plan) *proto.ExecutePlanRequest { @@ -69,6 +117,7 @@ func (s *sparkConnectClientImpl) newExecutePlanRequest(plan *proto.Plan) *proto. }, }, }, + Tags: s.GetTags(), } } diff --git a/spark/client/client_test.go b/spark/client/client_test.go index ca7019c..4725ec4 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/uuid" + "google.golang.org/grpc" proto "github.com/apache/spark-connect-go/internal/generated" "github.com/apache/spark-connect-go/spark/client" @@ -108,3 +109,72 @@ func Test_Execute_SchemaParsingFails(t *testing.T) { _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) assert.ErrorIs(t, err, sparkerrors.ExecutionError) } + +// executePlanRecorder wraps the testutils mock and captures the last ExecutePlanRequest +// passed to ExecutePlan so tests can inspect Tags / OperationId / SessionId. +type executePlanRecorder struct { + proto.SparkConnectServiceClient + lastRequest *proto.ExecutePlanRequest +} + +func (r *executePlanRecorder) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest, + opts ...grpc.CallOption, +) (proto.SparkConnectService_ExecutePlanClient, error) { + r.lastRequest = in + return r.SparkConnectServiceClient.ExecutePlan(ctx, in, opts...) +} + +func TestAddTagRejectsInvalidInput(t *testing.T) { + c := client.NewSparkExecutorFromClient( + testutils.NewConnectServiceClientMock(nil, nil, nil, t), nil, mocks.MockSessionId) + + assert.ErrorIs(t, c.AddTag(""), sparkerrors.InvalidArgumentError) + assert.ErrorIs(t, c.AddTag("has,comma"), sparkerrors.InvalidArgumentError) + assert.Empty(t, c.GetTags(), "invalid tags must not be stored") +} + +func TestTagsRoundTripAddRemoveClear(t *testing.T) { + c := client.NewSparkExecutorFromClient( + testutils.NewConnectServiceClientMock(nil, nil, nil, t), nil, mocks.MockSessionId) + + assert.NoError(t, c.AddTag("beta")) + assert.NoError(t, c.AddTag("alpha")) + assert.NoError(t, c.AddTag("alpha")) // dedupes + assert.Equal(t, []string{"alpha", "beta"}, c.GetTags()) + + assert.NoError(t, c.RemoveTag("alpha")) + assert.Equal(t, []string{"beta"}, c.GetTags()) + + // Removing a tag that was never added is a no-op. + assert.NoError(t, c.RemoveTag("never-added")) + assert.Equal(t, []string{"beta"}, c.GetTags()) + + c.ClearTags() + assert.Empty(t, c.GetTags()) +} + +func TestExecutePlanRequestCarriesSessionTags(t *testing.T) { + ctx := context.Background() + responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) + recorder := &executePlanRecorder{ + SparkConnectServiceClient: testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), + } + c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId) + + // First call: no tags configured — request must not carry any. + _, err := c.ExecutePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + assert.Empty(t, recorder.lastRequest.GetTags(), "untagged session must not send a Tags field") + + assert.NoError(t, c.AddTag("etl-job-42")) + assert.NoError(t, c.AddTag("priority-high")) + + _, err = c.ExecutePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + assert.Equal(t, []string{"etl-job-42", "priority-high"}, recorder.lastRequest.GetTags()) + + c.ClearTags() + _, err = c.ExecutePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + assert.Empty(t, recorder.lastRequest.GetTags(), "ClearTags must scrub tags from subsequent requests") +} diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go index 600e9e0..c26e9e1 100644 --- a/spark/mocks/mock_executor.go +++ b/spark/mocks/mock_executor.go @@ -18,6 +18,7 @@ package mocks import ( "context" "errors" + "sort" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -32,6 +33,7 @@ type TestExecutor struct { Client base.ExecuteResponseStream response *generated.AnalyzePlanResponse Err error + tags map[string]struct{} } func (t *TestExecutor) ExecutePlan(ctx context.Context, plan *generated.Plan) (base.ExecuteResponseStream, error) { @@ -89,3 +91,38 @@ func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan) ( func (t *TestExecutor) Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) { return nil, errors.New("not implemented") } + +func (t *TestExecutor) AddTag(tag string) error { + if err := base.ValidateTag(tag); err != nil { + return err + } + if t.tags == nil { + t.tags = make(map[string]struct{}) + } + t.tags[tag] = struct{}{} + return nil +} + +func (t *TestExecutor) RemoveTag(tag string) error { + if err := base.ValidateTag(tag); err != nil { + return err + } + delete(t.tags, tag) + return nil +} + +func (t *TestExecutor) GetTags() []string { + if len(t.tags) == 0 { + return nil + } + out := make([]string, 0, len(t.tags)) + for tag := range t.tags { + out = append(out, tag) + } + sort.Strings(out) + return out +} + +func (t *TestExecutor) ClearTags() { + t.tags = nil +} diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index a84bb61..12e59d2 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -48,6 +48,18 @@ type SparkSession interface { CreateDataFrameFromArrow(ctx context.Context, data arrow.Table) (DataFrame, error) CreateDataFrame(ctx context.Context, data [][]any, schema *types.StructType) (DataFrame, error) Config() client.RuntimeConfig + + // AddTag attaches a tag to every operation started by this session afterwards. The tag is + // sent as part of ExecutePlanRequest.tags and can later be used with InterruptTag to cancel + // every running operation that carries it. Tag must be non-empty and must not contain ','. + AddTag(tag string) error + // RemoveTag removes a tag previously added via AddTag. Removing a tag that was never added + // is a no-op. Returns an error only if the tag itself is invalid. + RemoveTag(tag string) error + // GetTags returns the tags currently attached to this session, sorted lexicographically. + GetTags() []string + // ClearTags removes every tag currently attached to this session. + ClearTags() } // NewSessionBuilder creates a new session builder for starting a new spark session @@ -167,6 +179,22 @@ func (s *sparkSessionImpl) Stop() error { return nil } +func (s *sparkSessionImpl) AddTag(tag string) error { + return s.client.AddTag(tag) +} + +func (s *sparkSessionImpl) RemoveTag(tag string) error { + return s.client.RemoveTag(tag) +} + +func (s *sparkSessionImpl) GetTags() []string { + return s.client.GetTags() +} + +func (s *sparkSessionImpl) ClearTags() { + s.client.ClearTags() +} + func (s *sparkSessionImpl) Table(name string) (DataFrame, error) { return s.Read().Table(name) } diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 11539af..a8d863a 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -104,6 +104,29 @@ func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) { assert.Nil(t, spark) } +func TestSparkSessionTagsRoundTripThroughClient(t *testing.T) { + s := testutils.NewConnectServiceClientMock(nil, nil, nil, t) + c := client.NewSparkExecutorFromClient(s, nil, "") + session := &sparkSessionImpl{client: c} + + assert.Empty(t, session.GetTags()) + + assert.NoError(t, session.AddTag("etl")) + assert.NoError(t, session.AddTag("nightly")) + assert.Equal(t, []string{"etl", "nightly"}, session.GetTags()) + + // Invalid tags must be rejected at the session layer too. + assert.ErrorIs(t, session.AddTag(""), sparkerrors.InvalidArgumentError) + assert.ErrorIs(t, session.AddTag("a,b"), sparkerrors.InvalidArgumentError) + assert.Equal(t, []string{"etl", "nightly"}, session.GetTags(), "invalid tags must not be stored") + + assert.NoError(t, session.RemoveTag("etl")) + assert.Equal(t, []string{"nightly"}, session.GetTags()) + + session.ClearTags() + assert.Empty(t, session.GetTags()) +} + func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { ctx := context.Background()