Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer
attachmentMap: Map[String, String],
attachmentOrder: Seq[String]
): Seq[OpenAICompositeMessage] = {
// Filter to get only non-null, non-empty path values
val orderedAttachments = attachmentOrder.flatMap { columnName =>
attachmentMap.get(columnName).map(_.trim).filter(_.nonEmpty)
attachmentMap.get(columnName).flatMap(v => Option(v).map(_.trim).filter(_.nonEmpty))
}

val contentParts = buildContentParts(userMessage, orderedAttachments)
val messages = getPromptsForMessage(Left(contentParts))
messages
// If there are path columns but all are null/empty, pass through null
if (attachmentOrder.nonEmpty && orderedAttachments.isEmpty) {
null //scalastyle:ignore null
} else {
val contentParts = buildContentParts(userMessage, orderedAttachments)
val messages = getPromptsForMessage(Left(contentParts))
messages
}
}

private def buildContentParts(promptText: String, attachmentPaths: Seq[String]): Seq[Map[String, String]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}
}

// scalastyle:off null
test("createMessagesForRow returns null when all path columns are null") {
val prompt = new OpenAIPrompt()
val attachments = Map("filePath" -> null)
val messages = prompt.createMessagesForRow("Summarize the file", attachments, Seq("filePath"))
assert(messages == null)
}

test("createMessagesForRow returns messages when at least one path column has value") {
val prompt = new OpenAIPrompt()
val tempFile = Files.createTempFile("synapseml-openai", ".txt")
try {
Files.write(tempFile, "example content".getBytes(StandardCharsets.UTF_8))
val attachments = Map("filePath" -> null, "anotherPath" -> tempFile.toString)
val messages = prompt.createMessagesForRow("Summarize", attachments, Seq("filePath", "anotherPath"))
assert(messages != null)
assert(messages.nonEmpty)
} finally {
Files.deleteIfExists(tempFile)
}
}
// scalastyle:on null

test("RAI Usage") {
val result = prompt
.setDeploymentName(deploymentName)
Expand Down Expand Up @@ -321,6 +344,45 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}
}

test("null path columns return null output") {
val promptResponses = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setApiVersion("2025-04-01-preview")
.setApiType("responses")
.setColumnType("images", "path")
.setOutputCol("outParsed")
.setPromptTemplate("{questions}: {images}")

val urlDF = Seq(
(
"What's in this document?",
"https://mmlspark.blob.core.windows.net/datasets/OCR/paper.pdf"
),
(
"What's in this image?",
null // scalastyle:ignore null
),
(
"What's in this image?",
"https://mmlspark.blob.core.windows.net/datasets/OCR/test2.png"
)
).toDF("questions", "images")

val results = promptResponses
.transform(urlDF)
.select("outParsed")
.collect()

// First row: valid path, should have output
assert(results(0).getString(0) != null)
// Second row: null path, should have null output
assert(results(1).get(0) == null)
// Third row: valid path, should have output
assert(results(2).getString(0) != null)
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
Expand Down
Loading