Skip to content

fix(openai): harden chat request handling#789

Open
san-tian wants to merge 4 commits into
ROCm:mainfrom
san-tian:pr/openai-chat-request-guards
Open

fix(openai): harden chat request handling#789
san-tian wants to merge 4 commits into
ROCm:mainfrom
san-tian:pr/openai-chat-request-guards

Conversation

@san-tian
Copy link
Copy Markdown

Summary

  • Normalize OpenAI JSON-string tool call arguments before chat template rendering.
  • Decode streaming responses with cumulative token context so partial UTF-8/token-boundary output does not emit replacement characters.
  • Reject requests whose prompt length or requested prompt+generation length exceeds max_model_len before they reach scheduling/model-runner internals.

Tests

  • python -m pytest -q -p no:cacheprovider tests/entrypoints/test_protocol.py in rocm/atom-dev:latest on the AMD host: 27 passed.
  • Compile check for atom/entrypoints/openai/protocol.py, atom/entrypoints/openai/api_server.py, atom/model_engine/llm_engine.py, and tests/entrypoints/test_protocol.py in rocm/atom-dev:latest on the AMD host.
  • GLM-5.1 service smoke on AMD host: normal completion returns 5; oversized prompt above max_model_len returns HTTP 400 instead of being scheduled.

Copilot AI review requested due to automatic review settings May 14, 2026 14:08
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the OpenAI-compatible API request/streaming paths by normalizing tool-call arguments, improving streaming detokenization to avoid UTF-8 replacement characters on token boundaries, and rejecting over-length prompts before they enter engine scheduling.

Changes:

  • Normalize tool_calls[*].function.arguments from JSON strings (and empty values) into decoded objects before chat template rendering.
  • Decode streaming text using cumulative token context and emit only the newly decoded suffix to avoid U+FFFD from partial byte-fallback sequences.
  • Add early max_model_len enforcement in engine preprocessing for prompt length and prompt+generation length.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
tests/entrypoints/test_protocol.py Adds unit tests covering tool-call argument normalization behavior in ChatMessage.to_template_dict().
atom/model_engine/llm_engine.py Adds early validation to reject requests that exceed max_model_len before scheduling/model-runner work.
atom/entrypoints/openai/protocol.py Normalizes tool-call arguments (JSON string → object, empty → {}) prior to chat template rendering.
atom/entrypoints/openai/api_server.py Introduces cumulative-token streaming decode state to avoid replacement characters and updates streaming callbacks/cleanup accordingly.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/entrypoints/openai/protocol.py Outdated
Comment on lines +66 to +70
if arguments:
if not isinstance(arguments, (dict, list)):
function["arguments"] = json.loads(arguments)
else:
function["arguments"] = {}
Comment thread atom/entrypoints/openai/protocol.py Outdated
normalized = []
for item in tool_calls:
if not isinstance(item, dict):
raise TypeError(f"tool_calls entries must be dicts, got {type(item)!r}")
Comment thread atom/entrypoints/openai/api_server.py Outdated
Comment on lines +176 to +205
) -> str:
"""Decode streaming output with request-local token context.

The scheduler sends only newly generated token ids. Decoding that slice
directly can split byte-fallback or multi-token Unicode sequences and
emit U+FFFD. Keep cumulative token ids per stream and return only the
newly decoded suffix, similar to vLLM's incremental detokenizer.
"""
global tokenizer

state = _stream_decode_states.setdefault(
state_key,
{"token_ids": [], "sent_text": ""},
)
state["token_ids"].extend(new_token_ids)

decoded_text = tokenizer.decode(state["token_ids"], skip_special_tokens=True)
sent_text = state["sent_text"]
if decoded_text.startswith(sent_text):
delta = decoded_text[len(sent_text) :]
else:
# Tokenizers may normalize whitespace around special tokens. Fall back
# to the changed suffix instead of replaying the entire decoded text.
prefix_len = _common_prefix_len(decoded_text, sent_text)
delta = decoded_text[prefix_len:]

if delta.endswith("\ufffd") and not finished:
return ""

state["sent_text"] = decoded_text
Comment thread atom/entrypoints/openai/api_server.py Outdated
Comment on lines +519 to +524
_stream_queues.pop(request_id, None)
_seq_id_to_request_id.pop(seq_id, None)
_stream_loops.pop(request_id, None)
_request_start_times.pop(request_id, None)
for key in [key for key in _stream_decode_states if key[0] == request_id]:
_stream_decode_states.pop(key, None)
state["sent_text"] = decoded_text
return delta


Comment on lines +261 to +280
prompt_len = len(tokens)
max_model_len = self.config.max_model_len
if max_model_len is not None and prompt_len > max_model_len:
raise ValueError(
f"Input has {prompt_len} tokens, which exceeds "
f"max_model_len={max_model_len}. Shorten the prompt or "
"restart the server with a larger --max-model-len."
)
max_tokens = max(0, int(getattr(sampling_params, "max_tokens", 0)))
if (
max_model_len is not None
and prompt_len + max_tokens > max_model_len
):
raise ValueError(
f"Requested context length is {prompt_len + max_tokens} "
f"tokens ({prompt_len} input + {max_tokens} max output), "
f"which exceeds max_model_len={max_model_len}. Shorten the "
"prompt, lower max_tokens, or restart the server with a "
"larger --max-model-len."
)
@valarLip
Copy link
Copy Markdown
Collaborator

could you please fix the Code Style issue

Copilot AI review requested due to automatic review settings May 15, 2026 08:46
@san-tian san-tian force-pushed the pr/openai-chat-request-guards branch from c78429f to 1af2a97 Compare May 15, 2026 08:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

Comment thread atom/entrypoints/openai/protocol.py Outdated
normalized = []
for item in tool_calls:
if not isinstance(item, dict):
raise TypeError(f"tool_calls entries must be dicts, got {type(item)!r}")
Comment thread atom/entrypoints/openai/protocol.py Outdated
Comment on lines +66 to +70
if arguments:
if not isinstance(arguments, (dict, list)):
function["arguments"] = json.loads(arguments)
else:
function["arguments"] = {}
Comment thread atom/entrypoints/openai/api_server.py Outdated
Comment on lines +179 to +205
The scheduler sends only newly generated token ids. Decoding that slice
directly can split byte-fallback or multi-token Unicode sequences and
emit U+FFFD. Keep cumulative token ids per stream and return only the
newly decoded suffix, similar to vLLM's incremental detokenizer.
"""
global tokenizer

state = _stream_decode_states.setdefault(
state_key,
{"token_ids": [], "sent_text": ""},
)
state["token_ids"].extend(new_token_ids)

decoded_text = tokenizer.decode(state["token_ids"], skip_special_tokens=True)
sent_text = state["sent_text"]
if decoded_text.startswith(sent_text):
delta = decoded_text[len(sent_text) :]
else:
# Tokenizers may normalize whitespace around special tokens. Fall back
# to the changed suffix instead of replaying the entire decoded text.
prefix_len = _common_prefix_len(decoded_text, sent_text)
delta = decoded_text[prefix_len:]

if delta.endswith("\ufffd") and not finished:
return ""

state["sent_text"] = decoded_text
Comment on lines +261 to +280
prompt_len = len(tokens)
max_model_len = self.config.max_model_len
if max_model_len is not None and prompt_len > max_model_len:
raise ValueError(
f"Input has {prompt_len} tokens, which exceeds "
f"max_model_len={max_model_len}. Shorten the prompt or "
"restart the server with a larger --max-model-len."
)
max_tokens = max(0, int(getattr(sampling_params, "max_tokens", 0)))
if (
max_model_len is not None
and prompt_len + max_tokens > max_model_len
):
raise ValueError(
f"Requested context length is {prompt_len + max_tokens} "
f"tokens ({prompt_len} input + {max_tokens} max output), "
f"which exceeds max_model_len={max_model_len}. Shorten the "
"prompt, lower max_tokens, or restart the server with a "
"larger --max-model-len."
)
@carlushuang
Copy link
Copy Markdown
Contributor

@san-tian thank you for contribute. Can you please elaborate more on what cases do you see this PR needed, is some bug, or some deployment hang?
And please address the copilot 's comment and fix the format with black (please see CI)

Copilot AI review requested due to automatic review settings May 16, 2026 11:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Comment thread atom/model_engine/llm_engine.py Outdated
f"max_model_len={max_model_len}. Shorten the prompt or "
"restart the server with a larger --max-model-len."
)
max_tokens = max(0, int(getattr(sampling_params, "max_tokens", 0)))
Comment on lines +199 to +207
request_id, sibling_index = state_key
states_for_request = _stream_decode_states.setdefault(request_id, {})
state = states_for_request.setdefault(
sibling_index,
{"token_ids": [], "sent_text": "", "context_start": 0, "context_text": ""},
)
old_token_count = len(state["token_ids"])
state["token_ids"].extend(new_token_ids)

Comment thread atom/entrypoints/openai/api_server.py Outdated
Comment on lines +201 to +213
state = states_for_request.setdefault(
sibling_index,
{"token_ids": [], "sent_text": "", "context_start": 0, "context_text": ""},
)
old_token_count = len(state["token_ids"])
state["token_ids"].extend(new_token_ids)

token_ids = state["token_ids"]
new_context_start = max(
0,
len(token_ids) - len(new_token_ids) - STREAM_DECODE_CONTEXT_TOKENS,
)
decoded_text = tokenizer.decode(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants