diff --git a/config/config.go b/config/config.go index 1c3ade4..2bd823f 100644 --- a/config/config.go +++ b/config/config.go @@ -180,7 +180,10 @@ type Model struct { Embedding bool `hcl:"embedding,optional"` Images bool `hcl:"images,optional"` Videos bool `hcl:"videos,optional"` - Reasoning bool `hcl:"reasoning,optional"` + Thinking bool `hcl:"thinking,optional"` + Batch bool `hcl:"batch,optional"` + // Region is relevant for providers with inconsistent availability. + Region string `hcl:"region,optional"` ModelDefault } diff --git a/config/config_test.go b/config/config_test.go index 5869ee7..57393a3 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -39,7 +39,7 @@ func TestParsePath(t *testing.T) { Name: "api", Models: []Model{ {Name: "gpt-4o", Alias: list{"4o"}}, - {Name: "o3-mini", Alias: list{"o3"}, Reasoning: true}, + {Name: "o3-mini", Alias: list{"o3"}, Thinking: true}, }, }, { diff --git a/config/simp.hcl.tmpl b/config/simp.hcl.tmpl index ad99004..f20817c 100644 --- a/config/simp.hcl.tmpl +++ b/config/simp.hcl.tmpl @@ -95,8 +95,8 @@ provider "{{ .Driver }}" "{{ .Name }}" { {{- if .Images }} images = {{ .Images }} {{- end }} - {{- if .Reasoning }} - reasoning = {{ .Reasoning }} + {{- if .Thinking }} + thinking = {{ .Thinking }} {{- end }} } {{- end }} diff --git a/driver/vertex.go b/driver/vertex.go index a2f8a69..73e6c6d 100644 --- a/driver/vertex.go +++ b/driver/vertex.go @@ -47,6 +47,14 @@ type Vertex struct { uploads map[string]string } +func (v *Vertex) region(ctx context.Context) string { + m, ok := ctx.Value(simp.KeyModel).(config.Model) + if !ok || m.Region == "" { + return v.Region + } + return m.Region +} + func (v *Vertex) List(ctx context.Context) ([]openai.Model, error) { client, err := v.genaiClient(ctx) if err != nil { @@ -421,10 +429,13 @@ func (v *Vertex) BatchUpload(ctx context.Context, batch *openai.Batch, inputs [] if !v.Batch { return simp.ErrNotImplemented } - modelName, ok := ctx.Value(simp.KeyModel).(config.Model) + model, ok := ctx.Value(simp.KeyModel).(config.Model) if !ok { return fmt.Errorf("model not found") } + if !model.Batch { + return fmt.Errorf("model %q does not support batching", model.Name) + } client, err := v.bigqueryClient(ctx) if err != nil { @@ -451,7 +462,7 @@ func (v *Vertex) BatchUpload(ctx context.Context, batch *openai.Batch, inputs [] contents, config := sect.Contents, sect.Config // Build the request map with only non-nil values req := map[string]any{ - "model": v.googleModel(modelName.Name), + "model": v.googleModel(model.Name), "contents": contents, } if config.SystemInstruction != nil { @@ -524,6 +535,7 @@ func (v *Vertex) BatchUpload(ctx context.Context, batch *openai.Batch, inputs [] } chunk := rows[input:end] + fmt.Println("inserting chunk", len(chunk), "into", v.Dataset, table) if err := inserter.Put(ctx, chunk); err != nil { return fmt.Errorf("failed to insert batch chunk: %w", err) } @@ -777,7 +789,7 @@ func (v *Vertex) genaiClient(ctx context.Context) (*genai.Client, error) { client, err := genai.NewClient(ctx, &genai.ClientConfig{ Backend: genai.BackendVertexAI, Project: v.Project, - Location: v.Region, + Location: v.region(ctx), Credentials: auth.NewCredentials(&auth.CredentialsOptions{ JSON: []byte(v.APIKey), }), @@ -794,7 +806,7 @@ func (v *Vertex) genaiClient(ctx context.Context) (*genai.Client, error) { func (v *Vertex) jobClient(ctx context.Context) (*aipl.JobClient, error) { client, err := aipl.NewJobClient(ctx, - option.WithEndpoint(v.Region+"-aiplatform.googleapis.com:443"), + option.WithEndpoint(v.region(ctx)+"-aiplatform.googleapis.com:443"), v.credentials()) if err != nil { return nil, fmt.Errorf("cannot make job client: %w", err)