fix(openai): harden chat request handling#789
Conversation
There was a problem hiding this comment.
Pull request overview
This PR hardens the OpenAI-compatible API request/streaming paths by normalizing tool-call arguments, improving streaming detokenization to avoid UTF-8 replacement characters on token boundaries, and rejecting over-length prompts before they enter engine scheduling.
Changes:
- Normalize
tool_calls[*].function.argumentsfrom JSON strings (and empty values) into decoded objects before chat template rendering. - Decode streaming text using cumulative token context and emit only the newly decoded suffix to avoid U+FFFD from partial byte-fallback sequences.
- Add early
max_model_lenenforcement in engine preprocessing for prompt length and prompt+generation length.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| tests/entrypoints/test_protocol.py | Adds unit tests covering tool-call argument normalization behavior in ChatMessage.to_template_dict(). |
| atom/model_engine/llm_engine.py | Adds early validation to reject requests that exceed max_model_len before scheduling/model-runner work. |
| atom/entrypoints/openai/protocol.py | Normalizes tool-call arguments (JSON string → object, empty → {}) prior to chat template rendering. |
| atom/entrypoints/openai/api_server.py | Introduces cumulative-token streaming decode state to avoid replacement characters and updates streaming callbacks/cleanup accordingly. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if arguments: | ||
| if not isinstance(arguments, (dict, list)): | ||
| function["arguments"] = json.loads(arguments) | ||
| else: | ||
| function["arguments"] = {} |
| normalized = [] | ||
| for item in tool_calls: | ||
| if not isinstance(item, dict): | ||
| raise TypeError(f"tool_calls entries must be dicts, got {type(item)!r}") |
| ) -> str: | ||
| """Decode streaming output with request-local token context. | ||
|
|
||
| The scheduler sends only newly generated token ids. Decoding that slice | ||
| directly can split byte-fallback or multi-token Unicode sequences and | ||
| emit U+FFFD. Keep cumulative token ids per stream and return only the | ||
| newly decoded suffix, similar to vLLM's incremental detokenizer. | ||
| """ | ||
| global tokenizer | ||
|
|
||
| state = _stream_decode_states.setdefault( | ||
| state_key, | ||
| {"token_ids": [], "sent_text": ""}, | ||
| ) | ||
| state["token_ids"].extend(new_token_ids) | ||
|
|
||
| decoded_text = tokenizer.decode(state["token_ids"], skip_special_tokens=True) | ||
| sent_text = state["sent_text"] | ||
| if decoded_text.startswith(sent_text): | ||
| delta = decoded_text[len(sent_text) :] | ||
| else: | ||
| # Tokenizers may normalize whitespace around special tokens. Fall back | ||
| # to the changed suffix instead of replaying the entire decoded text. | ||
| prefix_len = _common_prefix_len(decoded_text, sent_text) | ||
| delta = decoded_text[prefix_len:] | ||
|
|
||
| if delta.endswith("\ufffd") and not finished: | ||
| return "" | ||
|
|
||
| state["sent_text"] = decoded_text |
| _stream_queues.pop(request_id, None) | ||
| _seq_id_to_request_id.pop(seq_id, None) | ||
| _stream_loops.pop(request_id, None) | ||
| _request_start_times.pop(request_id, None) | ||
| for key in [key for key in _stream_decode_states if key[0] == request_id]: | ||
| _stream_decode_states.pop(key, None) |
| state["sent_text"] = decoded_text | ||
| return delta | ||
|
|
||
|
|
| prompt_len = len(tokens) | ||
| max_model_len = self.config.max_model_len | ||
| if max_model_len is not None and prompt_len > max_model_len: | ||
| raise ValueError( | ||
| f"Input has {prompt_len} tokens, which exceeds " | ||
| f"max_model_len={max_model_len}. Shorten the prompt or " | ||
| "restart the server with a larger --max-model-len." | ||
| ) | ||
| max_tokens = max(0, int(getattr(sampling_params, "max_tokens", 0))) | ||
| if ( | ||
| max_model_len is not None | ||
| and prompt_len + max_tokens > max_model_len | ||
| ): | ||
| raise ValueError( | ||
| f"Requested context length is {prompt_len + max_tokens} " | ||
| f"tokens ({prompt_len} input + {max_tokens} max output), " | ||
| f"which exceeds max_model_len={max_model_len}. Shorten the " | ||
| "prompt, lower max_tokens, or restart the server with a " | ||
| "larger --max-model-len." | ||
| ) |
|
could you please fix the Code Style issue |
c78429f to
1af2a97
Compare
| normalized = [] | ||
| for item in tool_calls: | ||
| if not isinstance(item, dict): | ||
| raise TypeError(f"tool_calls entries must be dicts, got {type(item)!r}") |
| if arguments: | ||
| if not isinstance(arguments, (dict, list)): | ||
| function["arguments"] = json.loads(arguments) | ||
| else: | ||
| function["arguments"] = {} |
| The scheduler sends only newly generated token ids. Decoding that slice | ||
| directly can split byte-fallback or multi-token Unicode sequences and | ||
| emit U+FFFD. Keep cumulative token ids per stream and return only the | ||
| newly decoded suffix, similar to vLLM's incremental detokenizer. | ||
| """ | ||
| global tokenizer | ||
|
|
||
| state = _stream_decode_states.setdefault( | ||
| state_key, | ||
| {"token_ids": [], "sent_text": ""}, | ||
| ) | ||
| state["token_ids"].extend(new_token_ids) | ||
|
|
||
| decoded_text = tokenizer.decode(state["token_ids"], skip_special_tokens=True) | ||
| sent_text = state["sent_text"] | ||
| if decoded_text.startswith(sent_text): | ||
| delta = decoded_text[len(sent_text) :] | ||
| else: | ||
| # Tokenizers may normalize whitespace around special tokens. Fall back | ||
| # to the changed suffix instead of replaying the entire decoded text. | ||
| prefix_len = _common_prefix_len(decoded_text, sent_text) | ||
| delta = decoded_text[prefix_len:] | ||
|
|
||
| if delta.endswith("\ufffd") and not finished: | ||
| return "" | ||
|
|
||
| state["sent_text"] = decoded_text |
| prompt_len = len(tokens) | ||
| max_model_len = self.config.max_model_len | ||
| if max_model_len is not None and prompt_len > max_model_len: | ||
| raise ValueError( | ||
| f"Input has {prompt_len} tokens, which exceeds " | ||
| f"max_model_len={max_model_len}. Shorten the prompt or " | ||
| "restart the server with a larger --max-model-len." | ||
| ) | ||
| max_tokens = max(0, int(getattr(sampling_params, "max_tokens", 0))) | ||
| if ( | ||
| max_model_len is not None | ||
| and prompt_len + max_tokens > max_model_len | ||
| ): | ||
| raise ValueError( | ||
| f"Requested context length is {prompt_len + max_tokens} " | ||
| f"tokens ({prompt_len} input + {max_tokens} max output), " | ||
| f"which exceeds max_model_len={max_model_len}. Shorten the " | ||
| "prompt, lower max_tokens, or restart the server with a " | ||
| "larger --max-model-len." | ||
| ) |
|
@san-tian thank you for contribute. Can you please elaborate more on what cases do you see this PR needed, is some bug, or some deployment hang? |
| f"max_model_len={max_model_len}. Shorten the prompt or " | ||
| "restart the server with a larger --max-model-len." | ||
| ) | ||
| max_tokens = max(0, int(getattr(sampling_params, "max_tokens", 0))) |
| request_id, sibling_index = state_key | ||
| states_for_request = _stream_decode_states.setdefault(request_id, {}) | ||
| state = states_for_request.setdefault( | ||
| sibling_index, | ||
| {"token_ids": [], "sent_text": "", "context_start": 0, "context_text": ""}, | ||
| ) | ||
| old_token_count = len(state["token_ids"]) | ||
| state["token_ids"].extend(new_token_ids) | ||
|
|
| state = states_for_request.setdefault( | ||
| sibling_index, | ||
| {"token_ids": [], "sent_text": "", "context_start": 0, "context_text": ""}, | ||
| ) | ||
| old_token_count = len(state["token_ids"]) | ||
| state["token_ids"].extend(new_token_ids) | ||
|
|
||
| token_ids = state["token_ids"] | ||
| new_context_start = max( | ||
| 0, | ||
| len(token_ids) - len(new_token_ids) - STREAM_DECODE_CONTEXT_TOKENS, | ||
| ) | ||
| decoded_text = tokenizer.decode( |
Summary
max_model_lenbefore they reach scheduling/model-runner internals.Tests
python -m pytest -q -p no:cacheprovider tests/entrypoints/test_protocol.pyinrocm/atom-dev:lateston the AMD host: 27 passed.atom/entrypoints/openai/protocol.py,atom/entrypoints/openai/api_server.py,atom/model_engine/llm_engine.py, andtests/entrypoints/test_protocol.pyinrocm/atom-dev:lateston the AMD host.5; oversized prompt abovemax_model_lenreturns HTTP 400 instead of being scheduled.