From 62b56d1364ea39e4a884c2a029910c19d37faa3f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Feb 2026 16:15:04 -0800 Subject: [PATCH] feat: Add inference_generation_config to EvaluationConfig for Tuning PiperOrigin-RevId: 871544397 --- src/main/java/com/google/genai/Tunings.java | 350 ++++++++++++++++++ .../google/genai/types/EvaluationConfig.java | 34 ++ .../java/com/google/genai/TuningsTest.java | 4 + 3 files changed, 388 insertions(+) diff --git a/src/main/java/com/google/genai/Tunings.java b/src/main/java/com/google/genai/Tunings.java index cf1476c4d11..68279068755 100644 --- a/src/main/java/com/google/genai/Tunings.java +++ b/src/main/java/com/google/genai/Tunings.java @@ -650,6 +650,17 @@ ObjectNode evaluationConfigFromVertex( Common.getValueByPath(fromObject, new String[] {"autoraterConfig"})); } + if (Common.getValueByPath(fromObject, new String[] {"inferenceGenerationConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"inferenceGenerationConfig"}, + generationConfigFromVertex( + JsonSerializable.toJsonNode( + Common.getValueByPath(fromObject, new String[] {"inferenceGenerationConfig"})), + toObject, + rootObject)); + } + return toObject; } @@ -678,6 +689,345 @@ ObjectNode evaluationConfigToVertex( Common.getValueByPath(fromObject, new String[] {"autoraterConfig"})); } + if (Common.getValueByPath(fromObject, new String[] {"inferenceGenerationConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"inferenceGenerationConfig"}, + generationConfigToVertex( + JsonSerializable.toJsonNode( + Common.getValueByPath(fromObject, new String[] {"inferenceGenerationConfig"})), + toObject, + rootObject)); + } + + return toObject; + } + + @ExcludeFromGeneratedCoverageReport + ObjectNode generationConfigFromVertex( + JsonNode fromObject, ObjectNode parentObject, JsonNode rootObject) { + ObjectNode toObject = JsonSerializable.objectMapper().createObjectNode(); + if (Common.getValueByPath(fromObject, new String[] {"modelConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"modelSelectionConfig"}, + Common.getValueByPath(fromObject, new String[] {"modelConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseJsonSchema"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseJsonSchema"}, + Common.getValueByPath(fromObject, new String[] {"responseJsonSchema"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"audioTimestamp"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"audioTimestamp"}, + Common.getValueByPath(fromObject, new String[] {"audioTimestamp"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"candidateCount"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"candidateCount"}, + Common.getValueByPath(fromObject, new String[] {"candidateCount"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"enableAffectiveDialog"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"enableAffectiveDialog"}, + Common.getValueByPath(fromObject, new String[] {"enableAffectiveDialog"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"frequencyPenalty"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"frequencyPenalty"}, + Common.getValueByPath(fromObject, new String[] {"frequencyPenalty"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"logprobs"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"logprobs"}, + Common.getValueByPath(fromObject, new String[] {"logprobs"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"maxOutputTokens"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"maxOutputTokens"}, + Common.getValueByPath(fromObject, new String[] {"maxOutputTokens"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"mediaResolution"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"mediaResolution"}, + Common.getValueByPath(fromObject, new String[] {"mediaResolution"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"presencePenalty"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"presencePenalty"}, + Common.getValueByPath(fromObject, new String[] {"presencePenalty"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseLogprobs"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseLogprobs"}, + Common.getValueByPath(fromObject, new String[] {"responseLogprobs"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseMimeType"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseMimeType"}, + Common.getValueByPath(fromObject, new String[] {"responseMimeType"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseModalities"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseModalities"}, + Common.getValueByPath(fromObject, new String[] {"responseModalities"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseSchema"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseSchema"}, + Common.getValueByPath(fromObject, new String[] {"responseSchema"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"routingConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"routingConfig"}, + Common.getValueByPath(fromObject, new String[] {"routingConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"seed"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"seed"}, + Common.getValueByPath(fromObject, new String[] {"seed"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"speechConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"speechConfig"}, + Common.getValueByPath(fromObject, new String[] {"speechConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"stopSequences"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"stopSequences"}, + Common.getValueByPath(fromObject, new String[] {"stopSequences"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"temperature"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"temperature"}, + Common.getValueByPath(fromObject, new String[] {"temperature"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"thinkingConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"thinkingConfig"}, + Common.getValueByPath(fromObject, new String[] {"thinkingConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"topK"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"topK"}, + Common.getValueByPath(fromObject, new String[] {"topK"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"topP"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"topP"}, + Common.getValueByPath(fromObject, new String[] {"topP"})); + } + + return toObject; + } + + @ExcludeFromGeneratedCoverageReport + ObjectNode generationConfigToVertex( + JsonNode fromObject, ObjectNode parentObject, JsonNode rootObject) { + ObjectNode toObject = JsonSerializable.objectMapper().createObjectNode(); + if (Common.getValueByPath(fromObject, new String[] {"modelSelectionConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"modelConfig"}, + Common.getValueByPath(fromObject, new String[] {"modelSelectionConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseJsonSchema"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseJsonSchema"}, + Common.getValueByPath(fromObject, new String[] {"responseJsonSchema"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"audioTimestamp"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"audioTimestamp"}, + Common.getValueByPath(fromObject, new String[] {"audioTimestamp"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"candidateCount"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"candidateCount"}, + Common.getValueByPath(fromObject, new String[] {"candidateCount"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"enableAffectiveDialog"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"enableAffectiveDialog"}, + Common.getValueByPath(fromObject, new String[] {"enableAffectiveDialog"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"frequencyPenalty"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"frequencyPenalty"}, + Common.getValueByPath(fromObject, new String[] {"frequencyPenalty"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"logprobs"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"logprobs"}, + Common.getValueByPath(fromObject, new String[] {"logprobs"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"maxOutputTokens"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"maxOutputTokens"}, + Common.getValueByPath(fromObject, new String[] {"maxOutputTokens"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"mediaResolution"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"mediaResolution"}, + Common.getValueByPath(fromObject, new String[] {"mediaResolution"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"presencePenalty"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"presencePenalty"}, + Common.getValueByPath(fromObject, new String[] {"presencePenalty"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseLogprobs"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseLogprobs"}, + Common.getValueByPath(fromObject, new String[] {"responseLogprobs"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseMimeType"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseMimeType"}, + Common.getValueByPath(fromObject, new String[] {"responseMimeType"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseModalities"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseModalities"}, + Common.getValueByPath(fromObject, new String[] {"responseModalities"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"responseSchema"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"responseSchema"}, + Common.getValueByPath(fromObject, new String[] {"responseSchema"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"routingConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"routingConfig"}, + Common.getValueByPath(fromObject, new String[] {"routingConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"seed"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"seed"}, + Common.getValueByPath(fromObject, new String[] {"seed"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"speechConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"speechConfig"}, + Common.getValueByPath(fromObject, new String[] {"speechConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"stopSequences"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"stopSequences"}, + Common.getValueByPath(fromObject, new String[] {"stopSequences"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"temperature"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"temperature"}, + Common.getValueByPath(fromObject, new String[] {"temperature"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"thinkingConfig"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"thinkingConfig"}, + Common.getValueByPath(fromObject, new String[] {"thinkingConfig"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"topK"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"topK"}, + Common.getValueByPath(fromObject, new String[] {"topK"})); + } + + if (Common.getValueByPath(fromObject, new String[] {"topP"}) != null) { + Common.setValueByPath( + toObject, + new String[] {"topP"}, + Common.getValueByPath(fromObject, new String[] {"topP"})); + } + + if (!Common.isZero( + Common.getValueByPath(fromObject, new String[] {"enableEnhancedCivicAnswers"}))) { + throw new IllegalArgumentException( + "enableEnhancedCivicAnswers parameter is not supported in Vertex AI."); + } + return toObject; } diff --git a/src/main/java/com/google/genai/types/EvaluationConfig.java b/src/main/java/com/google/genai/types/EvaluationConfig.java index 742e3637417..3e452209de8 100644 --- a/src/main/java/com/google/genai/types/EvaluationConfig.java +++ b/src/main/java/com/google/genai/types/EvaluationConfig.java @@ -46,6 +46,10 @@ public abstract class EvaluationConfig extends JsonSerializable { @JsonProperty("autoraterConfig") public abstract Optional autoraterConfig(); + /** Generation config for inference. */ + @JsonProperty("inferenceGenerationConfig") + public abstract Optional inferenceGenerationConfig(); + /** Instantiates a builder for EvaluationConfig. */ @ExcludeFromGeneratedCoverageReport public static Builder builder() { @@ -161,6 +165,36 @@ public Builder clearAutoraterConfig() { return autoraterConfig(Optional.empty()); } + /** + * Setter for inferenceGenerationConfig. + * + *

inferenceGenerationConfig: Generation config for inference. + */ + @JsonProperty("inferenceGenerationConfig") + public abstract Builder inferenceGenerationConfig(GenerationConfig inferenceGenerationConfig); + + /** + * Setter for inferenceGenerationConfig builder. + * + *

inferenceGenerationConfig: Generation config for inference. + */ + @CanIgnoreReturnValue + public Builder inferenceGenerationConfig( + GenerationConfig.Builder inferenceGenerationConfigBuilder) { + return inferenceGenerationConfig(inferenceGenerationConfigBuilder.build()); + } + + @ExcludeFromGeneratedCoverageReport + abstract Builder inferenceGenerationConfig( + Optional inferenceGenerationConfig); + + /** Clears the value of inferenceGenerationConfig field. */ + @ExcludeFromGeneratedCoverageReport + @CanIgnoreReturnValue + public Builder clearInferenceGenerationConfig() { + return inferenceGenerationConfig(Optional.empty()); + } + public abstract EvaluationConfig build(); } diff --git a/src/test/java/com/google/genai/TuningsTest.java b/src/test/java/com/google/genai/TuningsTest.java index 3f26067de60..a15b3572df6 100644 --- a/src/test/java/com/google/genai/TuningsTest.java +++ b/src/test/java/com/google/genai/TuningsTest.java @@ -29,6 +29,7 @@ import com.google.genai.types.CustomOutputFormatConfig; import com.google.genai.types.EvaluationConfig; import com.google.genai.types.GcsDestination; +import com.google.genai.types.GenerationConfig; import com.google.genai.types.JobState; import com.google.genai.types.ListTuningJobsConfig; import com.google.genai.types.OutputConfig; @@ -254,6 +255,8 @@ public void testTuneWithEvaluationConfig(boolean vertexAI) { .autoraterConfig( AutoraterConfig.builder().autoraterModel("test-model").samplingCount(1).build()) .metrics(metrics) + .inferenceGenerationConfig( + GenerationConfig.builder().temperature(0.5f).maxOutputTokens(1024).build()) .build(); CreateTuningJobConfig tuningConfig = CreateTuningJobConfig.builder() @@ -277,6 +280,7 @@ public void testTuneWithEvaluationConfig(boolean vertexAI) { // Assert assertNotNull(currentJob); assertTrue(currentJob.evaluationConfig().isPresent()); + assertTrue(currentJob.evaluationConfig().get().inferenceGenerationConfig().isPresent()); assertTrue(currentJob.state().get().knownEnum() == JobState.Known.JOB_STATE_PENDING); } }