diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index d451a499b7..e4f8645972 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -462,13 +462,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer attachmentMap: Map[String, String], attachmentOrder: Seq[String] ): Seq[OpenAICompositeMessage] = { + // Filter to get only non-null, non-empty path values val orderedAttachments = attachmentOrder.flatMap { columnName => - attachmentMap.get(columnName).map(_.trim).filter(_.nonEmpty) + attachmentMap.get(columnName).flatMap(v => Option(v).map(_.trim).filter(_.nonEmpty)) } - val contentParts = buildContentParts(userMessage, orderedAttachments) - val messages = getPromptsForMessage(Left(contentParts)) - messages + // If there are path columns but all are null/empty, pass through null + if (attachmentOrder.nonEmpty && orderedAttachments.isEmpty) { + null //scalastyle:ignore null + } else { + val contentParts = buildContentParts(userMessage, orderedAttachments) + val messages = getPromptsForMessage(Left(contentParts)) + messages + } } private def buildContentParts(promptText: String, attachmentPaths: Seq[String]): Seq[Map[String, String]] = { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 469d608fa5..b66aa6a203 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -86,6 +86,29 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } } + // scalastyle:off null + test("createMessagesForRow returns null when all path columns are null") { + val prompt = new OpenAIPrompt() + val attachments = Map("filePath" -> null) + val messages = prompt.createMessagesForRow("Summarize the file", attachments, Seq("filePath")) + assert(messages == null) + } + + test("createMessagesForRow returns messages when at least one path column has value") { + val prompt = new OpenAIPrompt() + val tempFile = Files.createTempFile("synapseml-openai", ".txt") + try { + Files.write(tempFile, "example content".getBytes(StandardCharsets.UTF_8)) + val attachments = Map("filePath" -> null, "anotherPath" -> tempFile.toString) + val messages = prompt.createMessagesForRow("Summarize", attachments, Seq("filePath", "anotherPath")) + assert(messages != null) + assert(messages.nonEmpty) + } finally { + Files.deleteIfExists(tempFile) + } + } + // scalastyle:on null + test("RAI Usage") { val result = prompt .setDeploymentName(deploymentName) @@ -321,6 +344,45 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } } + test("null path columns return null output") { + val promptResponses = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setApiVersion("2025-04-01-preview") + .setApiType("responses") + .setColumnType("images", "path") + .setOutputCol("outParsed") + .setPromptTemplate("{questions}: {images}") + + val urlDF = Seq( + ( + "What's in this document?", + "https://mmlspark.blob.core.windows.net/datasets/OCR/paper.pdf" + ), + ( + "What's in this image?", + null // scalastyle:ignore null + ), + ( + "What's in this image?", + "https://mmlspark.blob.core.windows.net/datasets/OCR/test2.png" + ) + ).toDF("questions", "images") + + val results = promptResponses + .transform(urlDF) + .select("outParsed") + .collect() + + // First row: valid path, should have output + assert(results(0).getString(0) != null) + // Second row: null path, should have null output + assert(results(1).get(0) == null) + // Third row: valid path, should have output + assert(results(2).getString(0) != null) + } + ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")