diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 1f1e48c8..44ca1376 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -261,6 +261,12 @@ def try_free_blocks(self, num_required: int) -> bool: def get_num_free_blocks(self) -> int: return len(self.free_block_ids) + def get_total_usable_blocks(self) -> int: + freeable_used_blocks = sum( + 1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0 + ) + return len(self.free_block_ids) + freeable_used_blocks + def __repr__(self): return ( f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, " diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index c152d6e4..53eec8a0 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -228,11 +228,39 @@ def _update_requests( req.generated_token_ids.append(token_id) if req.is_prefill: req.is_prefill = False + # vLLM-style replacement character handling is primarily relevant for streaming. + # For offline generation (no output queue), keep the fast incremental path. + if req._output_queue is None: + token_text = self.detokenize([token_id]) + req.generated_text += token_text + else: + # Streaming path: compute delta from a full decode so we can hold back + # trailing '\ufffd' (likely an incomplete UTF-8 sequence). + decoded_text = self.detokenize(req.generated_token_ids) + + finished_now = False + if self._check_request_finished(req, token_id): + req.mark_finished(req.finish_reason) + finished_now = True - token_text = self.tokenizer.decode(token_id) - req.generated_text += token_text + # Update generated_text to the latest decode (used for stop-string checks and debugging) + req.generated_text = decoded_text + + holds_back_incomplete_utf8 = ( + bool(decoded_text) and decoded_text.endswith("\ufffd") + ) - if self._check_request_finished(req, token_id): + # vLLM-style: hold back only if we are not on the final chunk. + if holds_back_incomplete_utf8 and not finished_now: + token_text = "" + else: + last_len = getattr(req, "_stream_last_yielded_length", 0) + token_text = decoded_text[last_len:] + if token_text: + req._stream_last_yielded_length = len(decoded_text) + + # For non-streaming, finish checks happen here. + if req._output_queue is None and self._check_request_finished(req, token_id): req.mark_finished(req.finish_reason) # Put output in queue if it exists (for async streaming) @@ -283,12 +311,15 @@ def apply_chat_template( self, messages: List[dict], add_generation_prompt: bool = True, + chat_template_kwargs: Optional[dict] = None, ) -> str: """Apply chat template to messages.""" + chat_template_kwargs = chat_template_kwargs or {} return self.tokenizer.apply_chat_template( conversation=messages, add_generation_prompt=add_generation_prompt, tokenize=False, + **chat_template_kwargs, ) @@ -486,6 +517,10 @@ def __init__( self._running = False self._step_thread: Optional[threading.Thread] = None + self._healthy = True + + def is_healthy(self) -> bool: + return bool(self._healthy) def start(self): """Start the background inference loop.""" @@ -520,6 +555,7 @@ def _step_loop(self): time.sleep(0.01) except Exception as e: logger.error(f"Error in step loop: {e}", exc_info=True) + self._healthy = False self._running = False break @@ -581,6 +617,8 @@ def add_chat_request( request_id: Optional[str] = None, request_data: Optional[dict] = None, http_request: Optional[any] = None, + add_generation_prompt: bool = True, + chat_template_kwargs: Optional[dict] = None, ) -> InferenceRequest: """Add a chat request to the engine. @@ -594,7 +632,11 @@ def add_chat_request( Returns: The created InferenceRequest object. """ - prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True) + prompt = self.engine.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) return self.add_request( prompt=prompt, sampling_params=sampling_params, @@ -607,6 +649,7 @@ async def stream_request( self, request: InferenceRequest, timeout: float = 100.0, + request_timeout: Optional[float] = None, ) -> AsyncIterator[TokenOutput]: """Stream tokens from a request. @@ -619,6 +662,7 @@ async def stream_request( """ import asyncio + start = time.time() while True: if request.is_finished() and request.output_queue.async_q.empty(): break @@ -635,6 +679,20 @@ async def stream_request( if token_output.finished: break except asyncio.TimeoutError: + # Enforce request-level timeout even if no tokens are produced. + if request_timeout is not None: + now = time.time() + if now - start > float(request_timeout): + request.mark_timeout() + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=FinishReason.TIMEOUT, + generated_text=request.generated_text, + ) + break if request.is_finished(): break continue diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index d6e08aef..224828d1 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -144,6 +144,10 @@ def __init__( # Output management (for async streaming) self._output_queue: Optional[janus.Queue] = None + # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) + # Used by the engine to compute "delta" text chunks from a full decode. + self._stream_last_yielded_length: int = 0 + @property def output_queue(self) -> janus.Queue: """Lazy initialization of output queue.""" diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index b1853292..b3188c9b 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -155,12 +155,21 @@ def schedule(self) -> Optional[SchedulerOutput]: except queue.Empty: break + if not self.can_accept_request(req): + self.waiting_queue.sync_q.put(req) + break + + # Skip requests that were already finished (e.g., timed out/canceled while waiting) + if req.is_finished(): + self.complete_requests([req]) + continue + req_tokens = req.get_input_tokens() num_required_blocks = req.get_num_blocks_required(self.block_size) if not self.cache_manager.can_allocate(num_required_blocks): if not self.cache_manager.try_free_blocks(num_required_blocks): - raise RuntimeError("No available cache blocks") + raise RuntimeError("No available cache blocks for new request") # Allocate blocks with automatic prefix caching support req.block_table, req.slot_mapping, req.num_cached_tokens = ( @@ -185,6 +194,10 @@ def schedule(self) -> Optional[SchedulerOutput]: req = self.running_queue.sync_q.get_nowait() except queue.Empty: break + # Skip requests that were already finished (e.g., timed out/canceled while running) + if req.is_finished(): + self.complete_requests([req]) + continue # Decode phase: allocate slot for newly generated token try: @@ -197,7 +210,7 @@ def schedule(self) -> Optional[SchedulerOutput]: scheduled_requests.append(req) except RuntimeError as e: - raise RuntimeError("No available cache blocks") from e + raise RuntimeError("No available cache blocks for new token") from e # Return decode batch if any running requests were scheduled if scheduled_requests: @@ -237,6 +250,31 @@ def complete_requests(self, requests: List[InferenceRequest]): # Still running, put back in running queue self.running_queue.sync_q.put(req) + def can_accept_request(self, request: InferenceRequest) -> bool: + total_required_blocks = 0 + + # Calculate blocks needed for running requests + running_queue_size = self.running_queue.sync_q.qsize() + for _ in range(running_queue_size): + req = self.running_queue.sync_q.get() + remaining_tokens = ( + req.sampling_params.max_tokens - req.get_num_generated_tokens() + ) + num_blocks_needed = ( + remaining_tokens + self.block_size - 1 + ) // self.block_size + total_required_blocks += num_blocks_needed + self.running_queue.sync_q.put(req) + + # Calculate blocks needed for the new request + total_length = request.get_prompt_length() + total_length += request.sampling_params.max_tokens + num_blocks_needed = (total_length + self.block_size - 1) // self.block_size + total_required_blocks += num_blocks_needed + + # Compare with total usable blocks in cache manager + return total_required_blocks <= self.cache_manager.get_total_usable_blocks() + def get_cache_stats(self) -> dict: """Get cache statistics.""" return { diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 99e1988d..9d9eb570 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -10,6 +10,7 @@ import argparse import uvicorn import logging +import os from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -22,7 +23,7 @@ DEFAULT_REQUEST_TIMEOUT = 1000.0 -def chunk_json(id_, content=None, role=None, finish_reason=None): +def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"): """Generate JSON chunk for streaming response.""" delta = {} if content: @@ -33,7 +34,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "jiuge", + "model": model, "system_fingerprint": None, "choices": [ { @@ -84,6 +85,8 @@ def __init__( port: Server port number. """ self.model_path = model_path + # vLLM-like served model id: directory name of model_path + self.model_id = os.path.basename(os.path.normpath(model_path)) or "model" self.device = device self.dtype = dtype self.tensor_parallel_size = tensor_parallel_size @@ -136,7 +139,10 @@ async def lifespan(app: FastAPI): def _register_routes(self, app: FastAPI): """Register API routes.""" + # OpenAI-compatible chat completions endpoint. + # Support both legacy path and OpenAI-style /v1 prefix for proxy/router compatibility. @app.post("/chat/completions") + @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: data = await request.json() @@ -169,15 +175,21 @@ async def chat_completions(request: Request): @app.get("/health") async def health(): + # Expose engine health so babysitter/registry can treat backend as unhealthy. + if ( + self.engine is not None + and hasattr(self.engine, "is_healthy") + and not self.engine.is_healthy() + ): + return JSONResponse(content={"status": "unhealthy"}, status_code=503) return {"status": "healthy"} - @app.get("/v1/models") - async def list_models(): + def _models_payload(): return { "object": "list", "data": [ { - "id": "jiuge", + "id": self.model_id, "object": "model", "created": int(time.time()), "owned_by": "infinilm", @@ -185,14 +197,53 @@ async def list_models(): ], } + # Support both /v1/models (OpenAI) and /models (common legacy) for compatibility. + @app.get("/v1/models") + async def list_models(): + return _models_payload() + + @app.get("/models") + async def list_models_legacy(): + return _models_payload() + def _build_sampling_params(self, data: dict) -> SamplingParams: """Build SamplingParams from request data.""" + # Support both: + # - top-level OpenAI-ish fields: temperature/top_p/top_k/max_tokens/stop + # - nested dict: sampling_params: { ... } + sp = data.get("sampling_params") or {} + if not isinstance(sp, dict): + sp = {} + + def pick(key: str, default): + # Priority: explicit top-level field > nested sampling_params > server default + if key in data and data.get(key) is not None: + return data.get(key) + if key in sp and sp.get(key) is not None: + return sp.get(key) + return default + + # Accept common alias + max_tokens = pick("max_tokens", self.max_tokens) + if max_tokens is None: + # Some clients use max_new_tokens + max_tokens = pick("max_new_tokens", self.max_tokens) + + stop = pick("stop", None) + if isinstance(stop, str): + stop = [stop] + + stop_token_ids = pick("stop_token_ids", None) + if isinstance(stop_token_ids, int): + stop_token_ids = [stop_token_ids] + return SamplingParams( - temperature=data.get("temperature", self.temperature), - top_p=data.get("top_p", self.top_p), - top_k=data.get("top_k", self.top_k), - max_tokens=data.get("max_tokens", self.max_tokens), - stop=data.get("stop"), + temperature=float(pick("temperature", self.temperature)), + top_p=float(pick("top_p", self.top_p)), + top_k=int(pick("top_k", self.top_k)), + max_tokens=int(max_tokens) if max_tokens is not None else None, + stop=stop, + stop_token_ids=stop_token_ids, ) async def _stream_chat(self, request_id: str, data: dict, http_request: Request): @@ -210,22 +261,26 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) request_id=request_id, request_data=data, http_request=http_request, + add_generation_prompt=bool(data.get("add_generation_prompt", True)), + chat_template_kwargs=data.get("chat_template_kwargs") or {}, ) async for token_output in self.engine.stream_request( - req, timeout=DEFAULT_STREAM_TIMEOUT + req, + timeout=DEFAULT_STREAM_TIMEOUT, + request_timeout=DEFAULT_REQUEST_TIMEOUT, ): - # Check timeout - if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: + # If stream_request enforces timeout, we can just surface the state to the client. + if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning( f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s" ) - req.mark_timeout() error_chunk = json.dumps( chunk_json( request_id, content="[Request timeout]", finish_reason="timeout", + model=self.model_id, ), ensure_ascii=False, ) @@ -238,19 +293,31 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) req.mark_canceled() break - # Send token - chunk = json.dumps( - chunk_json(request_id, content=token_output.token_text), - ensure_ascii=False, + # Skip EOS token text for OpenAI API compatibility + # Check if this token is an EOS token by comparing token_id with eos_token_ids + eos_token_ids = self.engine.engine.eos_token_ids + is_eos_token = ( + eos_token_ids and token_output.token_id in eos_token_ids ) - yield f"data: {chunk}\n\n" + + if not is_eos_token and token_output.token_text: + # Send token + chunk = json.dumps( + chunk_json( + request_id, content=token_output.token_text, model=self.model_id + ), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" if token_output.finished: finish_reason = self._convert_finish_reason( token_output.finish_reason ) chunk = json.dumps( - chunk_json(request_id, finish_reason=finish_reason), + chunk_json( + request_id, finish_reason=finish_reason, model=self.model_id + ), ensure_ascii=False, ) yield f"data: {chunk}\n\n" @@ -262,7 +329,10 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) req.mark_failed() error_chunk = json.dumps( chunk_json( - request_id, content=f"[Error: {str(e)}]", finish_reason="error" + request_id, + content=f"[Error: {str(e)}]", + finish_reason="error", + model=self.model_id, ), ensure_ascii=False, ) @@ -290,17 +360,20 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): request_id=request_id, request_data=data, http_request=http_request, + add_generation_prompt=bool(data.get("add_generation_prompt", True)), + chat_template_kwargs=data.get("chat_template_kwargs") or {}, ) # Collect all generated tokens output_text = "" async for token_output in self.engine.stream_request( - req, timeout=DEFAULT_STREAM_TIMEOUT + req, + timeout=DEFAULT_STREAM_TIMEOUT, + request_timeout=DEFAULT_REQUEST_TIMEOUT, ): - # Check timeout - if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: + # Request-level timeout is handled inside stream_request. + if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning(f"Request {request_id} timed out") - req.mark_timeout() break # Check client disconnect @@ -309,7 +382,15 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): req.mark_canceled() break - output_text += token_output.token_text + # Skip EOS token text for OpenAI API compatibility + # Check if this token is an EOS token by comparing token_id with eos_token_ids + eos_token_ids = self.engine.engine.eos_token_ids + is_eos_token = ( + eos_token_ids and token_output.token_id in eos_token_ids + ) + + if not is_eos_token: + output_text += token_output.token_text if token_output.finished: break @@ -322,6 +403,7 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): content=output_text, role="assistant", finish_reason=finish_reason or "stop", + model=self.model_id, ) return response diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index b23241ea..0e78ca22 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -4,7 +4,6 @@ import time import re import csv -from datasets import load_dataset, Dataset import numpy as np import infinicore from infinilm.modeling_utils import load_model_state_dict_by_file @@ -12,6 +11,7 @@ from infinilm.cache import StaticKVCacheConfig from infinilm.infer_engine import GenerationConfig, InferEngine from infinilm.cache import StaticKVCacheConfig +from datasets import load_dataset, Dataset from abc import ABC, abstractmethod @@ -67,7 +67,7 @@ def __init__( "nvidia": "cuda", "cambricon": "mlu", "ascend": "ascend", - "metax": "metax", + "metax": "cuda", "moore": "moore", "iluvatar": "iluvatar", "kunlun": "kunlun",