Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions spark/client/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package base

import (
"context"
"fmt"
"strings"

"github.com/apache/spark-connect-go/spark/sql/utils"

Expand All @@ -25,6 +27,18 @@ import (
"github.com/apache/spark-connect-go/spark/sql/types"
)

// ValidateTag enforces the Spark Connect tag rules: a tag must be non-empty
// and must not contain ','. See ExecutePlanRequest.tags in spark/connect/base.proto.
func ValidateTag(tag string) error {
if tag == "" {
return fmt.Errorf("spark connect tag cannot be an empty string")
}
if strings.ContainsRune(tag, ',') {
return fmt.Errorf("spark connect tag cannot contain ',': %q", tag)
}
return nil
}

type SparkConnectRPCClient generated.SparkConnectServiceClient

// SparkConnectClient is the interface for executing a plan in Spark.
Expand All @@ -44,6 +58,17 @@ type SparkConnectClient interface {
SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2 *generated.Plan) (bool, error)
SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error)
Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error)

// AddTag attaches the given tag to every subsequent ExecutePlan request issued by this client.
// Tags follow the Spark Connect contract enforced by ValidateTag: non-empty and no commas.
AddTag(tag string) error
// RemoveTag removes the given tag from this client. Removing a tag that was never added is a no-op.
RemoveTag(tag string) error
// GetTags returns a snapshot of the tags currently attached to this client.
// The slice is sorted lexicographically so callers and tests get a deterministic order.
GetTags() []string
// ClearTags drops every tag attached to this client.
ClearTags()
}

type ExecuteResponseStream interface {
Expand Down
49 changes: 49 additions & 0 deletions spark/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"errors"
"fmt"
"io"
"sort"
"sync"

"github.com/apache/spark-connect-go/spark/sql/utils"

Expand All @@ -46,6 +48,52 @@ type sparkConnectClientImpl struct {
metadata metadata.MD
sessionId string
opts options.SparkClientOptions

tagsMu sync.RWMutex
tags map[string]struct{}
}

func (s *sparkConnectClientImpl) AddTag(tag string) error {
if err := base.ValidateTag(tag); err != nil {
return sparkerrors.WithType(err, sparkerrors.InvalidArgumentError)
}
s.tagsMu.Lock()
defer s.tagsMu.Unlock()
if s.tags == nil {
s.tags = make(map[string]struct{})
}
s.tags[tag] = struct{}{}
return nil
}

func (s *sparkConnectClientImpl) RemoveTag(tag string) error {
if err := base.ValidateTag(tag); err != nil {
return sparkerrors.WithType(err, sparkerrors.InvalidArgumentError)
}
s.tagsMu.Lock()
defer s.tagsMu.Unlock()
delete(s.tags, tag)
return nil
}

func (s *sparkConnectClientImpl) GetTags() []string {
s.tagsMu.RLock()
defer s.tagsMu.RUnlock()
if len(s.tags) == 0 {
return nil
}
out := make([]string, 0, len(s.tags))
for t := range s.tags {
out = append(out, t)
}
sort.Strings(out)
return out
}

func (s *sparkConnectClientImpl) ClearTags() {
s.tagsMu.Lock()
defer s.tagsMu.Unlock()
s.tags = nil
}

func (s *sparkConnectClientImpl) newExecutePlanRequest(plan *proto.Plan) *proto.ExecutePlanRequest {
Expand All @@ -69,6 +117,7 @@ func (s *sparkConnectClientImpl) newExecutePlanRequest(plan *proto.Plan) *proto.
},
},
},
Tags: s.GetTags(),
}
}

Expand Down
70 changes: 70 additions & 0 deletions spark/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/google/uuid"
"google.golang.org/grpc"

proto "github.com/apache/spark-connect-go/internal/generated"
"github.com/apache/spark-connect-go/spark/client"
Expand Down Expand Up @@ -108,3 +109,72 @@ func Test_Execute_SchemaParsingFails(t *testing.T) {
_, _, _, err := c.ExecuteCommand(ctx, sqlCommand)
assert.ErrorIs(t, err, sparkerrors.ExecutionError)
}

// executePlanRecorder wraps the testutils mock and captures the last ExecutePlanRequest
// passed to ExecutePlan so tests can inspect Tags / OperationId / SessionId.
type executePlanRecorder struct {
proto.SparkConnectServiceClient
lastRequest *proto.ExecutePlanRequest
}

func (r *executePlanRecorder) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest,
opts ...grpc.CallOption,
) (proto.SparkConnectService_ExecutePlanClient, error) {
r.lastRequest = in
return r.SparkConnectServiceClient.ExecutePlan(ctx, in, opts...)
}

func TestAddTagRejectsInvalidInput(t *testing.T) {
c := client.NewSparkExecutorFromClient(
testutils.NewConnectServiceClientMock(nil, nil, nil, t), nil, mocks.MockSessionId)

assert.ErrorIs(t, c.AddTag(""), sparkerrors.InvalidArgumentError)
assert.ErrorIs(t, c.AddTag("has,comma"), sparkerrors.InvalidArgumentError)
assert.Empty(t, c.GetTags(), "invalid tags must not be stored")
}

func TestTagsRoundTripAddRemoveClear(t *testing.T) {
c := client.NewSparkExecutorFromClient(
testutils.NewConnectServiceClientMock(nil, nil, nil, t), nil, mocks.MockSessionId)

assert.NoError(t, c.AddTag("beta"))
assert.NoError(t, c.AddTag("alpha"))
assert.NoError(t, c.AddTag("alpha")) // dedupes
assert.Equal(t, []string{"alpha", "beta"}, c.GetTags())

assert.NoError(t, c.RemoveTag("alpha"))
assert.Equal(t, []string{"beta"}, c.GetTags())

// Removing a tag that was never added is a no-op.
assert.NoError(t, c.RemoveTag("never-added"))
assert.Equal(t, []string{"beta"}, c.GetTags())

c.ClearTags()
assert.Empty(t, c.GetTags())
}

func TestExecutePlanRequestCarriesSessionTags(t *testing.T) {
ctx := context.Background()
responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF)
recorder := &executePlanRecorder{
SparkConnectServiceClient: testutils.NewConnectServiceClientMock(responseStream, nil, nil, t),
}
c := client.NewSparkExecutorFromClient(recorder, nil, mocks.MockSessionId)

// First call: no tags configured — request must not carry any.
_, err := c.ExecutePlan(ctx, &proto.Plan{})
assert.NoError(t, err)
assert.Empty(t, recorder.lastRequest.GetTags(), "untagged session must not send a Tags field")

assert.NoError(t, c.AddTag("etl-job-42"))
assert.NoError(t, c.AddTag("priority-high"))

_, err = c.ExecutePlan(ctx, &proto.Plan{})
assert.NoError(t, err)
assert.Equal(t, []string{"etl-job-42", "priority-high"}, recorder.lastRequest.GetTags())

c.ClearTags()
_, err = c.ExecutePlan(ctx, &proto.Plan{})
assert.NoError(t, err)
assert.Empty(t, recorder.lastRequest.GetTags(), "ClearTags must scrub tags from subsequent requests")
}
37 changes: 37 additions & 0 deletions spark/mocks/mock_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mocks
import (
"context"
"errors"
"sort"

"github.com/apache/spark-connect-go/spark/sql/utils"

Expand All @@ -32,6 +33,7 @@ type TestExecutor struct {
Client base.ExecuteResponseStream
response *generated.AnalyzePlanResponse
Err error
tags map[string]struct{}
}

func (t *TestExecutor) ExecutePlan(ctx context.Context, plan *generated.Plan) (base.ExecuteResponseStream, error) {
Expand Down Expand Up @@ -89,3 +91,38 @@ func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan) (
func (t *TestExecutor) Config(ctx context.Context, configRequest *generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) {
return nil, errors.New("not implemented")
}

func (t *TestExecutor) AddTag(tag string) error {
if err := base.ValidateTag(tag); err != nil {
return err
}
if t.tags == nil {
t.tags = make(map[string]struct{})
}
t.tags[tag] = struct{}{}
return nil
}

func (t *TestExecutor) RemoveTag(tag string) error {
if err := base.ValidateTag(tag); err != nil {
return err
}
delete(t.tags, tag)
return nil
}

func (t *TestExecutor) GetTags() []string {
if len(t.tags) == 0 {
return nil
}
out := make([]string, 0, len(t.tags))
for tag := range t.tags {
out = append(out, tag)
}
sort.Strings(out)
return out
}

func (t *TestExecutor) ClearTags() {
t.tags = nil
}
28 changes: 28 additions & 0 deletions spark/sql/sparksession.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ type SparkSession interface {
CreateDataFrameFromArrow(ctx context.Context, data arrow.Table) (DataFrame, error)
CreateDataFrame(ctx context.Context, data [][]any, schema *types.StructType) (DataFrame, error)
Config() client.RuntimeConfig

// AddTag attaches a tag to every operation started by this session afterwards. The tag is
// sent as part of ExecutePlanRequest.tags and can later be used with InterruptTag to cancel
// every running operation that carries it. Tag must be non-empty and must not contain ','.
AddTag(tag string) error
// RemoveTag removes a tag previously added via AddTag. Removing a tag that was never added
// is a no-op. Returns an error only if the tag itself is invalid.
RemoveTag(tag string) error
// GetTags returns the tags currently attached to this session, sorted lexicographically.
GetTags() []string
// ClearTags removes every tag currently attached to this session.
ClearTags()
}

// NewSessionBuilder creates a new session builder for starting a new spark session
Expand Down Expand Up @@ -167,6 +179,22 @@ func (s *sparkSessionImpl) Stop() error {
return nil
}

func (s *sparkSessionImpl) AddTag(tag string) error {
return s.client.AddTag(tag)
}

func (s *sparkSessionImpl) RemoveTag(tag string) error {
return s.client.RemoveTag(tag)
}

func (s *sparkSessionImpl) GetTags() []string {
return s.client.GetTags()
}

func (s *sparkSessionImpl) ClearTags() {
s.client.ClearTags()
}

func (s *sparkSessionImpl) Table(name string) (DataFrame, error) {
return s.Read().Table(name)
}
Expand Down
23 changes: 23 additions & 0 deletions spark/sql/sparksession_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,29 @@ func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) {
assert.Nil(t, spark)
}

func TestSparkSessionTagsRoundTripThroughClient(t *testing.T) {
s := testutils.NewConnectServiceClientMock(nil, nil, nil, t)
c := client.NewSparkExecutorFromClient(s, nil, "")
session := &sparkSessionImpl{client: c}

assert.Empty(t, session.GetTags())

assert.NoError(t, session.AddTag("etl"))
assert.NoError(t, session.AddTag("nightly"))
assert.Equal(t, []string{"etl", "nightly"}, session.GetTags())

// Invalid tags must be rejected at the session layer too.
assert.ErrorIs(t, session.AddTag(""), sparkerrors.InvalidArgumentError)
assert.ErrorIs(t, session.AddTag("a,b"), sparkerrors.InvalidArgumentError)
assert.Equal(t, []string{"etl", "nightly"}, session.GetTags(), "invalid tags must not be stored")

assert.NoError(t, session.RemoveTag("etl"))
assert.Equal(t, []string{"nightly"}, session.GetTags())

session.ClearTags()
assert.Empty(t, session.GetTags())
}

func TestWriteResultStreamsArrowResultToCollector(t *testing.T) {
ctx := context.Background()

Expand Down