Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions qa/L0_vertex_ai/vertex_ai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,5 +324,115 @@ def test_malformed_binary_header_large_number(self):
)


# =====================================================================
# KServe V1 Protocol Tests
# =====================================================================

def test_v1_predict(self):
"""V1 format: {"instances": [...]} -> {"predictions": [...]}"""
request_body = {
"instances": [
{
"INPUT0": self.input_data_,
"INPUT1": self.input_data_,
}
]
}

import json

headers = {"Content-Type": "application/json"}
r = requests.post(
self.url_, data=json.dumps(request_body), headers=headers
)
r.raise_for_status()

result = r.json()
self.assertIn("predictions", result)
self.assertEqual(len(result["predictions"]), 1)

prediction = result["predictions"][0]
self.assertIn("OUTPUT0", prediction)
self.assertIn("OUTPUT1", prediction)

for i in range(16):
self.assertEqual(prediction["OUTPUT0"][i], self.expected_output0_data_[i])
self.assertEqual(prediction["OUTPUT1"][i], self.expected_output1_data_[i])

def test_v1_predict_with_parameters(self):
"""V1 format with top-level parameters passthrough"""
import json

request_body = {
"instances": [
{
"INPUT0": self.input_data_,
"INPUT1": self.input_data_,
}
],
"parameters": {},
}

headers = {"Content-Type": "application/json"}
r = requests.post(
self.url_, data=json.dumps(request_body), headers=headers
)
r.raise_for_status()

result = r.json()
self.assertIn("predictions", result)
prediction = result["predictions"][0]
for i in range(16):
self.assertEqual(prediction["OUTPUT0"][i], self.expected_output0_data_[i])
self.assertEqual(prediction["OUTPUT1"][i], self.expected_output1_data_[i])

def test_v2_predict_backward_compat(self):
"""V2 format still works (backward compatibility regression test)"""
inputs = []
outputs = []
inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32"))
inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32"))

input_data = np.array(self.input_data_, dtype=np.int32)
input_data = np.expand_dims(input_data, axis=0)
inputs[0].set_data_from_numpy(input_data, binary_data=False)
inputs[1].set_data_from_numpy(input_data, binary_data=False)

outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=False))
outputs.append(httpclient.InferRequestedOutput("OUTPUT1", binary_data=False))
request_body, _ = httpclient.InferenceServerClient.generate_request_body(
inputs, outputs=outputs
)

headers = {"Content-Type": "application/json"}
r = requests.post(self.url_, data=request_body, headers=headers)
r.raise_for_status()

result = httpclient.InferenceServerClient.parse_response_body(r._content)
output0_data = result.as_numpy("OUTPUT0")
output1_data = result.as_numpy("OUTPUT1")
for i in range(16):
self.assertEqual(output0_data[0][i], self.expected_output0_data_[i])
self.assertEqual(output1_data[0][i], self.expected_output1_data_[i])

def test_v1_predict_empty_instances(self):
"""V1 format with empty instances should return an error"""
import json

request_body = {"instances": []}

headers = {"Content-Type": "application/json"}
r = requests.post(
self.url_, data=json.dumps(request_body), headers=headers
)
self.assertEqual(
400,
r.status_code,
"Expected error code 400 for empty instances; got: {}".format(
r.status_code
),
)


if __name__ == "__main__":
unittest.main()
35 changes: 35 additions & 0 deletions src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

extern "C" {
#include <b64/cdecode.h>
#include <b64/cencode.h>
}

namespace triton { namespace server {
Expand Down Expand Up @@ -163,6 +164,40 @@ DecodeBase64(
return nullptr;
}

TRITONSERVER_Error*
EncodeBase64(
const char* input, size_t input_len, std::string& encoded_data)
{
if (input_len > static_cast<size_t>(INT_MAX)) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"input data exceeds the maximum allowed data size limit INT_MAX");
}

// Base64 output is ceil(input_len/3)*4 chars. libb64 inserts a newline
// every 72 OUTPUT chars, so newline count is based on output size.
size_t base64_chars = ((input_len + 2) / 3) * 4;
size_t max_encoded_size = base64_chars + (base64_chars / 72) + 4;
encoded_data.resize(max_encoded_size);

base64_encodestate state;
base64_init_encodestate(&state);

size_t encoded_len = base64_encode_block(
input, input_len, &encoded_data[0], &state);
encoded_len += base64_encode_blockend(&encoded_data[0] + encoded_len, &state);

// Remove any trailing newlines added by libb64
while (encoded_len > 0 &&
(encoded_data[encoded_len - 1] == '\n' ||
encoded_data[encoded_len - 1] == '\r')) {
encoded_len--;
}
encoded_data.resize(encoded_len);

return nullptr;
}

TRITONSERVER_Error*
ValidateSharedMemoryKey(const std::string& name, const std::string& shm_key)
{
Expand Down
9 changes: 9 additions & 0 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ TRITONSERVER_Error* DecodeBase64(
const char* input, size_t input_len, std::vector<char>& decoded_data,
size_t& decoded_size, const std::string& name);

/// Encodes binary data to a Base64 encoded string.
///
/// \param input The raw binary data to encode.
/// \param input_len The length of the input data.
/// \param encoded_data A string to store the Base64 encoded result.
/// \return The error status.
TRITONSERVER_Error* EncodeBase64(
const char* input, size_t input_len, std::string& encoded_data);


/// Validate shared memory key
///
Expand Down
Loading