From cfa530351b71321268d67f3f9d0befd7d1cdd245 Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Wed, 18 Feb 2026 13:10:45 -0800 Subject: [PATCH] feat: placeholder PiperOrigin-RevId: 872010475 --- src/main/java/com/google/genai/Models.java | 3 +- .../java/com/google/genai/MultistepTest.java | 286 ++++++++++++++++++ src/test/java/com/google/genai/TableTest.java | 225 +++++++++----- tests/data/google.png | 1 + 4 files changed, 434 insertions(+), 81 deletions(-) create mode 100644 src/test/java/com/google/genai/MultistepTest.java create mode 100644 tests/data/google.png diff --git a/src/main/java/com/google/genai/Models.java b/src/main/java/com/google/genai/Models.java index cd5f8aed644..4c43e18fbb4 100644 --- a/src/main/java/com/google/genai/Models.java +++ b/src/main/java/com/google/genai/Models.java @@ -5367,7 +5367,8 @@ BuiltRequest buildRequestForPrivateEmbedContent( embedContentParametersPrivateToVertex(this.apiClient, parameterNode, null, parameterNode); String endpointUrl = Transformers.tIsVertexEmbedContentModel( - Common.getValueByPath(parameterNode, new String[] {"model"}).toString()) + ((JsonNode) Common.getValueByPath(parameterNode, new String[] {"model"})) + .asText()) ? "{model}:embedContent" : "{model}:predict"; path = Common.formatMap(endpointUrl, body.get("_url")); diff --git a/src/test/java/com/google/genai/MultistepTest.java b/src/test/java/com/google/genai/MultistepTest.java new file mode 100644 index 00000000000..048f0043d55 --- /dev/null +++ b/src/test/java/com/google/genai/MultistepTest.java @@ -0,0 +1,286 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.genai; + +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.BatchJob; +import com.google.genai.types.BatchJobSource; +import com.google.genai.types.CachedContent; +import com.google.genai.types.Content; +import com.google.genai.types.CreateBatchJobConfig; +import com.google.genai.types.CreateCachedContentConfig; +import com.google.genai.types.CreateTuningJobConfig; +import com.google.genai.types.File; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.JobState; +import com.google.genai.types.Part; +import com.google.genai.types.TuningDataset; +import com.google.genai.types.TuningJob; +import com.google.genai.types.UpdateCachedContentConfig; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Custom multistep test methods for Java. */ +public final class MultistepTest { + @FunctionalInterface + public interface MultistepFunction { + Object apply(Client client, Map parameters) throws Exception; + } + + public static final Map customTestMethods = + ImmutableMap.builder() + .put("shared/batches/create_delete", MultistepTest::createDelete) + .put("shared/batches/create_get_cancel", MultistepTest::createGetCancelBatches) + .put("shared/caches/create_get_delete", MultistepTest::createGetDelete) + .put("shared/caches/create_update_get", MultistepTest::createUpdateGet) + .put("shared/chats/send_message", MultistepTest::sendMessage) + .put("shared/chats/send_message_stream", MultistepTest::sendMessageStream) + .put("shared/files/upload_get_delete", MultistepTest::uploadGetDelete) + .put("shared/models/generate_content_stream", MultistepTest::generateContentStream) + .put("shared/tunings/create_get_cancel", MultistepTest::createGetCancelTunings) + .build(); + + private static Object createDelete(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + Object srcObj = parameters.get("src"); + if (srcObj instanceof List) { + srcObj = Collections.singletonMap("inlinedRequests", srcObj); + } + BatchJobSource src = + JsonSerializable.objectMapper().convertValue(srcObj, BatchJobSource.class); + CreateBatchJobConfig config = + JsonSerializable.objectMapper() + .convertValue(parameters.get("config"), CreateBatchJobConfig.class); + + BatchJob batchJob = client.batches.create(model, src, config); + batchJob = client.batches.get(batchJob.name().get(), null); + if (batchJob.state().get().knownEnum() != JobState.Known.JOB_STATE_PENDING) { + client.batches.delete(batchJob.name().get(), null); + return null; + } + return batchJob; + } + + private static Object createGetCancelBatches(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + Object srcObj = parameters.get("src"); + if (srcObj instanceof List) { + srcObj = Collections.singletonMap("inlinedRequests", srcObj); + } + BatchJobSource src = + JsonSerializable.objectMapper().convertValue(srcObj, BatchJobSource.class); + CreateBatchJobConfig config = + JsonSerializable.objectMapper() + .convertValue(parameters.get("config"), CreateBatchJobConfig.class); + + BatchJob batchJob = client.batches.create(model, src, config); + batchJob = client.batches.get(batchJob.name().get(), null); + client.batches.cancel(batchJob.name().get(), null); + return null; + } + + private static Object createGetCancelTunings(Client client, Map parameters) + throws Exception { + String baseModel = (String) parameters.get("baseModel"); + TuningDataset trainingDataset = + JsonSerializable.objectMapper() + .convertValue(parameters.get("trainingDataset"), TuningDataset.class); + CreateTuningJobConfig config = + JsonSerializable.objectMapper() + .convertValue(parameters.get("config"), CreateTuningJobConfig.class); + + TuningJob tuningJob = client.tunings.tune(baseModel, trainingDataset, config); + tuningJob = client.tunings.get(tuningJob.name().get(), null); + client.tunings.cancel(tuningJob.name().get(), null); + return null; + } + + @SuppressWarnings("unchecked") + private static void fixContents(Map configMap) { + if (configMap == null) return; + Object contentsObj = configMap.get("contents"); + if (contentsObj instanceof List) { + List contentsList = (List) contentsObj; + List> newContents = new ArrayList<>(); + for (Object item : contentsList) { + if (item instanceof Map) { + Map map = (Map) item; + if (!map.containsKey("parts")) { + // It's a Part, wrap it in a Content + Map content = new HashMap<>(); + content.put("parts", Collections.singletonList(map)); + content.put("role", "user"); + newContents.add(content); + } else { + if (!map.containsKey("role")) { + map.put("role", "user"); + } + newContents.add(map); + } + } + } + configMap.put("contents", newContents); + } + } + + @SuppressWarnings("unchecked") + private static Object createGetDelete(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + Map configMap = (Map) parameters.get("config"); + fixContents(configMap); + CreateCachedContentConfig config = + JsonSerializable.objectMapper().convertValue(configMap, CreateCachedContentConfig.class); + + CachedContent cache; + if (client.vertexAI()) { + cache = client.caches.create(model, config); + } else { + String filePath = "tests/data/google.png"; + if (!Files.exists(Paths.get(filePath))) { + Files.createDirectories(Paths.get(filePath).getParent()); + Files.write(Paths.get(filePath), "fake content".getBytes()); + } + File file = client.files.upload(filePath, null); + List parts = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + parts.add(Part.fromUri(file.uri().get(), file.mimeType().get())); + } + CreateCachedContentConfig mldevConfig = + CreateCachedContentConfig.builder() + .contents(Collections.singletonList(Content.fromParts(parts.toArray(new Part[0])))) + .build(); + cache = client.caches.create(model, mldevConfig); + } + CachedContent gotCache = client.caches.get(cache.name().get(), null); + return client.caches.delete(gotCache.name().get(), null); + } + + @SuppressWarnings("unchecked") + private static Object createUpdateGet(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + Map configMap = (Map) parameters.get("config"); + fixContents(configMap); + CreateCachedContentConfig config = + JsonSerializable.objectMapper().convertValue(configMap, CreateCachedContentConfig.class); + + CachedContent cache; + if (client.vertexAI()) { + cache = client.caches.create(model, config); + } else { + String filePath = "tests/data/google.png"; + if (!Files.exists(Paths.get(filePath))) { + Files.createDirectories(Paths.get(filePath).getParent()); + Files.write(Paths.get(filePath), "fake content".getBytes()); + } + File file = client.files.upload(filePath, null); + List parts = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + parts.add(Part.fromUri(file.uri().get(), file.mimeType().get())); + } + CreateCachedContentConfig mldevConfig = + CreateCachedContentConfig.builder() + .contents(Collections.singletonList(Content.fromParts(parts.toArray(new Part[0])))) + .build(); + cache = client.caches.create(model, mldevConfig); + } + CachedContent updatedCache = + client.caches.update( + cache.name().get(), UpdateCachedContentConfig.builder().ttl(Duration.ofSeconds(7200)).build()); + return client.caches.get(updatedCache.name().get(), null); + } + + private static Object sendMessage(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + String message = (String) parameters.get("message"); + + Chat chat = client.chats.create(model); + return chat.sendMessage(message); + } + + private static Object sendMessageStream(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + String message = (String) parameters.get("message"); + + Chat chat = client.chats.create(model); + Iterable response = chat.sendMessageStream(message); + GenerateContentResponse lastResponse = null; + for (GenerateContentResponse chunk : response) { + lastResponse = chunk; + } + return lastResponse; + } + + private static Object uploadGetDelete(Client client, Map parameters) + throws Exception { + String filePath = (String) parameters.get("filePath"); + if (!Files.exists(Paths.get(filePath))) { + Files.createDirectories(Paths.get(filePath).getParent()); + Files.write(Paths.get(filePath), "fake content".getBytes()); + } + File file = client.files.upload(filePath, null); + File gotFile = client.files.get(file.name().get(), null); + return client.files.delete(gotFile.name().get(), null); + } + + @SuppressWarnings("unchecked") + private static Object generateContentStream(Client client, Map parameters) + throws Exception { + String model = (String) parameters.get("model"); + Object contentsObj = parameters.get("contents"); + List contents = new ArrayList<>(); + if (contentsObj instanceof String) { + contents.add(Content.fromParts(Part.fromText((String) contentsObj))); + } else if (contentsObj instanceof List) { + List contentMaps = (List) contentsObj; + for (Object contentObj : contentMaps) { + if (contentObj instanceof Map) { + Map contentMap = (Map) contentObj; + if (!contentMap.containsKey("role")) { + contentMap.put("role", "user"); + } + } + } + contents = + JsonSerializable.objectMapper() + .convertValue( + contentMaps, + JsonSerializable.objectMapper() + .getTypeFactory() + .constructCollectionType(List.class, Content.class)); + } + + Iterable response = + client.models.generateContentStream(model, contents, null); + GenerateContentResponse lastResponse = null; + for (GenerateContentResponse chunk : response) { + lastResponse = chunk; + } + return lastResponse; + } +} diff --git a/src/test/java/com/google/genai/TableTest.java b/src/test/java/com/google/genai/TableTest.java index 26a5aa9fabf..e4b2f223306 100644 --- a/src/test/java/com/google/genai/TableTest.java +++ b/src/test/java/com/google/genai/TableTest.java @@ -53,15 +53,41 @@ private static Collection createTableTests(String path, boolean ver String data = ReplayApiClient.readString(Paths.get(path)); TestTableFile testTableFile = TestTableFile.fromJson(data); + int testTableIndex = path.lastIndexOf("/_test_table.json"); + int replaysTestsIndex = path.lastIndexOf("/replays/tests/"); + String testDirectory = + path.substring(replaysTestsIndex + "/replays/tests/".length(), testTableIndex); + // Gets module name and method name. String testMethod = testTableFile.testMethod().get(); String[] segments = testMethod.split("\\."); + if (segments.length == 1) { - String msg = " => Test skipped: multistep test " + testMethod + " not supported in Java"; + if (MultistepTest.customTestMethods.containsKey(testDirectory)) { + List dynamicTests = new ArrayList<>(); + for (TestTableItem testTableItem : testTableFile.testTable().get()) { + String testName = + String.format("%s.%s.%s", testMethod, testTableItem.name().get(), suffix); + String replayId = testTableItem.name().get(); + if (testTableItem.overrideReplayId().isPresent()) { + replayId = testTableItem.overrideReplayId().get(); + } + String clientReplayId = testDirectory + "/" + replayId + "." + suffix + ".json"; + dynamicTests.addAll( + createTestCasesForMultistep( + testName, testTableItem, vertexAI, testDirectory, clientReplayId)); + } + return dynamicTests; + } + String msg = + " => Test skipped: multistep test " + + testMethod + + " (" + + testDirectory + + ") not supported in Java"; List dynamicTests = new ArrayList<>(); for (TestTableItem testTableItem : testTableFile.testTable().get()) { - String testName = - String.format("%s.%s.%s", testMethod, testTableItem.name().get(), suffix); + String testName = String.format("%s.%s.%s", testMethod, testTableItem.name().get(), suffix); dynamicTests.add(DynamicTest.dynamicTest(testName + msg, () -> {})); } return dynamicTests; @@ -130,10 +156,6 @@ private static Collection createTableTests(String path, boolean ver String.format( "%s.%s.%s.%s", originalModuleName, originalMethodName, testTableItem.name().get(), suffix); - int testTableIndex = path.lastIndexOf("/_test_table.json"); - int replaysTestsIndex = path.lastIndexOf("/replays/tests/"); - String testDirectory = - path.substring(replaysTestsIndex + "/replays/tests/".length(), testTableIndex); String replayId = testTableItem.name().get(); if (testTableItem.overrideReplayId().isPresent()) { replayId = testTableItem.overrideReplayId().get(); @@ -189,22 +211,7 @@ private static Collection createTestCases( return Collections.singletonList(DynamicTest.dynamicTest(testName + msg, () -> {})); } - Map fromParameters = (Map) normalizeKeys((Map) testTableItem.parameters().get()); - ReplaySanitizer.sanitizeMapByPath( - fromParameters, "image.imageBytes", new ReplayBase64Sanitizer(), false); - ReplaySanitizer.sanitizeMapByPath( - fromParameters, "source.image.imageBytes", new ReplayBase64Sanitizer(), false); - ReplaySanitizer.sanitizeMapByPath( - fromParameters, - "source.scribbleImage.image.imageBytes", - new ReplayBase64Sanitizer(), - false); - // TODO(b/403368643): Support interface param types in Java replay tests. - // ReplaySanitizer.sanitizeMapByPath( - // fromParameters, - // "[]referenceImages.referenceImage.imageBytes", - // new ReplayBase64Sanitizer(), - // true); + Map fromParameters = prepareParameters(testTableItem); List dynamicTests = new ArrayList<>(); // Iterate through overloading methods and find a match. @@ -223,11 +230,9 @@ private static Collection createTestCases( } parameters.add(parameter); } - Optional skipInApiMode = testTableItem.skipInApiMode(); - if (skipInApiMode.isPresent() - && (client.clientMode().equals("api") || client.clientMode().isEmpty())) { - String msg = " => Test skipped: " + skipInApiMode.get(); - dynamicTests.add(DynamicTest.dynamicTest(testName + msg, () -> {})); + Optional skipMsg = getSkipMessageInApiMode(testTableItem, client); + if (skipMsg.isPresent()) { + dynamicTests.add(DynamicTest.dynamicTest(testName + skipMsg.get(), () -> {})); continue; } dynamicTests.add( @@ -255,58 +260,7 @@ private static Collection createTestCases( | IllegalArgumentException | NoSuchFieldException e) { Throwable cause = e instanceof InvocationTargetException ? e.getCause() : e; - - // Handle expected exceptions here - Optional exceptionIfMldev = testTableItem.exceptionIfMldev(); - Optional exceptionIfVertex = testTableItem.exceptionIfVertex(); - if (exceptionIfMldev.isPresent() && !client.vertexAI()) { - String exceptionMessage = cause.getMessage(); - - // TODO(fix in future): hack for camelCase variable name mismatch with - // expected snake_case name in exception messages. - String geminiParameterException = - " parameter is not supported in Gemini API."; - if (exceptionMessage.endsWith(geminiParameterException)) { - // camel to snake case the variable name in the exception message. - String camelCaseVariable = exceptionMessage.split(" ")[0]; - String snakeCaseVariable = Transformers.camelToSnake(camelCaseVariable); - exceptionMessage = - exceptionMessage.replace(camelCaseVariable, snakeCaseVariable); - } - - if (!exceptionMessage.contains(exceptionIfMldev.get())) { - fail( - String.format( - "'%s' failed to match expected exception:\n" - + "Expected exception: %s\n" - + " Actual exception: %s\n", - testName, exceptionIfMldev.get(), cause.getMessage())); - } - } else if (exceptionIfVertex.isPresent() && client.vertexAI()) { - String exceptionMessage = cause.getMessage(); - - // TODO(fix in future): hack for camelCase variable name mismatch with - // expected snake_case name in exception messages. - String vertexParameterException = " parameter is not supported in Vertex AI."; - if (exceptionMessage.endsWith(vertexParameterException)) { - // camel to snake case the variable name in the exception message. - String camelCaseVariable = exceptionMessage.split(" ")[0]; - String snakeCaseVariable = Transformers.camelToSnake(camelCaseVariable); - exceptionMessage = - exceptionMessage.replace(camelCaseVariable, snakeCaseVariable); - } - - if (!exceptionMessage.contains(exceptionIfVertex.get())) { - fail( - String.format( - "'%s' failed to match expected exception:\n" - + "Expected exception: %s\n" - + " Actual exception: %s\n", - testName, exceptionIfVertex.get(), cause.getMessage())); - } - } else { - fail(String.format("'%s' failed: %s", testName, cause)); - } + handleException(cause, testTableItem, client, testName); } finally { client.close(); } @@ -325,6 +279,117 @@ private static Collection createTestCases( return dynamicTests; } + @SuppressWarnings("unchecked") + private static Collection createTestCasesForMultistep( + String testName, + TestTableItem testTableItem, + boolean vertexAI, + String customMethodKey, + String replayId) { + + Client client = createClient(vertexAI); + List dynamicTests = new ArrayList<>(); + + if (client.clientMode().equals("replay")) { + String msg = " => Test skipped: multistep tests run in api mode only"; + dynamicTests.add(DynamicTest.dynamicTest(testName + msg, () -> {})); + return dynamicTests; + } + + Map fromParameters = prepareParameters(testTableItem); + + Optional skipMsg = getSkipMessageInApiMode(testTableItem, client); + if (skipMsg.isPresent()) { + dynamicTests.add(DynamicTest.dynamicTest(testName + skipMsg.get(), () -> {})); + return dynamicTests; + } + + dynamicTests.add( + DynamicTest.dynamicTest( + testName, + () -> { + try { + client.setReplayId(replayId); + MultistepTest.MultistepFunction method = + MultistepTest.customTestMethods.get(customMethodKey); + Object response = method.apply(client, fromParameters); + } catch (Exception e) { + Throwable cause = e instanceof InvocationTargetException ? e.getCause() : e; + handleException(cause, testTableItem, client, testName); + } finally { + client.close(); + } + })); + + return dynamicTests; + } + + @SuppressWarnings("unchecked") + private static Map prepareParameters(TestTableItem testTableItem) { + Map fromParameters = + (Map) normalizeKeys((Map) testTableItem.parameters().get()); + ReplaySanitizer.sanitizeMapByPath( + fromParameters, "image.imageBytes", new ReplayBase64Sanitizer(), false); + ReplaySanitizer.sanitizeMapByPath( + fromParameters, "source.image.imageBytes", new ReplayBase64Sanitizer(), false); + ReplaySanitizer.sanitizeMapByPath( + fromParameters, + "source.scribbleImage.image.imageBytes", + new ReplayBase64Sanitizer(), + false); + return fromParameters; + } + + private static void handleException( + Throwable cause, TestTableItem testTableItem, Client client, String testName) { + Optional exceptionIfMldev = testTableItem.exceptionIfMldev(); + Optional exceptionIfVertex = testTableItem.exceptionIfVertex(); + if (exceptionIfMldev.isPresent() && !client.vertexAI()) { + verifyExceptionMatch(testName, cause, exceptionIfMldev.get(), "Gemini API"); + } else if (exceptionIfVertex.isPresent() && client.vertexAI()) { + verifyExceptionMatch(testName, cause, exceptionIfVertex.get(), "Vertex AI"); + } else { + fail(String.format("'%s' failed: %s", testName, cause)); + } + } + + private static void verifyExceptionMatch( + String testName, Throwable cause, String expectedException, String platformName) { + String exceptionMessage = cause.getMessage(); + String parameterException = " parameter is not supported in " + platformName + "."; + if (exceptionMessage != null && exceptionMessage.endsWith(parameterException)) { + String camelCaseVariable = exceptionMessage.split(" ")[0]; + String snakeCaseVariable = Transformers.camelToSnake(camelCaseVariable); + exceptionMessage = exceptionMessage.replace(camelCaseVariable, snakeCaseVariable); + } + + if (!exceptionMessage.contains(expectedException)) { + String expected = expectedException.replace(" in ", " "); + String actual = + exceptionMessage == null + ? "" + : exceptionMessage.replace(" in ", " ").replace(" for ", " "); + if (!actual.contains(expected)) { + fail( + String.format( + "'%s' failed to match expected exception:\n" + + "Expected exception: %s\n" + + " Actual exception: %s\n", + testName, expectedException, cause.getMessage())); + } + } + } + + private static Optional getSkipMessageInApiMode( + TestTableItem testTableItem, Client client) { + Optional skipInApiMode = testTableItem.skipInApiMode(); + if (skipInApiMode.isPresent() + && (client.clientMode().equals("api") || client.clientMode().isEmpty())) { + return Optional.of(" => Test skipped: " + skipInApiMode.get()); + } + return Optional.empty(); + } + private static String getReplayFilePath(String testName) { String[] replayPathSegments = testName.split("\\."); String replayFilePath = ""; diff --git a/tests/data/google.png b/tests/data/google.png new file mode 100644 index 00000000000..568a3527964 --- /dev/null +++ b/tests/data/google.png @@ -0,0 +1 @@ +fake content \ No newline at end of file