diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index 3ed194d5..5d643e46 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -1,6 +1,6 @@ name: Bug Report description: File a bug report. -title: "πŸ› " +title: "bug - " labels: ["bug"] body: - type: markdown @@ -9,19 +9,21 @@ body: Thanks for taking the time to fill out this bug report! - type: textarea attributes: - label: Bug description + label: Bug description & expected description: Describe the bug placeholder: Ran {this}, did {that}, expected {the other} - validations: + validations: required: true + - type: input + attributes: + label: Narrow cargo test + description: If possible, give the narrow cargo test to show the error + placeholder: cargo test --test tests_p_openai test_tool_full_flow_ok - type: input attributes: label: Adapter - description: The AdapterKind if known + description: The AdapterKind if known - type: input attributes: label: Model description: The Model name if known - - type: textarea - attributes: - label: Suggested Resolution \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8d6ad4fd..57e68e32 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ _* # '_' in src dir, ok. !**/src/**/_* +!**/spec/**/_* *.lock *.lockb @@ -23,6 +24,17 @@ target/ dist/ out/ +*.html +# Doc Files +*.pdf +*.docx +*.xlsx +*.pptx +*.doc +*.xls +*.ppt +*.page + # Data Files *.db3 *.parquet @@ -40,6 +52,11 @@ out/ *.ogg *.avi +# Audio +*.mp3 +*.wav +*.flac + # Images *.icns *.ico @@ -48,7 +65,10 @@ out/ *.png *.bmp -!tests/data/*.jpg +# -- Test data (one by one) +!tests/data/duck-small.jpg +!tests/data/other-one.png +!tests/data/small.pdf # -- Nodejs node_modules/ @@ -60,7 +80,7 @@ __pycache__/ # -- others -# Allows .env (make sure only dev info) +# Allows .env (make sure only dev info) # !.env # Commented by default # Allow vscode diff --git a/BIG-THANKS.md b/BIG-THANKS.md new file mode 100644 index 00000000..ab56fae3 --- /dev/null +++ b/BIG-THANKS.md @@ -0,0 +1,34 @@ +# Big Thanks to + +- v0.5.1 + - [anagrius](https://github.com/anagrius) for [#119](https://github.com/jeremychone/rust-genai/pull/119) openai_resp assistant content fix + - [BinaryMuse](https://github.com/BinaryMuse) for [#117](https://github.com/jeremychone/rust-genai/pull/117) WebStream status check and [#116](https://github.com/jeremychone/rust-genai/pull/116) extra headers fix + - [vlmutolo](https://github.com/vlmutolo) for [#115](https://github.com/jeremychone/rust-genai/pull/115) Gemini 3 tool thoughtSignature fix +- v0.5.x + - [BinaryMuse](https://github.com/BinaryMuse) for [#114](https://github.com/jeremychone/rust-genai/pull/114) Anthropic ToolCalls streaming fix + - [Himmelschmidt](https://github.com/Himmelschmidt) for [#111](https://github.com/jeremychone/rust-genai/pull/111) Gemini `responseJsonSchema` support, [#103](https://github.com/jeremychone/rust-genai/pull/103) error body capture, and Gemini Thought signatures + - [malyavi-nochum](https://github.com/malyavi-nochum) for [#109](https://github.com/jeremychone/rust-genai/pull/109) Fireworks default streaming fix + - [mengdehong](https://github.com/mengdehong) for [#108](https://github.com/jeremychone/rust-genai/pull/108) Ollama reasoning streaming fix + - [Akagi201](https://github.com/Akagi201) for [#105](https://github.com/jeremychone/rust-genai/pull/105) MIMO model adapter +- v0.1.x .. v0.4.x + - [Vagmi Mudumbai](https://github.com/vagmi) for [#96](https://github.com/jeremychone/rust-genai/pull/96) openai audio_type + - [Himmelschmidt](https://github.com/Himmelschmidt) for [#98](https://github.com/jeremychone/rust-genai/pull/98) openai service_tier + - [Bart Carroll](https://github.com/bartCarroll) for [#91](https://github.com/jeremychone/rust-genai/pull/91) Fixed streaming tool calls for openai models + - [Rui Andrada](https://github.com/shingonoide) for [#95](https://github.com/jeremychone/rust-genai/pull/95) refactoring ZHIPU adapter to ZAI + - [Adrien](https://github.com/XciD) Extra headers in requests, seed for chat requests, and fixes (with [Julien Chaumond](https://github.com/julien-c) for extra headers) + - [Andrew Rademacher](https://github.com/AndrewRademacher) for PDF support, Anthropic streamer + - [Jesus Santander](https://github.com/jsantanders) Embedding support [PR #83](https://github.com/jeremychone/rust-genai/pull/83) + - [4t145](https://github.com/4t145) for raw body capture [PR #68](https://github.com/jeremychone/rust-genai/pull/68) + - [Vagmi Mudumbai](https://github.com/vagmi) exec_chat bug fix [PR #86](https://github.com/jeremychone/rust-genai/pull/86) + - [Maximilian Goisser](https://github.com/hobofan) Fix OpenAI adapter to use ServiceTarget + - [ClanceyLu](https://github.com/ClanceyLu) for tool use streaming support, web configuration support, and fixes + - [@SilasMarvin](https://github.com/SilasMarvin) for fixing content/tools issues with some Ollama models [PR #55](https://github.com/jeremychone/rust-genai/pull/55) + - [@una-spirito](https://github.com/luna-spirito) for Gemini `ReasoningEffort::Budget` support + - [@jBernavaPrah](https://github.com/jBernavaPrah) for adding tracing (it was long overdue). [PR #45](https://github.com/jeremychone/rust-genai/pull/45) + - [@GustavoWidman](https://github.com/GustavoWidman) for the initial Gemini tool/function support! [PR #41](https://github.com/jeremychone/rust-genai/pull/41) + - [@AdamStrojek](https://github.com/AdamStrojek) for initial image support [PR #36](https://github.com/jeremychone/rust-genai/pull/36) + - [@semtexzv](https://github.com/semtexzv) for `stop_sequences` Anthropic support [PR #34](https://github.com/jeremychone/rust-genai/pull/34) + - [@omarshehab221](https://github.com/omarshehab221) for de/serialize on structs [PR #19](https://github.com/jeremychone/rust-genai/pull/19) + - [@tusharmath](https://github.com/tusharmath) for making webc::Error [PR #12](https://github.com/jeremychone/rust-genai/pull/12) + - [@giangndm](https://github.com/giangndm) for making stream Send [PR #10](https://github.com/jeremychone/rust-genai/pull/10) + - [@stargazing-dino](https://github.com/stargazing-dino) for [PR #2](https://github.com/jeremychone/rust-genai/pull/2), implement Groq completions diff --git a/CHANGELOG.md b/CHANGELOG.md index feea5827..66ed3c1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,70 @@ `.` minor | `-` Fix | `+` Addition | `^` improvement | `!` Change | `*` Refactor +## 2026-01-31 - [v0.5.3](https://github.com/jeremychone/rust-genai/compare/v0.5.2...v0.5.3) + +- `^` error - add request payload / response body when to chat response fail +- `>` refactor captured_raw_body into client .exec_chat (prep for #137) +- `.` tracing - add traced to web-client for ai response (#132) +- `-` Fix incorrect empty output from MessageContent::joined_texts for β‰₯ 2 text parts (fixes #135) (#136) Co-authored-by: Ross MacLeod +- `.` ChatRole - Add PartialEq / Eq (#131) + + +## 2026-01-27 - [v0.5.2](https://github.com/jeremychone/rust-genai/compare/v0.5.1...v0.5.2) + +- `-` Does not capture body when json parse fail (#128) +- `^` Anthropic - Add separate reasoning content and thought signature for anthropic messages api (#125) +- `-` fix - Ollama tool calls are silently swallowed in OpenAI adapter (streaming) (#124) +- `^` test - ollama - add tool tests +- `^` gemini - Include thoughts and capture thoughts are reasoning content (#121) + +## 2026-01-17 - [v0.5.1](https://github.com/jeremychone/rust-genai/compare/v0.5.0...v0.5.1) + +`!` `Error::WebStream` - added error field to preserve original error +`^` gemini - allow empty tool `thoughtSignature` for Gemini 3 (#115) +`-` webc - check HTTP status in `WebStream` before processing byte stream (#117) +`-` client - ensure extra headers are applied in `exec_chat` and `exec_chat_stream` (#116) +`-` openai_resp - fix assistant message content to use `output_text` (#119) + +## 2026-01-09 - [v0.5.0](https://github.com/jeremychone/rust-genai/compare/v0.4.4...v0.5.0) + +- `!` zai - change namespace strategy with (zai:: for default, and zai-codding:: for subscription, same Adapter) +- `+` New Adapter: bigmodel - add back bigmodel.cn and BigModel adapter (only via namespace) +- `+` MessageContent - Add from ContentPart and Binary +- `+` New Adatper: : Add MIMO model adapter (#105) +- `+` gemini adapter - impl thought signature - ThoughtSignature api update +- `^` anthropic - implemented new output_config.effort for opus-4-5 (matching ReasonningEffort) +- `^` gemini - for gemini-3, convert ReasoningEffort Low/High to the appropriate gemini thinkingLevel LOW/HIGH, fall back on budget if not gemini 3 or other effort +- `^` reasoning - add RasoningEffort::None +- `^` dependency - update to reqwest 0.13 +- `^` MessageContent - add .binaries() and .into_binaries() +- `^` .size - implement .size in ContentPart and MessageContent +- `^` ContentPart - Binary from file (as base64) +- `^` binary - add constructors (from_base64, from_url, from_file) +- `-` pr-anthropic-tool-fix - #pr 114 - Anthropic ToolCalls with no parameters are not parsed correctly while streaming +- `-` Fix Gemini adapter to use responseJsonSchema (PR #111) +- `-` Fix Ollama reasoning streaming (Skip empty reasoning chunks in streaming) +- `-` Fix Fireworks default depending on streaming (#109) +- `-` Capture response body in ResponseFailedNotJson error (#103) +- `>` anthropic - Refactor streamer to use webc::EventSourceStream +- `>` adapter_openai - switched to custom webc::EventSourceStream based on WebStream +- `>` webc - remove 'reqwest-eventsource' dependency, all based in same WebStream (EventsourceStream wrapper) +- `>` ModelName - add namespace_is(..), namespace(), namespace_and_name() +- `>` binary - refactor openai to use into_url for the base64 url +- `>` content_part - refactor binary into own file + +## 2025-11-14 - [v0.4.4](https://github.com/jeremychone/rust-genai/compare/v0.4.3...v0.4.4) + +- `+` openai - adding support for gpt-5-pro (must be mapped to OpenaiResp adapter) +- `+` Add support for openai audio_type content part for voice agent support. ([PR #96](https://github.com/jeremychone/rust-genai/pull/96) thanks to [Vagmi Mudumbai](https://github.com/vagmi)) +- `+` Add support for OpenAI `service_tier` parameter. ([PR #98](https://github.com/jeremychone/rust-genai/pull/98) thanks to [Himmelschmidt](https://github.com/Himmelschmidt)) + + +## 2025-10-25 - [v0.4.3](https://github.com/jeremychone/rust-genai/compare/v0.4.2...v0.4.3) + +- `!` Refactor ZHIPU adapter to ZAI with namespace-based endpoint routing (#95) +- `-` openai - stream tool - Fix streaming too issue (#91) +- `.` added ModelName partial eq implementations for string types (#94) +- `.` anthropic - update model name for haiku 4.5 ## 2026-02-01 - Fork Sync with Upstream v0.6.0-alpha.2 @@ -386,4 +451,4 @@ Some **API Changes** - See [migration-v_0_3_to_0_4](doc/migration/migration-v_0_ - `+` Added AdapterKindResolver - `-` Adapter::list_models api impl and change -- `^` chat_printer - added PrintChatStreamOptions with print_events \ No newline at end of file +- `^` chat_printer - added PrintChatStreamOptions with print_events diff --git a/Cargo.toml b/Cargo.toml index 485509de..91dbcd38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,16 +33,17 @@ base64 = "0.22.0" mime_guess = "2" # -- Others derive_more = { version = "2", features = ["from", "display"] } -value-ext = "0.1.2" +value-ext = "0.1.2" [dev-dependencies] simple-fs = "0.9.2" tracing-subscriber = {version = "0.3", features = ["env-filter"]} serial_test = "3.2" +base64 = "0.22.0" bitflags = "2.8" gcp_auth = "0.12" -# Mock server dependencies +# Mock server dependencies (fork-specific) wiremock = "0.6.5" uuid = { version = "1.11.0", features = ["v4", "serde"] } -# Test utilities +# Test utilities (fork-specific) scopeguard = "1.2.0" diff --git a/README.md b/README.md index 23419c3b..f738a32f 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ -# genai - Multi-AI Providers Library for Rust +# genai, Multi-AI Providers Library for Rust -> **Terraphim Fork**: This is a synchronized fork of [jeremychone/rust-genai](https://github.com/jeremychone/rust-genai) merged with upstream v0.6.0-alpha.2 while preserving additional features including AWS Bedrock, Z.AI with namespace routing, Cerebras, and comprehensive testing infrastructure. +Currently natively supports: **OpenAI**, **Anthropic**, **Gemini**, **xAI**, **Ollama**, **Groq**, **DeepSeek**, **Cohere**, **Together**, **Fireworks**, **Nebius**, **Mimo**, **Zai** (Zhipu AI), **BigModel**. -Currently natively supports: **OpenAI**, **Anthropic**, **AWS Bedrock**, **Gemini**, **XAI/Grok**, **Ollama**, **Groq**, **DeepSeek** (deepseek.com & Groq), **Cohere**, **Cerebras**, **Z.AI** (GLM models), **Zhipu** (more to come) - -Also allows a custom URL with `ServiceTargetResolver` (see [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs)) +Also supports a custom URL with `ServiceTargetResolver` (see [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs)).
@@ -14,72 +12,39 @@ Also allows a custom URL with `ServiceTargetResolver` (see [examples/c06-target-
-
- -Provides a single, ergonomic API to many generative AI providers, such as Anthropic, OpenAI, Gemini, xAI, Ollama, Groq, and more. - -## 🍴 Fork Features (v0.6.0-alpha.2-fork) +Provides a single, ergonomic API for many generative AI providers, such as Anthropic, OpenAI, Gemini, xAI, Ollama, Groq, and more. -This fork includes all upstream improvements plus: +**NOTE:** Big update with **v0.5.0**: New adapters (BigModel, MIMO), Gemini Thinking support, Anthropic Reasoning Effort, and a more robust internal streaming engine. -- **AWS Bedrock Support** - Full Converse API implementation with streaming support for Claude, Llama, Titan, Mistral, and Cohere models via Bearer token authentication -- **Z.AI Adapter** - Anthropic-compatible adapter with namespace-based routing (`zai::` for default, `zai-codding::` for subscription endpoints) -- **Cerebras Integration** - OpenAI-compatible adapter for Cerebras AI inference -- **Comprehensive Testing** - Live API tests, model verification tests, and mock server infrastructure -- **Merged Upstream v0.6.0-alpha.2** - 114 commits including ModelSpec, custom EventSourceStream, enhanced error handling, and more +[Docs for LLMs](doc/for-llm/api-reference-for-llm.md) | [CHANGELOG](CHANGELOG.md) | [BIG THANKS](BIG-THANKS.md) -See [CHANGELOG](CHANGELOG.md) for full details. - -**NOTE:** Big update with **v0.4.x** - More adapters, PDF and image support, embeddings, custom headers, and transparent support for the OpenAI Responses API (gpt-5-codex) - -## v0.4.0 Big Release +## v0.5.x - (2026-01-09...) - **What's new**: - - **PDF and Images** support (thanks to [Andrew Rademacher](https://github.com/AndrewRademacher)) - - **Embedding** support (thanks to [Jesus Santander](https://github.com/jsantanders)) - - **Custom Headers** support (for AWS Bedrock, Vertex, etc.) (thanks to [Adrien](https://github.com/XciD)/[Julien Chaumond](https://github.com/julien-c)) - - **Simpler, flatter `MessageContent`** multi-part format (API change) (thanks to [Andrew Rademacher](https://github.com/AndrewRademacher) for insights) - - **Raw body capture** with `ChatOptions::with_capture_raw_body(true)` (thanks to [4t145](https://github.com/4t145)) - - **Transparent gpt-5-codex support with the Responses API**, even if gpt-5-codex uses a new API protocol (OpenAI Responses API) + - **New Adapters**: BigModel.cn and the MIMO model adapter (thanks to [Akagi201](https://github.com/Akagi201)). + - **zai: changed namespace strategy**, with (zai:: for default, and zai-codding:: for subscription, same adapter) + - **Gemini Thinking & Thought**: Full support for Gemini Thought signatures (thanks to [Himmelschmidt](https://github.com/Himmelschmidt)) and thinking levels. + - **Reasoning Effort Control**: Support for `ReasoningEffort` for Anthropic (Claude 3.7/4.5) and Gemini (Thinking levels), including `ReasoningEffort::None`. + - **Content & Binary Improvements**: Enhanced binary/PDF API and size tracking. + - **Internal Stream Refactor**: Switched to a unified `EventSourceStream` and `WebStream` for better reliability and performance across all providers. + - **Dependency Upgrade**: Now using `reqwest 0.13`. - **What's still awesome**: - - Normalized and ergonomic Chat API across all providers - - Most providers built in (OpenAI, Gemini, Anthropic, xAI, Groq, Together.ai, Fireworks.ai, ...) - - Native protocol support for Gemini and Anthropic protocols, for example allowing full budget controls with Gemini models - - Can override auth, endpoint, and headers to connect to AWS Bedrock, Vertex AI, etc. - -See: - - [migration from v0.3 to v0.4](doc/migration/migration-v_0_3_to_0_4.md) - - [CHANGELOG](CHANGELOG.md) - -## Big Thanks to - -- [Adrien](https://github.com/XciD) Extra headers in requests, seed for chat requests, and fixes (with [Julien Chaumond](https://github.com/julien-c) for extra headers) -- [Andrew Rademacher](https://github.com/AndrewRademacher) for PDF support, Anthropic streamer, and insight on flattening the message content (e.g., ContentParts) -- [Jesus Santander](https://github.com/jsantanders) Embedding support [PR #83](https://github.com/jeremychone/rust-genai/pull/83) -- [4t145](https://github.com/4t145) for raw body capture [PR #68](https://github.com/jeremychone/rust-genai/pull/68) -- [Vagmi Mudumbai](https://github.com/vagmi) exec_chat bug fix [PR #86](https://github.com/jeremychone/rust-genai/pull/86) -- [Maximilian Goisser](https://github.com/hobofan) Fix OpenAI adapter to use ServiceTarget -- [ClanceyLu](https://github.com/ClanceyLu) for Tool Use Streaming support, web configuration support, and fixes -- [@SilasMarvin](https://github.com/SilasMarvin) for fixing content/tools issues with some Ollama models [PR #55](https://github.com/jeremychone/rust-genai/pull/55) -- [@una-spirito](https://github.com/luna-spirito) for Gemini `ReasoningEffort::Budget` support -- [@jBernavaPrah](https://github.com/jBernavaPrah) for adding tracing (it was long overdue). [PR #45](https://github.com/jeremychone/rust-genai/pull/45) -- [@GustavoWidman](https://github.com/GustavoWidman) for the initial Gemini tool/function support! [PR #41](https://github.com/jeremychone/rust-genai/pull/41) -- [@AdamStrojek](https://github.com/AdamStrojek) for initial image support [PR #36](https://github.com/jeremychone/rust-genai/pull/36) -- [@semtexzv](https://github.com/semtexzv) for `stop_sequences` Anthropic support [PR #34](https://github.com/jeremychone/rust-genai/pull/34) -- [@omarshehab221](https://github.com/omarshehab221) for de/serialize on structs [PR #19](https://github.com/jeremychone/rust-genai/pull/19) -- [@tusharmath](https://github.com/tusharmath) for making webc::Error [PR #12](https://github.com/jeremychone/rust-genai/pull/12) -- [@giangndm](https://github.com/giangndm) for making stream Send [PR #10](https://github.com/jeremychone/rust-genai/pull/10) -- [@stargazing-dino](https://github.com/stargazing-dino) for [PR #2](https://github.com/jeremychone/rust-genai/pull/2) - implement Groq completions + - Normalized and ergonomic Chat API across all major providers. + - Native protocol support for Gemini and Anthropic protocols (Reasoning/Thinking controls). + - PDF, image, and embedding support. + - Custom auth, endpoint, and header overrides. + +See [CHANGELOG](CHANGELOG.md) ## Usage examples - Check out [AIPACK](https://aipack.ai), which wraps this **genai** library into an agentic runtime to run, build, and share AI Agent Packs. See [`pro@coder`](https://www.youtube.com/watch?v=zL1BzPVM8-Y&list=PL7r-PXl6ZPcB2zN0XHsYIDaD5yW8I40AE) for a simple example of how I use AI PACK/genai for production coding. -> Note: Feel free to send me a short description and a link to your application or library using genai. +> Note: Feel free to send me a short description and a link to your application or library that uses genai. ## Key Features -- Native Multi-AI Provider/Model: OpenAI, Anthropic, AWS Bedrock, Gemini, Ollama, Groq, xAI, DeepSeek, Cerebras (Direct chat and stream) (see [examples/c00-readme.rs](examples/c00-readme.rs)) +- Native Multi-AI Provider/Model: OpenAI, Anthropic, Gemini, Ollama, Groq, xAI, DeepSeek (direct chat and streaming) (see [examples/c00-readme.rs](examples/c00-readme.rs)) - DeepSeekR1 support, with `reasoning_content` (and stream support), plus DeepSeek Groq and Ollama support (and `reasoning_content` normalization) - Image Analysis (for OpenAI, Gemini flash-2, Anthropic) (see [examples/c07-image.rs](examples/c07-image.rs)) - Custom Auth/API Key (see [examples/c02-auth.rs](examples/c02-auth.rs)) @@ -101,12 +66,16 @@ use genai::Client; const MODEL_OPENAI: &str = "gpt-4o-mini"; // o1-mini, gpt-4o-mini const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307"; -const MODEL_COHERE: &str = "command-light"; +// or namespaced with simple name "fireworks::qwen3-30b-a3b", or "fireworks::accounts/fireworks/models/qwen3-30b-a3b" +const MODEL_FIREWORKS: &str = "accounts/fireworks/models/qwen3-30b-a3b"; +const MODEL_TOGETHER: &str = "together::openai/gpt-oss-20b"; const MODEL_GEMINI: &str = "gemini-2.0-flash"; const MODEL_GROQ: &str = "llama-3.1-8b-instant"; const MODEL_OLLAMA: &str = "gemma:2b"; // sh: `ollama pull gemma:2b` -const MODEL_XAI: &str = "grok-beta"; +const MODEL_XAI: &str = "grok-3-mini"; const MODEL_DEEPSEEK: &str = "deepseek-chat"; +const MODEL_ZAI: &str = "glm-4-plus"; +const MODEL_COHERE: &str = "command-r7b-12-2024"; // NOTE: These are the default environment keys for each AI Adapter Type. // They can be customized; see `examples/c02-auth.rs` @@ -114,12 +83,15 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ // -- De/activate models/providers (MODEL_OPENAI, "OPENAI_API_KEY"), (MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"), - (MODEL_COHERE, "COHERE_API_KEY"), (MODEL_GEMINI, "GEMINI_API_KEY"), + (MODEL_FIREWORKS, "FIREWORKS_API_KEY"), + (MODEL_TOGETHER, "TOGETHER_API_KEY"), (MODEL_GROQ, "GROQ_API_KEY"), (MODEL_XAI, "XAI_API_KEY"), (MODEL_DEEPSEEK, "DEEPSEEK_API_KEY"), (MODEL_OLLAMA, ""), + (MODEL_ZAI, "ZAI_API_KEY"), + (MODEL_COHERE, "COHERE_API_KEY"), ]; // NOTE: Model to AdapterKind (AI Provider) type mapping rule @@ -128,6 +100,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ // - starts_with "command" -> Cohere // - starts_with "gemini" -> Gemini // - model in Groq models -> Groq +// - starts_with "glm" -> ZAI // - For anything else -> Ollama // // This can be customized; see `examples/c03-mapper.rs` @@ -180,11 +153,10 @@ async fn main() -> Result<(), Box> { - [examples/c01-conv.rs](examples/c01-conv.rs) - Shows how to build a conversation flow. - [examples/c02-auth.rs](examples/c02-auth.rs) - Demonstrates how to provide a custom `AuthResolver` to provide auth data (i.e., for api_key) per adapter kind. - [examples/c03-mapper.rs](examples/c03-mapper.rs) - Demonstrates how to provide a custom `AdapterKindResolver` to customize the "model name" to "adapter kind" mapping. -- [examples/c04-chat-options.rs](examples/c04-chat-options.rs) - Demonstrates how to set chat generation options such as `temperature` and `max_tokens` at the client level (for all requests) and per-request level. +- [examples/c04-chat-options.rs](examples/c04-chat-options.rs) - Demonstrates how to set chat generation options such as `temperature` and `max_tokens` at the client level (for all requests) and at the per-request level. - [examples/c05-model-names.rs](examples/c05-model-names.rs) - Shows how to get model names per AdapterKind. - [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs) - For custom auth, endpoint, and model. - [examples/c07-image.rs](examples/c07-image.rs) - Image analysis support -- [examples/c11-cerebras.rs](examples/c11-cerebras.rs) - Cerebras chat + streaming (set `CEREBRAS_API_KEY`)
Static Badge @@ -216,7 +188,7 @@ async fn main() -> Result<(), Box> { ## ChatOptions - **(1)** - **OpenAI-compatible** notes - - Models: OpenAI, DeepSeek, Groq, Ollama, xAI + - Models: OpenAI, DeepSeek, Groq, Ollama, xAI, Mimo, Together, Fireworks, Nebius, Zai, Together, Fireworks, Nebius, Zai | Property | OpenAI Compatibles (*1) | Anthropic | Gemini `generationConfig.` | Cohere | |---------------|-------------------------|-----------------------------|----------------------------|---------------| @@ -230,78 +202,31 @@ async fn main() -> Result<(), Box> { |-----------------------------|-----------------------------|-------------------------|----------------------------|-----------------------| | `prompt_tokens` | `prompt_tokens` | `input_tokens` (added) | `promptTokenCount` (2) | `input_tokens` | | `completion_tokens` | `completion_tokens` | `output_tokens` (added) | `candidatesTokenCount` (2) | `output_tokens` | -| `total_tokens` | `total_tokens` | (computed) | `totalTokenCount` (2) | (computed) | +| `total_tokens` | `total_tokens` | (computed) | `totalTokenCount` (2) | (computed) | | `prompt_tokens_details` | `prompt_tokens_details` | `cached/cache_creation` | N/A for now | N/A for now | | `completion_tokens_details` | `completion_tokens_details` | N/A for now | N/A for now | N/A for now | - - **(1)** - **OpenAI-compatible** notes - - Models: OpenAI, DeepSeek, Groq, Ollama, xAI - - For **Groq**, the property `x_groq.usage.` - - At this point, **Ollama** does not emit input/output tokens when streaming due to the Ollama OpenAI compatibility layer limitation. (see [ollama #4448 - Streaming Chat Completion via OpenAI API should support stream option to include Usage](https://github.com/ollama/ollama/issues/4448)) + - Models: OpenAI, DeepSeek, Groq, Ollama, xAI, Mimo + - For **Groq**, the property `x_groq.usage.` + - At this point, **Ollama** does not emit input/output tokens when streaming due to a limitation in the Ollama OpenAI compatibility layer. (see [ollama #4448 - Streaming Chat Completion via OpenAI API should support stream option to include Usage](https://github.com/ollama/ollama/issues/4448)) - `prompt_tokens_details` and `completion_tokens_details` will have the value sent by the compatible provider (or None) - **(2)**: **Gemini** tokens - - Right now, with the [Gemini Stream API](https://ai.google.dev/api/rest/v1beta/models/streamGenerateContent), it's not clear whether usage for each event is cumulative or must be summed. It appears to be cumulative, meaning the last message shows the total amount of input, output, and total tokens, so that is the current assumption. See [possible tweet answer](https://twitter.com/jeremychone/status/1813734565967802859) for more info. - - -## AWS Bedrock Support - -AWS Bedrock is now natively supported with Bearer token authentication: - -```rust -// Set environment variable -// export AWS_BEARER_TOKEN_BEDROCK="your-api-key" -// export AWS_REGION="us-east-1" (optional, defaults to us-east-1) - -let client = Client::default(); -let chat_req = ChatRequest::new(vec![ - ChatMessage::user("Hello from Bedrock!") -]); - -// Use Titan models directly -let response = client.exec_chat("bedrock::amazon.titan-text-express-v1", chat_req, None).await?; - -// Or use explicit model IDs -let response = client.exec_chat("amazon.titan-text-lite-v1", chat_req, None).await?; -``` - -**Supported Models:** -- Amazon Titan: `amazon.titan-text-express-v1`, `amazon.titan-text-lite-v1`, `amazon.titan-text-premier-v1:0` -- Anthropic Claude: `anthropic.claude-3-5-sonnet-20241022-v2:0`, `anthropic.claude-3-5-haiku-20241022-v1:0`, etc. -- Meta Llama: `meta.llama3-70b-instruct-v1:0`, `meta.llama3-8b-instruct-v1:0`, etc. -- Mistral: `mistral.mistral-large-2407-v1:0`, `mistral.mistral-7b-instruct-v0:2`, etc. -- Cohere: `cohere.command-r-plus-v1:0`, `cohere.command-r-v1:0` -- AI21: `ai21.jamba-1-5-large-v1:0`, `ai21.jamba-1-5-mini-v1:0` - -**Limitations:** -- Streaming requires AWS SigV4 (not supported with Bearer token) -- Titan models don't support system messages or tool calling -- Claude/Llama models may require inference profiles - -See [AWS Bedrock API Keys documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html) for setup. + - Right now, with the [Gemini Stream API](https://ai.google.dev/api/rest/v1beta/models/streamGenerateContent), it's not clear whether usage for each event is cumulative or must be summed. It appears to be cumulative, meaning the last message shows the total number of input, output, and total tokens, so that is the current assumption. See [possible tweet answer](https://twitter.com/jeremychone/status/1813734565967802859) for more info. ## Notes on Possible Direction -- Will add more data on ChatResponse and ChatStream, especially metadata about usage. +- Will add more data to ChatResponse and ChatStream, especially usage metadata. - Add vision/image support to chat messages and responses. - Add function calling support to chat messages and responses. - Add `embed` and `embed_batch`. +- Add the AWS Bedrock variants (e.g., Mistral and Anthropic). Most of the work will be on the "interesting" token signature scheme. To avoid bringing in large SDKs, this might be a lower-priority feature. - Add the Google Vertex AI variants. - May add the Azure OpenAI variant (not sure yet). -## Contributing - -We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on: - -- Development setup -- Running tests (including live API tests) -- Code quality standards -- Submitting changes - ## Links - crates.io: [crates.io/crates/genai](https://crates.io/crates/genai) - GitHub: [github.com/jeremychone/rust-genai](https://github.com/jeremychone/rust-genai) -- Contributing: [CONTRIBUTING.md](CONTRIBUTING.md) - Sponsored by [BriteSnow](https://britesnow.com) (Jeremy Chone's consulting company) \ No newline at end of file diff --git a/dev/spec/_spec-rules.md b/dev/spec/_spec-rules.md new file mode 100644 index 00000000..a666acb3 --- /dev/null +++ b/dev/spec/_spec-rules.md @@ -0,0 +1,59 @@ +# Specification Guidelines + +This document defines the rules for creating and maintaining specification files. + +Important formatting rules + +- Use `-` for bullet points. +- For numbering bullet point style, have empty lines between numbering line. + + +## Types of Specification Files + +### `spec--index.md` + +A single file providing a high-level summary of the entire system. + +### `spec-module_name.md` + +A specification file for each individual module. +- `module-path-name` represents the module’s hierarchy path, flattened with `-`. +- Each file documents the specification for a single module. + +Make sure that the `module_name` is the top most common just after `src/` + +For example `src/module_01/sub_mod/some_file.rs` the spec module name will be `dev/spec/spec-module_01.md` + +(module_name is lowercase) + +## Required Structure for Module Specification Files + +Each `spec-module-path-name.md` file must include the following sections. + + + +## module-path-name + +### Goal + +A clear description of the module’s purpose and responsibilities. + +### Public Module API + +A description of the APIs exposed by the module. +- Define what is exported and how it can be consumed by other modules. +- Include function signatures, data structures, or endpoints as needed. + +### Module Parts + +A breakdown of the module’s internal components. +- May reference sub-files or sub-modules. +- Should explain how the parts work together. + +### Key Design Considerations + +Key design considerations of this module and of its key parts. + + + + diff --git a/dev/spec/spec-adapter.md b/dev/spec/spec-adapter.md new file mode 100644 index 00000000..17e24158 --- /dev/null +++ b/dev/spec/spec-adapter.md @@ -0,0 +1,33 @@ +## adapter + +### Goal + +The `adapter` module is responsible for abstracting the communication with various Generative AI providers (e.g., OpenAI, Gemini, Anthropic, Groq, DeepSeek). It translates generic GenAI requests (like `ChatRequest` and `EmbedRequest`) into provider-specific HTTP request data and converts provider-specific web responses back into generic GenAI response structures. It acts as the translation and dispatch layer between the client logic and the underlying web communication. + +### Public Module API + +The primary public API exposed by the `adapter` module is: + +- `AdapterKind`: An enum identifying the AI provider or protocol type (e.g., `OpenAI`, `Gemini`, `Anthropic`, `Cohere`). This type is used by the client and resolver layers to determine which adapter implementation should handle a specific model request. + +### Module Parts + +- `adapter_kind.rs`: Defines the `AdapterKind` enum. It includes implementation details for serialization, environment variable name resolution, and a default static mapping logic (`from_model`) to associate model names with a specific `AdapterKind`. + +- `adapter_types.rs`: Defines the `Adapter` trait, which sets the contract for all concrete adapter implementations. It also defines common types like `ServiceType` (Chat, ChatStream, Embed) and `WebRequestData` (the normalized structure holding URL, headers, and payload before web execution). + +- `dispatcher.rs`: Contains the `AdapterDispatcher` struct, which acts as the central routing mechanism. It dispatches calls from the client layer to the correct concrete adapter implementation based on the resolved `AdapterKind`. + +- `inter_stream.rs`: Defines internal types (`InterStreamEvent`, `InterStreamEnd`) used by streaming adapters to standardize the output format from diverse provider streaming protocols. This intermediary layer handles complex stream features like capturing usage, reasoning content, and tool calls before conversion to public `ChatStreamResponse` events. + +- `adapters/`: This submodule contains the concrete implementation of the `Adapter` trait for each provider (e.g., `openai`, `gemini`, `anthropic`, `zai`). These submodules handle the specific request/response translation logic for their respective protocols. + +### Key Design Considerations + +- **Stateless and Static Dispatch:** Adapters are designed to be stateless, with all methods in the `Adapter` trait being associated functions (static). Requests are routed efficiently using static dispatch through the `AdapterDispatcher`, minimizing runtime overhead and simplifying dependency management. + +- **Request/Response Normalization:** The adapter layer ensures that incoming requests and outgoing responses conform to generic GenAI types, hiding provider-specific implementation details from the rest of the library. + +- **Dynamic Resolution:** While `AdapterKind::from_model` provides a default mapping from model names (based on common prefixes or keywords), the system allows this to be overridden by custom `ServiceTargetResolver` configurations, enabling flexible routing (e.g., mapping a custom model name to an `OpenAI` adapter with a custom endpoint). + +- **Stream Intermediation:** The introduction of `InterStreamEvent` is crucial for handling the variance in streaming protocols across providers. it ensures that complex data transmitted at the end of a stream (like final usage statistics or aggregated tool calls) can be correctly collected and normalized, regardless of the provider's specific event format. diff --git a/dev/spec/spec-chat.md b/dev/spec/spec-chat.md new file mode 100644 index 00000000..52c1be62 --- /dev/null +++ b/dev/spec/spec-chat.md @@ -0,0 +1,66 @@ +## chat + +### Goal + +The `chat` module provides the core primitives for constructing chat requests, defining messages (including multi-part content like text, binary, and tool data), and handling synchronous and asynchronous (streaming) chat responses across all supported AI providers. It standardizes the data structures necessary for modern LLM interactions. + +### Public Module API + +The module exports the following key data structures: + +- **Request/Message Structure:** + - `ChatRequest`: The primary structure for initiating a chat completion call, containing the history (`messages`), an optional system prompt (`system`), and tool definitions (`tools`). + - `ChatMessage`: Represents a single interaction turn, comprising a `ChatRole`, `MessageContent`, and optional `MessageOptions`. + - `ChatRole`: Enum defining message roles (`System`, `User`, `Assistant`, `Tool`). + - `MessageContent`: A unified container for multi-part content, wrapping a list of `ContentPart`s. + - `ContentPart`: Enum defining content types: `Text`, `Binary`, `ToolCall`, `ToolResponse`. + - `Binary`, `BinarySource`: Structures defining binary payloads (e.g., images), sourced via base64 or URL. + - `MessageOptions`, `CacheControl`: Per-message configuration hints (e.g., for cache behavior). + +- **Configuration:** + - `ChatOptions`: General request configuration, including sampling parameters (`temperature`, `max_tokens`, `top_p`, `seed`), streaming capture flags, and format control. + - `ReasoningEffort`, `Verbosity`: Provider-specific hints for reasoning intensity or output verbosity. + - `ChatResponseFormat`, `JsonSpec`: Defines desired structured output formats (e.g., JSON mode). + +- **Responses:** + - `ChatResponse`: The result of a non-streaming request, including final content, usage, and model identifiers. + - `ChatStreamResponse`: The result wrapper for streaming requests, containing the `ChatStream` and model identity. + +- **Streaming:** + - `ChatStream`: A `futures::Stream` implementation yielding `ChatStreamEvent`s. + - `ChatStreamEvent`: Enum defining streaming events: `Start`, `Chunk` (content), `ReasoningChunk`, `ToolCallChunk`, and `End`. + - `StreamEnd`: Terminal event data including optional captured usage, content, and reasoning content. + +- **Tooling:** + - `Tool`: Metadata and schema defining a function the model can call. + - `ToolCall`: The model's invocation request for a specific tool. + - `ToolResponse`: The output returned from executing a tool, matched by call ID. + +- **Metadata:** + - `Usage`, `PromptTokensDetails`, `CompletionTokensDetails`: Normalized token usage statistics. + +- **Utilities:** + - `printer` module: Contains `print_chat_stream` for console output utilities. + +### Module Parts + +The functionality is divided into specialized files/sub-modules: + +- `chat_message.rs`: Defines the `ChatMessage` fundamental structure and associated types (`ChatRole`, `MessageOptions`). +- `chat_options.rs`: Manages request configuration (`ChatOptions`) and provides parsing logic for provider-specific hints like `ReasoningEffort` and `Verbosity`. +- `chat_req_response_format.rs`: Handles configuration for structured output (`ChatResponseFormat`, `JsonSpec`). +- `chat_request.rs`: Defines the top-level `ChatRequest` and methods for managing the request history and properties. +- `chat_response.rs`: Defines synchronous chat response structures (`ChatResponse`). +- `chat_stream.rs`: Implements the public `ChatStream` and its events, mapping from the internal adapter stream. +- `content_part.rs`: Defines `ContentPart`, `Binary`, and `BinarySource` for handling multi-modal inputs/outputs. +- `message_content.rs`: Defines `MessageContent`, focusing on collection management and convenient accessors for content parts (e.g., joining all text). +- `tool/mod.rs` (and associated files): Defines the tooling primitives (`Tool`, `ToolCall`, `ToolResponse`). +- `usage.rs`: Defines the normalized token counting structures (`Usage`). +- `printer.rs`: Provides utility functions for rendering stream events to standard output. + +### Key Design Considerations + +- **Unified Content Model:** The use of `MessageContent` composed of `ContentPart` allows any message role (user, assistant, tool) to handle complex, multi-part data seamlessly, including text, binary payloads, and tooling actions. +- **Decoupled Streaming:** The public `ChatStream` is an abstraction layer over an internal stream (`InterStream`), ensuring a consistent external interface regardless of adapter implementation details (like internal handling of usage reporting or reasoning chunks). +- **Normalized Usage Metrics:** The `Usage` structure provides an OpenAI-compatible interface while allowing for provider-specific breakdowns (e.g., caching or reasoning tokens) via detailed sub-structures. +- **Hierarchical Options:** `ChatOptions` can be applied globally at the client level or specifically per request. The internal resolution logic ensures request-specific options take precedence over client defaults. diff --git a/dev/spec/spec-client.md b/dev/spec/spec-client.md new file mode 100644 index 00000000..cedd505d --- /dev/null +++ b/dev/spec/spec-client.md @@ -0,0 +1,59 @@ +## client + +### Goal + +The `client` module provides the core entry point (`Client`) for interacting with various Generative AI providers. It encapsulates configuration (`ClientConfig`, `WebConfig`), a builder pattern (`ClientBuilder`), request execution (`exec_chat`, `exec_embed`), and service resolution logic (e.g., determining endpoints and authentication). + +### Public Module API + +The `client` module exposes the following public types: + +- **`Client`**: The main interface for executing AI requests (chat, embedding, streaming, model listing). + - `Client::builder()`: Starts the configuration process. + - `Client::default()`: Creates a client with default configuration. + - Core execution methods: `exec_chat`, `exec_chat_stream`, `exec_embed`, `embed`, `embed_batch`. + - Resolution/Discovery methods: `all_model_names`, `resolve_service_target`. + +- **`ClientBuilder`**: Provides a fluent interface for constructing a `Client`. Used to set `ClientConfig`, default `ChatOptions`, `EmbedOptions`, and custom resolvers (`AuthResolver`, `ServiceTargetResolver`, `ModelMapper`). + +- **`ClientConfig`**: Holds the resolved and default configurations used by the `Client`, including resolver functions and default options. + +- **`Headers`**: A simple map wrapper (`HashMap`) for managing HTTP headers in requests. + +- **`ServiceTarget`**: A struct containing the final resolved components needed to execute a request: `Endpoint`, `AuthData`, and `ModelIden`. + +- **`WebConfig`**: Configuration options specifically for building the underlying `reqwest::Client` (e.g., timeouts, proxies, default headers). + +### Module Parts + +The module is composed of several files that implement the layered client architecture: + +- `builder.rs`: Implements `ClientBuilder`, handling the creation and configuration flow. It initializes or updates the nested `ClientConfig` and optionally an internal `WebClient`. + +- `client_types.rs`: Defines the main `Client` struct and `ClientInner` (which holds `WebClient` and `ClientConfig` behind an `Arc`). + +- `config.rs`: Defines `ClientConfig` and the core `resolve_service_target` logic, which orchestrates calls to `ModelMapper`, `AuthResolver`, and `ServiceTargetResolver` before falling back to adapter defaults. + +- `client_impl.rs`: Contains the main implementation of the public API methods on `Client`, such as `exec_chat` and `exec_embed`. These methods perform service resolution and delegate to `AdapterDispatcher` for request creation and response parsing. + +- `headers.rs`: Implements the `Headers` utility for managing key-value HTTP header maps. + +- `service_target.rs`: Defines the `ServiceTarget` structure for resolved endpoints, authentication, and model identifiers. + +- `web_config.rs`: Defines `WebConfig` and its logic for applying settings to a `reqwest::ClientBuilder`. + +### Key Design Considerations + +- **Client Immutability and Sharing**: The `Client` holds its internal state (`ClientInner` with `WebClient` and `ClientConfig`) wrapped in an `Arc`. This design ensures that the client is thread-safe and cheaply cloneable, aligning with common client patterns in asynchronous Rust applications. + +- **Config Layering and Resolution**: The client architecture employs a sophisticated resolution process managed by `ClientConfig::resolve_service_target`. + - It first applies a `ModelMapper` to potentially translate the input model identifier. + - It then consults the `AuthResolver` for authentication data. If the resolver is absent or returns `None`, it defaults to the adapter's standard authentication mechanism (e.g., API key headers). + - It determines the adapter's default endpoint. + - Finally, it applies the optional `ServiceTargetResolver`, allowing users to override the endpoint, auth, or model for complex scenarios (e.g., custom proxies or routing). + +- **WebClient Abstraction**: The core HTTP client logic is delegated to the `WebClient` (from the `webc` module), which handles low-level request execution and streaming setup. This separation keeps the `client` module focused on business logic and AI provider orchestration. + +- **Builder Pattern for Configuration**: `ClientBuilder` enforces configuration before client creation, simplifying object construction and ensuring necessary dependencies are set up correctly. + +- **Headers Simplification**: The `Headers` struct abstracts HTTP header management, ensuring that subsequent merges or overrides result in a single, final header value, which is typical for API key authorization overrides. diff --git a/dev/spec/spec-common.md b/dev/spec/spec-common.md new file mode 100644 index 00000000..b2d13024 --- /dev/null +++ b/dev/spec/spec-common.md @@ -0,0 +1,36 @@ +## common + +### Goal + +The `common` module provides fundamental data structures used throughout the `genai` library, primarily focusing on identifying models and adapters in a clear and efficient manner. + +### Public Module API + +The module exposes two main types: `ModelName` and `ModelIden`. + +- `ModelName`: Represents a generative AI model identifier (e.g., `"gpt-4o"`, `"claude-3-opus"`). + - It wraps an `Arc` for efficient cloning and sharing across threads. + - Implements `From`, `From<&String>`, `From<&str>`, and `Deref`. + - Supports equality comparison (`PartialEq`) with various string types (`&str`, `String`). + +- `ModelIden`: Uniquely identifies a model by coupling an `AdapterKind` with a `ModelName`. + - Fields: + - `adapter_kind: AdapterKind` + - `model_name: ModelName` + - Constructor: `fn new(adapter_kind: AdapterKind, model_name: impl Into) -> Self` + - Utility methods for creating new identifiers based on name changes: + - `fn from_name(&self, new_name: T) -> ModelIden` + - `fn from_optional_name(&self, new_name: Option) -> ModelIden` + +### Module Parts + +The `common` module consists of: + +- `model_name.rs`: Defines the `ModelName` type and related string manipulation utilities, including parsing optional namespaces (e.g., `namespace::model_name`). +- `model_iden.rs`: Defines the `ModelIden` type, which associates a `ModelName` with an `AdapterKind`. + +### Key Design Considerations + +- **Efficiency of ModelName:** `ModelName` uses `Arc` to ensure that cloning the model identifier is cheap, which is crucial as model identifiers are frequently passed around in request and response structures. +- **Deref Implementation:** Implementing `Deref` for `ModelName` allows it to be used naturally as a string reference. +- **ModelIden Immutability:** `ModelIden` is designed to be immutable and fully identifiable, combining the model string identity (`ModelName`) with the service provider identity (`AdapterKind`). diff --git a/dev/spec/spec-webc.md b/dev/spec/spec-webc.md new file mode 100644 index 00000000..a4bd3232 --- /dev/null +++ b/dev/spec/spec-webc.md @@ -0,0 +1,36 @@ +## webc + +### Goal + +The `webc` module provides a low-level, internal web client layer utilizing `reqwest`. Its primary role is to abstract standard HTTP requests (GET/POST) and manage complex streaming responses required by various AI providers, especially those that do not fully conform to the Server-Sent Events (SSE) standard (`text/event-stream`). It handles standard JSON requests/responses and custom stream parsing. + +### Public Module API + +The `webc` module is primarily an internal component, only exposing its dedicated error type publicly. + +- `pub use error::Error;` + - `Error`: An enum representing all possible errors originating from the web communication layer (e.g., failed status codes, JSON parsing errors, reqwest errors, stream clone errors). + +(All other types like `WebClient`, `WebResponse`, `WebStream`, and `Result` are exported as `pub(crate)` for internal library use.) + +### Module Parts + +The module consists of three main internal components: + +- `error.rs`: Defines the `Error` enum and the module-scoped `Result` type alias. It captures network/HTTP related failures and external errors like `reqwest::Error` and `value_ext::JsonValueExtError`. + +- `web_client.rs`: Contains the `WebClient` struct, a thin wrapper around `reqwest::Client`. It provides methods (`do_get`, `do_post`) for non-streaming standard HTTP communication, which assumes the response body is JSON and is parsed into `serde_json::Value`. It also defines `WebResponse`, which encapsulates the HTTP status and parsed JSON body. + +- `web_stream.rs`: Implements `WebStream`, a custom `futures::Stream` implementation designed for handling non-SSE streaming protocols used by some AI providers (e.g., Cohere, Gemini). It defines `StreamMode` to specify how stream chunks should be parsed (either by a fixed delimiter or specialized handling for "Pretty JSON Array" formats). + +### Key Design Considerations + +- **Internal Focus:** The module is designed strictly for internal use (`pub(crate)`) except for the public error type. This shields the rest of the library from direct `reqwest` dependency details. + +- **Custom Streaming:** `WebStream` exists specifically to manage streaming protocols that deviate from the standard SSE format, providing message splitting based on `StreamMode`. This ensures compatibility with providers like Cohere (delimiter-based) and Gemini (JSON array chunking). + +- **Generic JSON Response Handling:** `WebResponse` abstracts successful non-streaming responses by immediately parsing the body into `serde_json::Value`. This allows adapter modules to deserialize into their specific structures subsequently. + +- **Error Richness:** The `Error::ResponseFailedStatus` variant includes the `StatusCode`, full `body`, and `HeaderMap` to provide comprehensive debugging information upon API failure. + +- **Async Implementation:** All network operations rely on `tokio` and `reqwest`, ensuring non-blocking execution throughout the I/O layer. `WebStream` leverages `futures::Stream` traits for integration with standard Rust async infrastructure. diff --git a/doc/for-llm/api-reference-for-llm.md b/doc/for-llm/api-reference-for-llm.md new file mode 100644 index 00000000..f1b13d92 --- /dev/null +++ b/doc/for-llm/api-reference-for-llm.md @@ -0,0 +1,175 @@ +# GenAI API Reference for LLMs + +Dry, concise reference for the `genai` library. + +```toml +genai = "0.5.3" +``` + +## Core Concepts + +- **Client**: Main entry point (`genai::Client`). Thread-safe (`Arc` wrapper). +- **ModelIden**: `AdapterKind` + `ModelName`. Identifies which provider to use. +- **ServiceTarget**: Resolved `ModelIden`, `Endpoint`, and `AuthData`. +- **Resolvers**: Hooks to customize model mapping, authentication, and service endpoints. +- **AdapterKind**: Supported: `OpenAI`, `OpenAIResp`, `Gemini`, `Anthropic`, `Fireworks`, `Together`, `Groq`, `Mimo`, `Nebius`, `Xai`, `DeepSeek`, `Zai`, `BigModel`, `Cohere`, `Ollama`. + +## Client & Configuration + +### `Client` +- `Client::default()`: Standard client. +- `Client::builder()`: Returns `ClientBuilder`. +- `exec_chat(model, chat_req, options)`: Returns `ChatResponse`. +- `exec_chat_stream(model, chat_req, options)`: Returns `ChatStreamResponse`. +- `exec_embed(model, embed_req, options)`: Returns `EmbedResponse`. +- `embed(model, text, options)`: Convenience for single text embedding. +- `embed_batch(model, texts, options)`: Convenience for batch embedding. +- `resolve_service_target(model_name)`: Returns `ServiceTarget`. +- `all_model_names(adapter_kind)`: Returns a list of models for a provider (Ollama is dynamic). + +### `ClientBuilder` +- `with_auth_resolver(resolver)` / `with_auth_resolver_fn(f)`: Set sync/async auth lookup. +- `with_service_target_resolver(resolver)` / `with_service_target_resolver_fn(f)`: Full control over URL/Headers/Auth per call. +- `with_model_mapper(mapper)` / `with_model_mapper_fn(f)`: Map model names before execution. +- `with_chat_options(options)`: Set client-level default chat options. +- `with_web_config(web_config)`: Configure `reqwest` (timeouts, proxies, default headers). + +## Chat Request Structure + +### `ChatRequest` +- `system`: Initial system string (optional). +- `messages`: `Vec`. +- `tools`: `Vec` (optional). +- `from_system(text)`, `from_user(text)`, `from_messages(vec)`: Constructors. +- `append_message(msg)`: Adds a message to the sequence. +- `append_messages(iter)`: Adds multiple messages. +- `append_tool(tool)`: Adds a single tool definition. +- `append_tool_use_from_stream_end(end, tool_response)`: Simplifies tool-use loops by appending the assistant turn (with thoughts/tools) and the tool result. +- `join_systems()`: Concatenates all system content (top-level + system-role messages) into one string. + +### `ChatMessage` +- `role`: `System`, `User`, `Assistant`, `Tool`. +- `content`: `MessageContent` (multipart). +- `options`: `MessageOptions` (e.g., `cache_control: Ephemeral` for Anthropic). +- **Constructors**: `ChatMessage::system(text)`, `user(text)`, `assistant(text)`. +- **Tool Handoff**: `assistant_tool_calls_with_thoughts(calls, thoughts)` for continuing tool exchanges where thoughts must precede tool calls. + +### `MessageContent` (Multipart) +- Transparent wrapper for `Vec`. +- **Constructors**: `from_text(text)`, `from_parts(vec)`, `from_tool_calls(vec)`. +- **Methods**: `joined_texts()` (joins with blank line), `first_text()`, `prepend(part)`, `extend_front(parts)`. +- `ContentPart` variants: + - `Text(String)`: Plain text. + - `Binary(Binary)`: Images/PDFs/Audio. + - `ToolCall(ToolCall)`: Model-requested function call. + - `ToolResponse(ToolResponse)`: Result of function call. + - `ThoughtSignature(String)`: Reasoning/thoughts (e.g., Gemini/Anthropic). + +### `Binary` +- `content_type`: MIME (e.g., `image/jpeg`, `application/pdf`). +- `source`: `Url(String)` or `Base64(Arc)`. +- `from_file(path)`: Reads file and detects MIME. +- `is_image()`, `is_audio()`, `is_pdf()`: Type checks. +- `size()`: Approximate in-memory size in bytes. + +## Chat Options & Features + +### `ChatOptions` +- `temperature`, `max_tokens`, `top_p`. +- `stop_sequences`: `Vec`. +- `response_format`: `ChatResponseFormat::JsonMode` or `JsonSpec(name, schema)`. +- `reasoning_effort`: `Low`, `Medium`, `High`, `Budget(u32)`, `None`. +- `verbosity`: `Low`, `Medium`, `High` (e.g., for GPT-5). +- `normalize_reasoning_content`: Extract `` blocks into response field. +- `capture_usage`, `capture_content`, `capture_reasoning_content`, `capture_tool_calls`: (Streaming) Accumulate results in `StreamEnd`. +- `seed`: Deterministic generation. +- `service_tier`: `Flex`, `Auto`, `Default` (OpenAI). +- `extra_headers`: `Headers` added to the request. + +## Embedding + +### `EmbedRequest` +- `input`: `EmbedInput` (Single string or Batch `Vec`). + +### `EmbedOptions` +- `dimensions`, `encoding_format` ("float", "base64"). +- `user`, `truncate` ("NONE", "START", "END"). +- `embedding_type`: Provider specific (e.g., "search_document" for Cohere, "RETRIEVAL_QUERY" for Gemini). + +### `EmbedResponse` +- `embeddings`: `Vec` (contains `vector: Vec`, `index`, `dimensions`). +- `usage`: `Usage`. +- `model_iden`, `provider_model_iden`. + +## Tooling + +### `Tool` +- `name`, `description`, `schema` (JSON Schema). +- `config`: Optional provider-specific config. + +### `ToolCall` +- `call_id`, `fn_name`, `fn_arguments` (JSON `Value`). +- `thought_signatures`: Leading thoughts associated with the call (captured during streaming). + +### `ToolResponse` +- `call_id`, `content` (Result as string, usually JSON). + +## Responses & Streaming + +### `ChatResponse` +- `content`: `MessageContent`. +- `reasoning_content`: Extracted thoughts (if normalized). +- `usage`: `Usage`. +- `model_iden`, `provider_model_iden`. +- `first_text()`, `into_first_text()`, `tool_calls()`. + +### `ChatStream` +- Sequence of `ChatStreamEvent`: `Start`, `Chunk(text)`, `ReasoningChunk(text)`, `ThoughtSignatureChunk(text)`, `ToolCallChunk(ToolCall)`, `End(StreamEnd)`. + +### `StreamEnd` +- `captured_usage`: `Option`. +- `captured_content`: Concatenated `MessageContent` (text, tools, thoughts). +- `captured_reasoning_content`: Concatenated reasoning content. +- `captured_first_text()`, `captured_tool_calls()`, `captured_thought_signatures()`. +- `into_assistant_message_for_tool_use()`: Returns a `ChatMessage` ready for the next request in a tool-use flow. + +## Usage & Metadata + +### `Usage` +- `prompt_tokens`: Total input tokens. +- `completion_tokens`: Total output tokens. +- `total_tokens`: Sum of input and output. +- `prompt_tokens_details`: `cache_creation_tokens`, `cached_tokens`, `audio_tokens`. +- `completion_tokens_details`: `reasoning_tokens`, `audio_tokens`, `accepted_prediction_tokens`, `rejected_prediction_tokens`. + +## Resolvers & Auth + +### `AuthData` +- `Key(String)`: The API key. +- `FromEnv(String)`: Env var name to lookup. +- `RequestOverride { url, headers }`: For unorthodox auth or endpoint overrides (e.g., Vertex AI, Bedrock). + +### `AuthResolver` +- `from_resolver_fn(f)` / `from_resolver_async_fn(f)`. +- Resolves `AuthData` based on `ModelIden`. + +### `ServiceTargetResolver` +- `from_resolver_fn(f)` / `from_resolver_async_fn(f)`. +- Maps `ServiceTarget` (Model, Auth, Endpoint) to a final call target. + +### `Headers` +- `merge(overlay)`, `applied_to(target)`. +- Iteration and `From` conversions for `HashMap`, `Vec<(K,V)>`, etc. + +## Model Resolution Nuances + +- **Auto-detection**: `AdapterKind` inferred from name (e.g., `gpt-` -> `OpenAI`, `claude-` -> `Anthropic`, `gemini-` -> `Gemini`, `command` -> `Cohere`, `grok` -> `Xai`, `glm` -> `Zai`). +- **Namespacing**: `namespace::model_name` (e.g., `together::meta-llama/...`, `nebius::Qwen/...`). +- **Ollama Fallback**: Unrecognized names default to `Ollama` adapter (localhost:11434). +- **Reasoning**: Automatic extraction for DeepSeek/Ollama when `normalize_reasoning_content` is enabled. + +## Error Handling + +- `genai::Error`: Covers `ChatReqHasNoMessages`, `RequiresApiKey`, `WebModelCall`, `StreamParse`, `AdapterNotSupported`, `Resolver`, etc. +- `Result`: Alias for `core::result::Result`. +- `size()`: Many types implement `.size()` for approximate memory tracking. diff --git a/examples/c00-readme.rs b/examples/c00-readme.rs index eb2f7fe7..83fb4760 100644 --- a/examples/c00-readme.rs +++ b/examples/c00-readme.rs @@ -15,7 +15,7 @@ const MODEL_GROQ: &str = "llama-3.1-8b-instant"; const MODEL_OLLAMA: &str = "gemma:2b"; // sh: `ollama pull gemma:2b` const MODEL_XAI: &str = "grok-3-mini"; const MODEL_DEEPSEEK: &str = "deepseek-chat"; -const MODEL_ZHIPU: &str = "glm-4-plus"; +const MODEL_ZAI: &str = "glm-4-plus"; const MODEL_COHERE: &str = "command-r7b-12-2024"; // NOTE: These are the default environment keys for each AI Adapter Type. @@ -31,7 +31,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ (MODEL_XAI, "XAI_API_KEY"), (MODEL_DEEPSEEK, "DEEPSEEK_API_KEY"), (MODEL_OLLAMA, ""), - (MODEL_ZHIPU, "ZHIPU_API_KEY"), + (MODEL_ZAI, "ZAI_API_KEY"), (MODEL_COHERE, "COHERE_API_KEY"), ]; @@ -41,7 +41,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ // - starts_with "command" -> Cohere // - starts_with "gemini" -> Gemini // - model in Groq models -> Groq -// - starts_with "glm" -> Zhipu +// - starts_with "glm" -> ZAI // - For anything else -> Ollama // // This can be customized; see `examples/c03-mapper.rs` diff --git a/examples/c06-model-spec.rs b/examples/c06-model-spec.rs new file mode 100644 index 00000000..518cdd0e --- /dev/null +++ b/examples/c06-model-spec.rs @@ -0,0 +1,66 @@ +//! This example shows how to use a custom AdapterKindResolver to have some custom +//! mapping from a model name to an AdapterKind. +//! This allows mapping missing models to their Adapter implementations. + +use genai::adapter::AdapterKind; +use genai::chat::{ChatMessage, ChatRequest}; +use genai::resolver::{AuthData, Endpoint}; +use genai::{Client, ModelIden, ModelSpec, ServiceTarget}; +use tracing_subscriber::EnvFilter; + +pub enum AppModel { + Fast, + Pro, + Local, + Custom(String), +} + +impl From<&AppModel> for ModelSpec { + fn from(model: &AppModel) -> Self { + match model { + AppModel::Fast => ModelSpec::from_static_name("gemini-3-flash-preview"), + + // ModelName will be Arc (use `ModelIden::from_static(..) for micro optimization) + AppModel::Pro => ModelSpec::from_iden((AdapterKind::Anthropic, "claude-opus-4-5")), + + AppModel::Local => ModelSpec::Target(ServiceTarget { + model: ModelIden::from_static(AdapterKind::Ollama, "gemma3:1b"), + endpoint: Endpoint::from_static("http://localhost:11434"), + auth: AuthData::Key("".to_string()), + }), + + AppModel::Custom(name) => ModelSpec::from_name(name), + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::new("genai=debug")) + // .with_max_level(tracing::Level::DEBUG) // To enable all sub-library tracing + .init(); + + // -- Model Spec (unselect one below) + let model_spec = AppModel::Fast; + // let model_spec = AppModel::Custom("gpt-5.2".to_string()); + + let question = "Why is the sky red? (be concise)"; + + // -- Build the new client with this client_config + let client = Client::default(); + + // -- Build the chat request + let chat_req = ChatRequest::new(vec![ChatMessage::user(question)]); + + // -- Execute and print + println!("\n--- Question:\n{question}"); + let chat_res = client.exec_chat(&model_spec, chat_req.clone(), None).await?; + + let model_iden = chat_res.model_iden; + let res_txt = chat_res.content.into_joined_texts().ok_or("Should have some response")?; + + println!("\n--- Answer: ({model_iden})\n{res_txt}"); + + Ok(()) +} diff --git a/examples/c07-image.rs b/examples/c07-image.rs index 7b17e0dd..2520c425 100644 --- a/examples/c07-image.rs +++ b/examples/c07-image.rs @@ -1,12 +1,16 @@ //! This example demonstrates how to properly attach image to the conversations use genai::Client; -use genai::chat::printer::print_chat_stream; use genai::chat::{ChatMessage, ChatRequest, ContentPart}; use tracing_subscriber::EnvFilter; -const MODEL: &str = "gpt-4o-mini"; -const IMAGE_URL: &str = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"; +const MODEL: &str = "gpt-5.1"; +// const MODEL: &str = "claude-sonnet-4-5"; + +// const IMAGE_URL: &str = "https://aipack.ai/images/test-duck.jpg"; + +const IMAGE_SOME_PATH: &str = "tests/data/duck-small.jpg"; +const IMAGE_OTHER_ONE_PATH: &str = "tests/data/other-one.png"; #[tokio::main] async fn main() -> Result<(), Box> { @@ -17,20 +21,46 @@ async fn main() -> Result<(), Box> { let client = Client::default(); - let question = "What is in this picture?"; - let mut chat_req = ChatRequest::default().with_system("Answer in one sentence"); - // This is similar to sending initial system chat messages (which will be cumulative with system chat messages) - chat_req = chat_req.append_message(ChatMessage::user(vec![ - ContentPart::from_text(question), - ContentPart::from_binary_url("image/jpg", IMAGE_URL, None), - ])); + chat_req = chat_req + .append_message(ChatMessage::user(vec![ + ContentPart::from_text("here is the file: 'some-image.jpg'"), // To test when name is different, should take precedence + ContentPart::from_binary_file(IMAGE_SOME_PATH)?, + ])) + .append_message(ChatMessage::user(vec![ + ContentPart::from_text("here is the file: 'other-one.png'"), // this is the most model portable way to provide image name/info + ContentPart::from_binary_file(IMAGE_OTHER_ONE_PATH)?, + ])); + + let questions = [ + "What is the first image about? and what is the file name for this image if you have it?", + "What is the second image about? and what is the file name for this image if you have it?", + ]; + + for question in questions { + println!("\nQuestion: {question}"); + + let chat_req = chat_req.clone().append_message(ChatMessage::user(question)); + let chat_res = client.exec_chat(MODEL, chat_req, None).await?; + + let usage = chat_res.usage; + let response_content = chat_res.content.joined_texts().ok_or("Should have response")?; + + println!("\nAnswer: {response_content}"); + println!( + "prompt: {:?} tokens | completion: {:?} tokens", + usage.prompt_tokens, usage.completion_tokens + ); - println!("\n--- Question:\n{question}"); - let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), None).await?; + println!(); + } - println!("\n--- Answer: (streaming)"); - let _assistant_answer = print_chat_stream(chat_res, None).await?; + // NOTE: For web url image, we can `from_binary_url` but not supported by all models (e.g., Anthropic does not support those) + // ContentPart::from_binary_url( + // "image/jpg", + // IMAGE_URL, + // None, + // ) Ok(()) } diff --git a/examples/c07-zai.rs b/examples/c07-zai.rs new file mode 100644 index 00000000..3228bb1f --- /dev/null +++ b/examples/c07-zai.rs @@ -0,0 +1,58 @@ +//! ZAI (Zhipu AI) adapter example +//! +//! Demonstrates how to use ZAI models with automatic endpoint routing: +//! - `glm-4.6` β†’ Regular credit-based API +//! - `zai::glm-4.6` β†’ Coding subscription API (automatically routed) + +use genai::Client; +use genai::chat::{ChatMessage, ChatRequest}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::builder().build(); + + // Test cases demonstrating automatic endpoint routing + let test_cases = vec![ + ("glm-4.6", "Regular ZAI model"), + ("zai-coding::glm-4.6", "Coding subscription model"), + ]; + + for (model_name, description) in test_cases { + println!("\n=== {} ===", description); + println!("Model: {}", model_name); + + let chat_req = ChatRequest::default() + .with_system("You are a helpful assistant.") + .append_message(ChatMessage::user("Say 'hello' and nothing else.")); + + match client.exec_chat(model_name, chat_req, None).await { + Ok(response) => { + println!("βœ… Success!"); + if let Some(content) = response.first_text() { + println!("Response: {}", content); + } + if response.usage.prompt_tokens.is_some() || response.usage.completion_tokens.is_some() { + println!( + "Usage: prompt={}, output={}", + response.usage.prompt_tokens.unwrap_or(0), + response.usage.completion_tokens.unwrap_or(0) + ); + } + } + Err(e) => { + println!("❌ Error: {}", e); + if e.to_string().contains("insufficient balance") { + println!("ℹ️ This model requires credits or subscription"); + } else if e.to_string().contains("401") { + println!("ℹ️ Set ZAI_API_KEY environment variable"); + } + } + } + } + + println!("\n=== SUMMARY ==="); + println!("βœ… ZAI adapter handles namespace routing automatically"); + println!("βœ… Use ZAI_API_KEY environment variable"); + + Ok(()) +} diff --git a/examples/c10-tooluse-streaming.rs b/examples/c10-tooluse-streaming.rs index 1589001d..800093ed 100644 --- a/examples/c10-tooluse-streaming.rs +++ b/examples/c10-tooluse-streaming.rs @@ -7,7 +7,8 @@ use serde_json::json; use tracing_subscriber::EnvFilter; // const MODEL: &str = "gemini-2.0-flash"; -const MODEL: &str = "deepseek-chat"; +// const MODEL: &str = "deepseek-chat"; +const MODEL: &str = "gemini-3-pro-preview"; #[tokio::main] async fn main() -> Result<(), Box> { @@ -18,6 +19,8 @@ async fn main() -> Result<(), Box> { let client = Client::default(); + println!("--- Model: {MODEL}"); + // 1. Define a tool for getting weather information let weather_tool = Tool::new("get_weather") .with_description("Get the current weather for a location") @@ -53,6 +56,8 @@ async fn main() -> Result<(), Box> { let mut chat_stream = client.exec_chat_stream(MODEL, chat_req.clone(), Some(&chat_options)).await?; let mut tool_calls: Vec = [].to_vec(); + let mut captured_thoughts: Option> = None; + // print_chat_stream(chat_res, Some(&print_options)).await?; println!("--- Streaming response with tool calls"); while let Some(result) = chat_stream.stream.next().await { @@ -63,25 +68,53 @@ async fn main() -> Result<(), Box> { ChatStreamEvent::Chunk(chunk) => { print!("{}", chunk.content); } - ChatStreamEvent::ToolCallChunk(tool_chunk) => { - println!( - "\nTool Call: {} with args: {}", - tool_chunk.tool_call.fn_name, tool_chunk.tool_call.fn_arguments - ); + ChatStreamEvent::ToolCallChunk(chunk) => { + println!(" ToolCallChunk: {:?}", chunk.tool_call); } ChatStreamEvent::ReasoningChunk(chunk) => { - println!("\nReasoning: {}", chunk.content); + println!(" ReasoningChunk: {:?}", chunk.content); + } + ChatStreamEvent::ThoughtSignatureChunk(chunk) => { + println!(" ThoughtSignatureChunk: {:?}", chunk.content); } ChatStreamEvent::End(end) => { println!("\nStream ended"); // Check if we captured any tool calls - if let Some(captured_tool_calls) = end.captured_into_tool_calls() { - println!("\nCaptured Tool Calls:"); - tool_calls = captured_tool_calls.clone(); - for tool_call in captured_tool_calls { - println!("- Function: {}", tool_call.fn_name); - println!(" Arguments: {}", tool_call.fn_arguments); + // Note: captured_into_tool_calls consumes self, so we can't use end afterwards. + // We should access captured_content directly or use references if possible, + // but StreamEnd getters often consume or clone. + // Let's access captured_content directly since we need both tool calls and thoughts. + + if let Some(content) = end.captured_content { + // Let's refactor to avoid ownership issues. + // We have `content` (MessageContent). + // We want `tool_calls` (Vec) and `thoughts` (Vec). + + // We can iterate and split. + let parts = content.into_parts(); + let mut extracted_tool_calls = Vec::new(); + let mut extracted_thoughts = Vec::new(); + + for part in parts { + match part { + genai::chat::ContentPart::ToolCall(tc) => extracted_tool_calls.push(tc), + genai::chat::ContentPart::ThoughtSignature(t) => extracted_thoughts.push(t), + _ => {} + } + } + + if !extracted_tool_calls.is_empty() { + println!("\nCaptured Tool Calls:"); + for tool_call in &extracted_tool_calls { + println!("- Function: {}", tool_call.fn_name); + println!(" Arguments: {}", tool_call.fn_arguments); + } + tool_calls = extracted_tool_calls; + } + + if !extracted_thoughts.is_empty() { + captured_thoughts = Some(extracted_thoughts); } } } @@ -107,7 +140,19 @@ async fn main() -> Result<(), Box> { ); // Add both the tool calls and response to chat history - let chat_req = chat_req.append_message(tool_calls).append_message(tool_response); + // Note: For Gemini 3, we MUST include the thoughtSignature in the history if it was generated. + let mut assistant_msg = ChatMessage::from(tool_calls); + if let Some(thoughts) = captured_thoughts { + // We need to insert the thought at the beginning. + // MessageContent wraps Vec, but doesn't expose insert. + // We can convert to Vec, insert, and convert back. + let mut parts = assistant_msg.content.into_parts(); + for thought in thoughts.into_iter().rev() { + parts.insert(0, genai::chat::ContentPart::ThoughtSignature(thought)); + } + assistant_msg.content = genai::chat::MessageContent::from_parts(parts); + } + let chat_req = chat_req.append_message(assistant_msg).append_message(tool_response); // Get final streaming response let chat_options = ChatOptions::default(); diff --git a/examples/c11-tooluse-deterministic.rs b/examples/c11-tooluse-deterministic.rs new file mode 100644 index 00000000..9eb51654 --- /dev/null +++ b/examples/c11-tooluse-deterministic.rs @@ -0,0 +1,62 @@ +use genai::Client; +use genai::chat::{ChatMessage, ChatRequest, Tool, ToolCall, ToolResponse}; +use serde_json::json; + +const MODEL: &str = "gemini-3-flash-preview"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::default(); + + let weather_tool = Tool::new("get_weather") + .with_description("Get the current weather for a location") + .with_schema(json!({ + "type": "object", + "properties": { + "city": { "type": "string" }, + "unit": { "type": "string", "enum": ["C", "F"] } + }, + "required": ["city", "unit"] + })); + + // Create a synthetic conversation history. These tool calls were not generated by + // Gemini 3 (or any LLM) and do not have a thought signature. This is useful for e.g. + // pre-seeding a conversation with tool calls and responses, or executing deterministic + // logic "in-band" with the LLM conversation. genai will correctly inject the required + // string "skip_thought_signature_validator" in place of a valid signature. Otherwise, + // the call would error out on Gemini 3 models. + let messages = vec![ + ChatMessage::user("What's the weather like in Paris?"), + ChatMessage::assistant(vec![ToolCall { + call_id: "call_123".to_string(), + fn_name: "get_weather".to_string(), + fn_arguments: json!({"city": "Paris", "unit": "C"}), + thought_signatures: None, + }]), + ChatMessage::from(ToolResponse::new( + "call_123".to_string(), + json!({"temperature": 15, "condition": "Cloudy"}).to_string(), + )), + ]; + + let chat_req = ChatRequest::new(messages).with_tools(vec![weather_tool]); + + println!("--- Model: {MODEL}"); + println!("--- Sending deterministic history (synthetic tool call)..."); + + match client.exec_chat(MODEL, chat_req, None).await { + Ok(chat_res) => { + println!("\n--- Response received successfully:"); + if let Some(text) = chat_res.first_text() { + println!("{}", text); + } + } + Err(e) => { + eprintln!("\n--- Error: Request failed!"); + eprintln!("{}", e); + return Err(e.into()); + } + } + + Ok(()) +} diff --git a/examples/c18-cache-ttl.rs b/examples/c18-cache-ttl.rs new file mode 100644 index 00000000..6504c47b --- /dev/null +++ b/examples/c18-cache-ttl.rs @@ -0,0 +1,179 @@ +//! Example demonstrating prompt caching with mixed TTLs (1h + 5m). +//! +//! This example shows how to: +//! 1. Use `CacheControl::Ephemeral1h` and `CacheControl::Ephemeral5m` on system messages +//! 2. Verify cache creation on the first request +//! 3. Verify cache hits on a subsequent identical request +//! 4. Inspect TTL-specific token breakdowns +//! +//! Requires: ANTHROPIC_API_KEY environment variable +//! +//! Run with: `cargo run --example c18-cache-ttl` + +use genai::Client; +use genai::chat::{CacheControl, ChatMessage, ChatRequest}; + +const MODEL: &str = "claude-haiku-4-5-20251001"; + +/// Large text for the 1h-cached system message (~3000 tokens). +/// Anthropic requires a minimum of 2048 tokens per cacheable block for Haiku. +fn long_system_text() -> String { + let paragraph = "The field of artificial intelligence has seen remarkable progress over \ + the past decade. Machine learning models have grown from simple classifiers to complex \ + systems capable of generating human-like text, creating images from descriptions, and \ + solving intricate reasoning problems. These advances have been driven by improvements \ + in hardware, the availability of massive datasets, and breakthroughs in model \ + architectures such as transformers. The transformer architecture, introduced in 2017, \ + revolutionized natural language processing by enabling models to attend to all parts of \ + an input sequence simultaneously, rather than processing tokens one at a time. This \ + parallel processing capability, combined with scaling laws that showed predictable \ + performance improvements with increased compute and data, led to the development of \ + large language models that can perform a wide variety of tasks. "; + // Repeat enough to exceed 4096 tokens (Haiku 4.5 minimum for caching) + paragraph.repeat(40) +} + +/// Large text for the 5m-cached system message (~3000 tokens). +fn medium_system_text() -> String { + let paragraph = "When responding to user queries, always provide clear, accurate, and \ + well-structured answers. Break down complex topics into digestible parts. Use examples \ + where appropriate to illustrate concepts. Maintain a professional and helpful tone \ + throughout the conversation. If you are unsure about something, say so rather than \ + guessing. Cite sources when possible. Keep responses concise but thorough. "; + // Repeat enough to exceed 4096 tokens (Haiku 4.5 minimum for caching) + paragraph.repeat(55) +} + +fn build_chat_request(user_msg: &str) -> ChatRequest { + let sys1 = ChatMessage::system(long_system_text()).with_options(CacheControl::Ephemeral1h); + let sys2 = ChatMessage::system(medium_system_text()).with_options(CacheControl::Ephemeral5m); + + ChatRequest::default() + .append_message(sys1) + .append_message(sys2) + .append_message(ChatMessage::user(user_msg)) +} + +fn get_or_zero(val: Option) -> i32 { + val.unwrap_or(0) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::default(); + let mut all_passed = true; + + // -- Request 1: Cache creation + println!("=== Request 1: Cache Creation ===\n"); + + let req = build_chat_request("What is 2+2?"); + let res = client.exec_chat(MODEL, req, None).await?; + + if let Some(text) = res.content.first_text() { + println!("Response: {}\n", text); + } + + let usage = &res.usage; + let details = usage.prompt_tokens_details.as_ref(); + + let prompt_tokens = get_or_zero(usage.prompt_tokens); + let completion_tokens = get_or_zero(usage.completion_tokens); + let total_tokens = get_or_zero(usage.total_tokens); + let cache_creation_tokens = get_or_zero(details.and_then(|d| d.cache_creation_tokens)); + let cached_tokens = get_or_zero(details.and_then(|d| d.cached_tokens)); + let eph_1h = get_or_zero( + details + .and_then(|d| d.cache_creation_details.as_ref()) + .and_then(|cd| cd.ephemeral_1h_tokens), + ); + let eph_5m = get_or_zero( + details + .and_then(|d| d.cache_creation_details.as_ref()) + .and_then(|cd| cd.ephemeral_5m_tokens), + ); + + println!(" prompt_tokens: {prompt_tokens}"); + println!(" completion_tokens: {completion_tokens}"); + println!(" total_tokens: {total_tokens}"); + println!(" cache_creation_tokens: {cache_creation_tokens}"); + println!(" cached_tokens: {cached_tokens}"); + println!(" ephemeral_1h_tokens: {eph_1h}"); + println!(" ephemeral_5m_tokens: {eph_5m}"); + println!(); + + // Verify creation request + if cache_creation_tokens <= 0 { + println!(" FAIL: cache_creation_tokens should be > 0"); + all_passed = false; + } + if cached_tokens != 0 { + println!(" FAIL: cached_tokens should be 0 on first request, got {cached_tokens}"); + all_passed = false; + } + if eph_1h <= 0 { + println!(" FAIL: ephemeral_1h_tokens should be > 0"); + all_passed = false; + } + if eph_5m <= 0 { + println!(" FAIL: ephemeral_5m_tokens should be > 0"); + all_passed = false; + } + + // -- Request 2: Cache read + println!("=== Request 2: Cache Read ===\n"); + + let req = build_chat_request("What is 3+3?"); + let res = client.exec_chat(MODEL, req, None).await?; + + if let Some(text) = res.content.first_text() { + println!("Response: {}\n", text); + } + + let usage = &res.usage; + let details = usage.prompt_tokens_details.as_ref(); + + let prompt_tokens = get_or_zero(usage.prompt_tokens); + let completion_tokens = get_or_zero(usage.completion_tokens); + let total_tokens = get_or_zero(usage.total_tokens); + let cache_creation_tokens = get_or_zero(details.and_then(|d| d.cache_creation_tokens)); + let cached_tokens = get_or_zero(details.and_then(|d| d.cached_tokens)); + let eph_1h = get_or_zero( + details + .and_then(|d| d.cache_creation_details.as_ref()) + .and_then(|cd| cd.ephemeral_1h_tokens), + ); + let eph_5m = get_or_zero( + details + .and_then(|d| d.cache_creation_details.as_ref()) + .and_then(|cd| cd.ephemeral_5m_tokens), + ); + + println!(" prompt_tokens: {prompt_tokens}"); + println!(" completion_tokens: {completion_tokens}"); + println!(" total_tokens: {total_tokens}"); + println!(" cache_creation_tokens: {cache_creation_tokens}"); + println!(" cached_tokens: {cached_tokens}"); + println!(" ephemeral_1h_tokens: {eph_1h}"); + println!(" ephemeral_5m_tokens: {eph_5m}"); + println!(); + + // Verify cache hit + if cached_tokens <= 0 { + println!(" FAIL: cached_tokens should be > 0 (cache hit)"); + all_passed = false; + } + if cache_creation_tokens != 0 { + println!(" FAIL: cache_creation_tokens should be 0 on cache hit, got {cache_creation_tokens}"); + all_passed = false; + } + + // -- Final result + println!(); + if all_passed { + println!("Cache TTL test PASSED"); + } else { + println!("Cache TTL test FAILED"); + } + + Ok(()) +} diff --git a/src/adapter/adapter_kind.rs b/src/adapter/adapter_kind.rs index 3e442ce6..d9ada5ef 100644 --- a/src/adapter/adapter_kind.rs +++ b/src/adapter/adapter_kind.rs @@ -1,22 +1,24 @@ use crate::adapter::adapters::together::TogetherAdapter; +use crate::adapter::adapters::zai::ZaiAdapter; use crate::adapter::anthropic::AnthropicAdapter; use crate::adapter::bedrock::{self, BedrockAdapter}; +use crate::adapter::bigmodel::BigModelAdapter; use crate::adapter::cerebras::CerebrasAdapter; use crate::adapter::cohere::CohereAdapter; use crate::adapter::deepseek::{self, DeepSeekAdapter}; use crate::adapter::fireworks::FireworksAdapter; use crate::adapter::gemini::GeminiAdapter; use crate::adapter::groq::{self, GroqAdapter}; +use crate::adapter::mimo::{self, MimoAdapter}; use crate::adapter::nebius::NebiusAdapter; use crate::adapter::openai::OpenAIAdapter; use crate::adapter::openrouter::OpenRouterAdapter; use crate::adapter::xai::XaiAdapter; -use crate::adapter::zai::{self, ZAiAdapter}; +use crate::adapter::zai; use crate::adapter::zhipu::ZhipuAdapter; use crate::{ModelName, Result}; use derive_more::Display; use serde::{Deserialize, Serialize}; -use tracing::info; /// AdapterKind is an enum that represents the different types of adapters that can be used to interact with the API. /// @@ -39,24 +41,28 @@ pub enum AdapterKind { Together, /// Reuse some of the OpenAI adapter behavior, customize some (e.g., normalize thinking budget) Groq, + /// For Mimo (Mostly use OpenAI) + Mimo, /// For Nebius (Mostly use OpenAI) Nebius, /// For xAI (Mostly use OpenAI) Xai, /// For DeepSeek (Mostly use OpenAI) DeepSeek, - /// For Zhipu (Mostly use OpenAI) - Zhipu, + /// For ZAI (OpenAI-compatible with dual endpoint support: zai:: and zai-coding::) + Zai, + /// For big model (only accessible via namespace bigmodel::) + BigModel, /// Cohere today use it's own native protocol but might move to OpenAI Adapter Cohere, /// OpenAI shared behavior + some custom. (currently, localhost only, can be customize with ServerTargetResolver). Ollama, /// Cerebras (OpenAI-compatible protocol) Cerebras, - /// Z.AI (Anthropic-compatible protocol) - ZAi, - /// AWS Bedrock (uses Converse API with AWS SigV4 authentication) + /// AWS Bedrock (uses Converse API with Bearer token authentication) Bedrock, + /// For Zhipu (legacy, kept for backwards compatibility) + Zhipu, } /// Serialization/Parse implementations @@ -72,15 +78,17 @@ impl AdapterKind { AdapterKind::Fireworks => "Fireworks", AdapterKind::Together => "Together", AdapterKind::Groq => "Groq", + AdapterKind::Mimo => "Mimo", AdapterKind::Nebius => "Nebius", AdapterKind::Xai => "xAi", AdapterKind::DeepSeek => "DeepSeek", - AdapterKind::Zhipu => "Zhipu", + AdapterKind::Zai => "Zai", + AdapterKind::BigModel => "BigModel", AdapterKind::Cohere => "Cohere", AdapterKind::Ollama => "Ollama", AdapterKind::Cerebras => "Cerebras", - AdapterKind::ZAi => "ZAi", AdapterKind::Bedrock => "Bedrock", + AdapterKind::Zhipu => "Zhipu", } } @@ -95,15 +103,17 @@ impl AdapterKind { AdapterKind::Fireworks => "fireworks", AdapterKind::Together => "together", AdapterKind::Groq => "groq", + AdapterKind::Mimo => "mimo", AdapterKind::Nebius => "nebius", AdapterKind::Xai => "xai", AdapterKind::DeepSeek => "deepseek", - AdapterKind::Zhipu => "zhipu", + AdapterKind::Zai => "zai", + AdapterKind::BigModel => "bigmodel", AdapterKind::Cohere => "cohere", AdapterKind::Ollama => "ollama", AdapterKind::Cerebras => "cerebras", - AdapterKind::ZAi => "zai", AdapterKind::Bedrock => "bedrock", + AdapterKind::Zhipu => "zhipu", } } @@ -117,15 +127,17 @@ impl AdapterKind { "fireworks" => Some(AdapterKind::Fireworks), "together" => Some(AdapterKind::Together), "groq" => Some(AdapterKind::Groq), + "mimo" => Some(AdapterKind::Mimo), "nebius" => Some(AdapterKind::Nebius), "xai" => Some(AdapterKind::Xai), "deepseek" => Some(AdapterKind::DeepSeek), - "zhipu" => Some(AdapterKind::Zhipu), + "zai" => Some(AdapterKind::Zai), + "bigmodel" => Some(AdapterKind::BigModel), "cohere" => Some(AdapterKind::Cohere), "ollama" => Some(AdapterKind::Ollama), "cerebras" => Some(AdapterKind::Cerebras), - "zai" => Some(AdapterKind::ZAi), "bedrock" => Some(AdapterKind::Bedrock), + "zhipu" => Some(AdapterKind::Zhipu), _ => None, } } @@ -144,16 +156,17 @@ impl AdapterKind { AdapterKind::Fireworks => Some(FireworksAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Together => Some(TogetherAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Groq => Some(GroqAdapter::API_KEY_DEFAULT_ENV_NAME), + AdapterKind::Mimo => Some(MimoAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Nebius => Some(NebiusAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Xai => Some(XaiAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::DeepSeek => Some(DeepSeekAdapter::API_KEY_DEFAULT_ENV_NAME), - AdapterKind::Zhipu => Some(ZhipuAdapter::API_KEY_DEFAULT_ENV_NAME), + AdapterKind::Zai => Some(ZaiAdapter::API_KEY_DEFAULT_ENV_NAME), + AdapterKind::BigModel => Some(BigModelAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Cohere => Some(CohereAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Ollama => None, AdapterKind::Cerebras => Some(CerebrasAdapter::API_KEY_DEFAULT_ENV_NAME), - AdapterKind::ZAi => Some(ZAiAdapter::API_KEY_DEFAULT_ENV_NAME), - // Bedrock uses Bearer token authentication AdapterKind::Bedrock => Some(BedrockAdapter::API_KEY_ENV), + AdapterKind::Zhipu => Some(ZhipuAdapter::API_KEY_DEFAULT_ENV_NAME), } } } @@ -171,7 +184,7 @@ impl AdapterKind { /// - Fireworks - contains "fireworks" (might add leading or trailing '/' later) /// - Groq - model in Groq models /// - DeepSeek - model in DeepSeek models (deepseek.com) - /// - Zhipu - starts_with "glm" + /// - Zai - model in ZAI models (glm series) /// - Cohere - starts_with "command" /// - Ollama - For anything else /// @@ -179,25 +192,22 @@ impl AdapterKind { /// - e.g., for together.ai `together::meta-llama/Llama-3-8b-chat-hf` /// - e.g., for nebius with `nebius::Qwen/Qwen3-235B-A22B` /// - e.g., for cerebras with `cerebras::llama-3.1-8b` + /// - e.g., for ZAI coding plan with `zai-coding::glm-4.6` /// /// And all adapters can be force namspaced as well. /// /// Note: At this point, this will never fail as the fallback is the Ollama adapter. /// This might change in the future, hence the Result return type. pub fn from_model(model: &str) -> Result { - // -- First check if namespaced (explicit :: namespace has priority) - if let (_, Some(ns)) = ModelName::model_name_and_namespace(model) { - if let Some(adapter) = Self::from_lower_str(ns) { - return Ok(adapter); - } else { - info!("No AdapterKind found for '{ns}'") - } - } + // -- First check if namespaced + if let Some(adapter) = Self::from_model_namespace(model) { + return Ok(adapter); + }; // -- Special handling for OpenRouter models (they start with provider names) // Only catch patterns without explicit :: namespace if model.contains('/') - && !model.contains("::") // Don't override explicit namespaces + && !model.contains("::") && (model.starts_with("openai/") || model.starts_with("anthropic/") || model.starts_with("meta-llama/") @@ -206,7 +216,7 @@ impl AdapterKind { return Ok(Self::OpenRouter); } - // -- Resolve from modelname + // -- Otherwise, Resolve from modelname if model.starts_with("o3") || model.starts_with("o4") || model.starts_with("o1") @@ -214,9 +224,8 @@ impl AdapterKind { || model.starts_with("codex") || (model.starts_with("gpt") && !model.starts_with("gpt-oss")) || model.starts_with("text-embedding") - // migh be a little generic on this one { - if model.starts_with("gpt") && model.contains("codex") { + if model.starts_with("gpt") && (model.contains("codex") || model.contains("pro")) { Ok(Self::OpenAIResp) } else { Ok(Self::OpenAI) @@ -226,11 +235,13 @@ impl AdapterKind { } else if model.starts_with("claude") { Ok(Self::Anthropic) } else if zai::MODELS.contains(&model) { - Ok(Self::ZAi) + Ok(Self::Zai) } else if model.contains("fireworks") { Ok(Self::Fireworks) } else if groq::MODELS.contains(&model) { Ok(Self::Groq) + } else if mimo::MODELS.contains(&model) { + Ok(Self::Mimo) } else if model.starts_with("command") || model.starts_with("embed-") { Ok(Self::Cohere) } else if deepseek::MODELS.contains(&model) { @@ -238,7 +249,7 @@ impl AdapterKind { } else if model.starts_with("grok") { Ok(Self::Xai) } else if model.starts_with("glm") { - Ok(Self::Zhipu) + Ok(Self::Zai) } // AWS Bedrock models (provider.model-name format) else if bedrock::MODELS.contains(&model) { @@ -250,3 +261,29 @@ impl AdapterKind { } } } + +// region: --- Support + +/// Inner api to return +impl AdapterKind { + fn from_model_namespace(model: &str) -> Option { + let (namespace, _) = ModelName::split_as_namespace_and_name(model); + let namespace = namespace?; + + // -- First, check if simple adapter lower string match + if let Some(adapter) = Self::from_lower_str(namespace) { + Some(adapter) + } + // -- Second, custom, for now, we hardcode this exception here (might become more generic later) + else if namespace == zai::ZAI_CODING_NAMESPACE { + Some(Self::Zai) + } + // + // -- Otherwise, no adapter from namespace, because no matching namespace + else { + None + } + } +} + +// endregion: --- Support diff --git a/src/adapter/adapters/anthropic/adapter_impl.rs b/src/adapter/adapters/anthropic/adapter_impl.rs index 33612efd..789b3b77 100644 --- a/src/adapter/adapters/anthropic/adapter_impl.rs +++ b/src/adapter/adapters/anthropic/adapter_impl.rs @@ -360,9 +360,7 @@ impl AnthropicAdapter { let completion_tokens: i32 = usage_value.x_take("output_tokens").ok().unwrap_or(0); // Parse cache_creation breakdown if present (TTL-specific breakdown) - let cache_creation_details = usage_value - .get("cache_creation") - .and_then(parse_cache_creation_details); + let cache_creation_details = usage_value.get("cache_creation").and_then(parse_cache_creation_details); // compute the prompt_tokens let prompt_tokens = input_tokens + cache_creation_input_tokens + cache_read_input_tokens; diff --git a/src/adapter/adapters/anthropic/mod.rs b/src/adapter/adapters/anthropic/mod.rs index e685fdae..133ff33e 100644 --- a/src/adapter/adapters/anthropic/mod.rs +++ b/src/adapter/adapters/anthropic/mod.rs @@ -1,7 +1,8 @@ -//! API Documentation: https://docs.anthropic.com/en/api/messages -//! Tool Documentation: https://docs.anthropic.com/en/docs/build-with-claude/tool-use -//! Model Names: https://docs.anthropic.com/en/docs/models-overview -//! Pricing: https://www.anthropic.com/pricing#anthropic-api +//! API Documentation: +//! Tool Documentation: +//! Effort Documentation: +//! Model Names: +//! Pricing: // region: --- Modules diff --git a/src/adapter/adapters/bedrock/adapter_impl.rs b/src/adapter/adapters/bedrock/adapter_impl.rs index 53c088ad..3f75ee96 100644 --- a/src/adapter/adapters/bedrock/adapter_impl.rs +++ b/src/adapter/adapters/bedrock/adapter_impl.rs @@ -280,7 +280,7 @@ impl Adapter for BedrockAdapter { fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { let base_url = endpoint.base_url(); - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); // URL encode the model ID (Bedrock model IDs contain colons) let encoded_model = urlencoding_encode(model_name); diff --git a/src/adapter/adapters/bigmodel/adapter_impl.rs b/src/adapter/adapters/bigmodel/adapter_impl.rs new file mode 100644 index 00000000..5e347b78 --- /dev/null +++ b/src/adapter/adapters/bigmodel/adapter_impl.rs @@ -0,0 +1,88 @@ +use crate::ModelIden; +use crate::adapter::openai::OpenAIAdapter; +use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; +use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::resolver::{AuthData, Endpoint}; +use crate::webc::WebResponse; +use crate::{Result, ServiceTarget}; +use reqwest::RequestBuilder; + +/// The BigModel adapter. Only available via namespace. +/// +pub struct BigModelAdapter; + +pub(in crate::adapter) const MODELS: &[&str] = &[]; + +impl BigModelAdapter { + pub const API_KEY_DEFAULT_ENV_NAME: &str = "BIGMODEL_API_KEY"; +} + +// The ZAI API is mostly compatible with the OpenAI API. +impl Adapter for BigModelAdapter { + fn default_endpoint() -> Endpoint { + const BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4/"; + Endpoint::from_static(BASE_URL) + } + + fn default_auth() -> AuthData { + AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) + } + + async fn all_model_names(_kind: AdapterKind) -> Result> { + Ok(MODELS.iter().map(|s| s.to_string()).collect()) + } + + fn get_service_url(_model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { + // For ZAI, we need to handle model-specific routing at this level + // because get_service_url is called with the modified endpoint from to_web_request_data + let base_url = endpoint.base_url(); + + let url = match service_type { + ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat/completions"), + ServiceType::Embed => format!("{base_url}embeddings"), + }; + Ok(url) + } + + fn to_web_request_data( + target: ServiceTarget, + service_type: ServiceType, + chat_req: ChatRequest, + chat_options: ChatOptionsSet<'_, '_>, + ) -> Result { + // Parse model name and determine appropriate endpoint + OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options, None) + } + + fn to_chat_response( + model_iden: ModelIden, + web_response: WebResponse, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_response(model_iden, web_response, options_set) + } + + fn to_chat_stream( + model_iden: ModelIden, + reqwest_builder: RequestBuilder, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set) + } + + fn to_embed_request_data( + service_target: crate::ServiceTarget, + embed_req: crate::embed::EmbedRequest, + options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_embed_request_data(service_target, embed_req, options_set) + } + + fn to_embed_response( + model_iden: crate::ModelIden, + web_response: crate::webc::WebResponse, + options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_embed_response(model_iden, web_response, options_set) + } +} diff --git a/src/adapter/adapters/bigmodel/mod.rs b/src/adapter/adapters/bigmodel/mod.rs new file mode 100644 index 00000000..63404fee --- /dev/null +++ b/src/adapter/adapters/bigmodel/mod.rs @@ -0,0 +1,14 @@ +//! Click the globe icon on the top-right corner of the page to switch language. +//! API Documentation: +//! Model Names: +//! Pricing: +//! +//! + +// region: --- Modules + +mod adapter_impl; + +pub use adapter_impl::*; + +// endregion: --- Modules diff --git a/src/adapter/adapters/cohere/adapter_impl.rs b/src/adapter/adapters/cohere/adapter_impl.rs index ef6b2a07..e0e07a16 100644 --- a/src/adapter/adapters/cohere/adapter_impl.rs +++ b/src/adapter/adapters/cohere/adapter_impl.rs @@ -80,7 +80,7 @@ impl Adapter for CohereAdapter { } = Self::into_cohere_request_parts(model.clone(), chat_req)?; // -- Build the basic payload - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); let stream = matches!(service_type, ServiceType::ChatStream); let mut payload = json!({ "model": model_name.to_string(), @@ -118,10 +118,9 @@ impl Adapter for CohereAdapter { fn to_chat_response( model_iden: ModelIden, web_response: WebResponse, - options_set: ChatOptionsSet<'_, '_>, + _options_set: ChatOptionsSet<'_, '_>, ) -> Result { let WebResponse { mut body, .. } = web_response; - let captured_raw_body = options_set.capture_raw_body().unwrap_or_default().then(|| body.clone()); // -- Capture the provider_model_iden // TODO: Need to be implemented (if available), for now, just clone model_iden @@ -148,7 +147,7 @@ impl Adapter for CohereAdapter { model_iden, provider_model_iden, usage, - captured_raw_body, + captured_raw_body: None, // Set by the client exec_chat }) } diff --git a/src/adapter/adapters/cohere/embed.rs b/src/adapter/adapters/cohere/embed.rs index 119446f1..0bfb32ba 100644 --- a/src/adapter/adapters/cohere/embed.rs +++ b/src/adapter/adapters/cohere/embed.rs @@ -1,5 +1,5 @@ //! Cohere Embeddings API implementation -//! API Documentation: https://docs.cohere.com/reference/embed +//! API Documentation: use crate::adapter::adapters::support::get_api_key; use crate::adapter::{Adapter, ServiceType, WebRequestData}; @@ -83,7 +83,7 @@ pub fn to_embed_request_data( let api_key = get_api_key(auth, &model)?; // Extract the actual model name (without namespace) - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); // Build headers let mut headers = Headers::from(vec![ diff --git a/src/adapter/adapters/cohere/mod.rs b/src/adapter/adapters/cohere/mod.rs index 8d35d71a..d105ec4e 100644 --- a/src/adapter/adapters/cohere/mod.rs +++ b/src/adapter/adapters/cohere/mod.rs @@ -1,6 +1,6 @@ -//! API DOC: https://docs.cohere.com/reference/chat -//! MODEL NAMES: https://docs.cohere.com/docs/models -//! PRICING: https://cohere.com/pricing +//! API DOC: +//! MODEL NAMES: +//! PRICING: // region: --- Modules diff --git a/src/adapter/adapters/deepseek/mod.rs b/src/adapter/adapters/deepseek/mod.rs index 2d2c105b..bfe5b5ae 100644 --- a/src/adapter/adapters/deepseek/mod.rs +++ b/src/adapter/adapters/deepseek/mod.rs @@ -1,6 +1,6 @@ -//! API Documentation: https://api-docs.deepseek.com/ -//! Model Names: https://api-docs.deepseek.com/quick_start/pricing -//! Pricing: https://api-docs.deepseek.com/quick_start/pricing +//! API Documentation: +//! Model Names: +//! Pricing: // region: --- Modules diff --git a/src/adapter/adapters/fireworks/adapter_impl.rs b/src/adapter/adapters/fireworks/adapter_impl.rs index e361c545..9c82c287 100644 --- a/src/adapter/adapters/fireworks/adapter_impl.rs +++ b/src/adapter/adapters/fireworks/adapter_impl.rs @@ -59,7 +59,7 @@ impl Adapter for FireworksAdapter { if !target.model.model_name.contains('/') { target.model = target.model.from_name(format!( "accounts/fireworks/models/{}", - target.model.model_name.as_model_name_and_namespace().0 + target.model.model_name.namespace_and_name().1 )) } // NOTE: Fireworks max_tokens is set at 2K by default, which is unpractical for most task. @@ -68,9 +68,13 @@ impl Adapter for FireworksAdapter { // NOTE: The `genai` strategy is to set a large max_tokens value, letting the model enforce its own lower limit by default to avoid unpleasant and confusing surprises. // Users can use [`ChatOptions`] to specify a specific max_tokens value. // NOTE: As mentioned in the Fireworks FAQ above, typically, for Fireworks-hosted models the top max_tokens is equal to the context window. - // Since, Qwen3 models are at 256k, so we will use this upper bound (without going to the 1M/10M of Llama 4). + // Since, Qwen3 models are at 256k, so we will use this upper bound (without going to the 1M/10M of Llama 4) for non-streaming. + // However, since anything over 5k requires streaming API, we cap the default to 5k for non-streaming here so that the request doesn't fail. let custom = ToWebRequestCustom { - default_max_tokens: Some(256_000), + default_max_tokens: match service_type { + ServiceType::ChatStream => Some(256_000), + _ => Some(5_000), + }, }; OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options, Some(custom)) diff --git a/src/adapter/adapters/fireworks/mod.rs b/src/adapter/adapters/fireworks/mod.rs index 8ff83801..f7a984f8 100644 --- a/src/adapter/adapters/fireworks/mod.rs +++ b/src/adapter/adapters/fireworks/mod.rs @@ -1,7 +1,7 @@ //! Click the globe icon on the top-right corner of the page to switch language. -//! API Documentation: https://fireworks.ai/docs/getting-started/introduction -//! Model Names: https://fireworks.ai/models -//! Pricing: https://fireworks.ai/pricing#serverless-pricing +//! API Documentation: +//! Model Names: +//! Pricing: // region: --- Modules diff --git a/src/adapter/adapters/gemini/adapter_impl.rs b/src/adapter/adapters/gemini/adapter_impl.rs index 8469e0b9..c2f32a97 100644 --- a/src/adapter/adapters/gemini/adapter_impl.rs +++ b/src/adapter/adapters/gemini/adapter_impl.rs @@ -18,6 +18,7 @@ pub struct GeminiAdapter; // Note: Those model names are just informative, as the Gemini AdapterKind is selected on `startsWith("gemini")` const MODELS: &[&str] = &[ // + "gemini-3-pro-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite", @@ -29,6 +30,25 @@ const REASONING_LOW: u32 = 1000; const REASONING_MEDIUM: u32 = 8000; const REASONING_HIGH: u32 = 24000; +/// Important +/// - For now Low and Minimal aare the same for geminia +/// - +fn insert_gemini_thinking_budget_value(payload: &mut Value, effort: &ReasoningEffort) -> Result<()> { + // -- for now, match minimal to Low (because zero is not supported by 2.5 pro) + let budget = match effort { + ReasoningEffort::None => None, + ReasoningEffort::Low | ReasoningEffort::Minimal => Some(REASONING_LOW), + ReasoningEffort::Medium => Some(REASONING_MEDIUM), + ReasoningEffort::High => Some(REASONING_HIGH), + ReasoningEffort::Budget(budget) => Some(*budget), + }; + + if let Some(budget) = budget { + payload.x_insert("/generationConfig/thinkingConfig/thinkingBudget", budget)?; + } + Ok(()) +} + // curl \ // -H 'Content-Type: application/json' \ // -d '{"contents":[{"parts":[{"text":"Explain how AI works"}]}]}' \ @@ -57,7 +77,7 @@ impl Adapter for GeminiAdapter { /// this will return the URL without the API_KEY in it. The API_KEY will need to be added by the caller. fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { let base_url = endpoint.base_url(); - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); let url = match service_type { ServiceType::Chat => format!("{base_url}models/{model_name}:generateContent"), ServiceType::ChatStream => format!("{base_url}models/{model_name}:streamGenerateContent"), @@ -73,7 +93,7 @@ impl Adapter for GeminiAdapter { options_set: ChatOptionsSet<'_, '_>, ) -> Result { let ServiceTarget { endpoint, auth, model } = target; - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); // -- api_key let api_key = get_api_key(auth, &model)?; @@ -82,16 +102,18 @@ impl Adapter for GeminiAdapter { let headers = Headers::from(("x-goog-api-key".to_string(), api_key.to_string())); // -- Reasoning Budget - let (provider_model_name, reasoning_budget) = match (model_name, options_set.reasoning_effort()) { - // No explicity reasoning_effor, try to infer from model name suffix (supports -zero) + let (provider_model_name, computed_reasoning_effort) = match (model_name, options_set.reasoning_effort()) { + // No explicity reasoning_effort, try to infer from model name suffix (supports -zero) (model, None) => { // let model_name: &str = &model.model_name; if let Some((prefix, last)) = model_name.rsplit_once('-') { let reasoning = match last { - "zero" => Some(REASONING_ZERO), - "low" => Some(REASONING_LOW), - "medium" => Some(REASONING_MEDIUM), - "high" => Some(REASONING_HIGH), + // 'zero' is a gemini special + "zero" => Some(ReasoningEffort::Budget(REASONING_ZERO)), + "none" => Some(ReasoningEffort::None), + "low" | "minimal" => Some(ReasoningEffort::Low), + "medium" => Some(ReasoningEffort::Medium), + "high" => Some(ReasoningEffort::High), _ => None, }; // create the model name if there was a `-..` reasoning suffix @@ -102,19 +124,8 @@ impl Adapter for GeminiAdapter { (model, None) } } - // If reasoning effort, turn the low, medium, budget ones into Budget - (model, Some(effort)) => { - let effort = match effort { - // -- for now, match minimal to Low (because zero is not supported by 2.5 pro) - ReasoningEffort::None => 0, // No reasoning for None - ReasoningEffort::Minimal => REASONING_LOW, - ReasoningEffort::Low => REASONING_LOW, - ReasoningEffort::Medium => REASONING_MEDIUM, - ReasoningEffort::High => REASONING_HIGH, - ReasoningEffort::Budget(budget) => *budget, - }; - (model, Some(effort)) - } + // TOOD: make it more elegant + (model, Some(effort)) => (model, Some(effort.clone())), }; // -- parts @@ -130,8 +141,28 @@ impl Adapter for GeminiAdapter { }); // -- Set the reasoning effort - if let Some(budget) = reasoning_budget { - payload.x_insert("/generationConfig/thinkingConfig/thinkingBudget", budget)?; + if let Some(computed_reasoning_effort) = computed_reasoning_effort { + // -- For gemini-3 use the thinkingLevel if Low or High (does not support medium for now) + if provider_model_name.contains("gemini-3") { + match computed_reasoning_effort { + ReasoningEffort::Low | ReasoningEffort::Minimal => { + payload.x_insert("/generationConfig/thinkingConfig/thinkingLevel", "LOW")?; + } + ReasoningEffort::High => { + payload.x_insert("/generationConfig/thinkingConfig/thinkingLevel", "HIGH")?; + } + // Fallback on thinkingBudget + other => { + insert_gemini_thinking_budget_value(&mut payload, &other)?; + } + } + } + // -- Otherwise, Do thinking budget + else { + insert_gemini_thinking_budget_value(&mut payload, &computed_reasoning_effort)?; + } + // -- Always include thoughts when reasoning effort is set since you are already paying for them + payload.x_insert("/generationConfig/thinkingConfig/includeThoughts", true)?; } // Note: It's unclear from the spec if the content of systemInstruction should have a role. @@ -164,7 +195,7 @@ impl Adapter for GeminiAdapter { } true }); - payload.x_insert("/generationConfig/responseSchema", schema)?; + payload.x_insert("/generationConfig/responseJsonSchema", schema)?; } // -- Add supported ChatOptions @@ -193,12 +224,10 @@ impl Adapter for GeminiAdapter { fn to_chat_response( model_iden: ModelIden, web_response: WebResponse, - options_set: ChatOptionsSet<'_, '_>, + _options_set: ChatOptionsSet<'_, '_>, ) -> Result { let WebResponse { mut body, .. } = web_response; - let captured_raw_body = options_set.capture_raw_body().unwrap_or_default().then(|| body.clone()); - // -- Capture the provider_model_iden // TODO: Need to be implemented (if available), for now, just clone model_iden let provider_model_name: Option = body.x_remove("modelVersion").ok(); @@ -209,22 +238,56 @@ impl Adapter for GeminiAdapter { usage, } = gemini_response; - // FIXME: Needs to take the content list - let mut content: MessageContent = MessageContent::default(); + let mut thoughts: Vec = Vec::new(); + let mut reasonings: Vec = Vec::new(); + let mut texts: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + for g_item in gemini_content { match g_item { - GeminiChatContent::Text(text) => content.push(text), - GeminiChatContent::ToolCall(tool_call) => content.push(tool_call), + GeminiChatContent::Text(text) => texts.push(text), + GeminiChatContent::ToolCall(tool_call) => tool_calls.push(tool_call), + GeminiChatContent::ThoughtSignature(thought) => thoughts.push(thought), + GeminiChatContent::Reasoning(reasoning_text) => reasonings.push(reasoning_text), } } + let thought_signatures_for_call = (!thoughts.is_empty() && !tool_calls.is_empty()).then(|| thoughts.clone()); + let mut parts: Vec = thoughts.into_iter().map(ContentPart::ThoughtSignature).collect(); + + if let Some(signatures) = thought_signatures_for_call + && let Some(first_call) = tool_calls.first_mut() + { + first_call.thought_signatures = Some(signatures); + } + + if !texts.is_empty() { + let total_len: usize = texts.iter().map(|t| t.len()).sum(); + let mut combined_text = String::with_capacity(total_len); + for text in texts { + combined_text.push_str(&text); + } + if !combined_text.is_empty() { + parts.push(ContentPart::Text(combined_text)); + } + } + let mut reasoning_text = String::new(); + if !reasonings.is_empty() { + for reasoning in &reasonings { + reasoning_text.push_str(reasoning); + } + } + + parts.extend(tool_calls.into_iter().map(ContentPart::ToolCall)); + let content = MessageContent::from_parts(parts); + Ok(ChatResponse { content, - reasoning_content: None, + reasoning_content: Some(reasoning_text), model_iden, provider_model_iden, usage, - captured_raw_body, + captured_raw_body: None, // Set by the client exec_chat }) } @@ -294,6 +357,36 @@ impl GeminiAdapter { }; for mut part in parts { + // -- Capture eventual thought signature + { + if let Some(thought_signature) = part + .x_take::("thoughtSignature") + .ok() + .and_then(|v| if let Value::String(v) = v { Some(v) } else { None }) + { + content.push(GeminiChatContent::ThoughtSignature(thought_signature)); + } + // Note: sometime the thought is in "thought" (undocumented, but observed in some cases or older models?) + // But for Gemini 3 it is thoughtSignature. Keeping this just in case or for backward compat if it was used. + // Actually, let's stick to thoughtSignature as per docs, but if we see "thought" we might want to capture it too. + // Let's check for "thought" if "thoughtSignature" was not found. + else if let Some(thought) = part + .x_take::("thought") + .ok() + .and_then(|v| if let Value::Bool(v) = v { Some(v) } else { None }) + { + if thought { + if let Some(val) = part + .x_take::("text") + .ok() + .and_then(|v| if let Value::String(v) = v { Some(v) } else { None }) + { + content.push(GeminiChatContent::Reasoning(val)); + } + } + } + } + // -- Capture eventual function call if let Ok(fn_call_value) = part.x_take::("functionCall") { let tool_call = ToolCall { @@ -461,8 +554,10 @@ impl GeminiAdapter { } })); } - ContentPart::ThoughtSignature(_) => { - // Thought signatures are not directly supported by Gemini, skip for now + ContentPart::ThoughtSignature(thought) => { + parts_values.push(json!({ + "thoughtSignature": thought + })); } } } @@ -471,23 +566,71 @@ impl GeminiAdapter { } ChatRole::Assistant => { let mut parts_values: Vec = Vec::new(); + let mut pending_thought: Option = None; + let mut is_first_tool_call = true; + for part in msg.content { match part { - ContentPart::Text(text) => parts_values.push(json!({"text": text})), + ContentPart::Text(text) => { + if let Some(thought) = pending_thought.take() { + parts_values.push(json!({"thoughtSignature": thought})); + } + parts_values.push(json!({"text": text})); + } ContentPart::ToolCall(tool_call) => { - parts_values.push(json!({ - "functionCall": { + let mut part_obj = serde_json::Map::new(); + part_obj.insert( + "functionCall".to_string(), + json!({ "name": tool_call.fn_name, "args": tool_call.fn_arguments, + }), + ); + + match pending_thought.take() { + Some(thought) => { + // Inject thoughtSignature alongside functionCall in the same Part object + part_obj.insert("thoughtSignature".to_string(), json!(thought)); } - })); + None => { + // For Gemini 3 models, if there haven't been any thoughts, and this is + // still the first tool call, we are required to inject a special flag. + // See: https://ai.google.dev/gemini-api/docs/thought-signatures#faqs + let is_gemini_3 = model_iden.model_name.contains("gemini-3"); + if is_gemini_3 && is_first_tool_call { + part_obj.insert( + "thoughtSignature".to_string(), + json!("skip_thought_signature_validator"), + ); + } + } + } + + parts_values.push(Value::Object(part_obj)); + is_first_tool_call = false; + } + ContentPart::ThoughtSignature(thought) => { + if let Some(prev_thought) = pending_thought.take() { + parts_values.push(json!({"thoughtSignature": prev_thought})); + } + pending_thought = Some(thought); } // Ignore unsupported parts for Assistant role - ContentPart::Binary(_) => {} - ContentPart::ToolResponse(_) => {} - ContentPart::ThoughtSignature(_) => {} + ContentPart::Binary(_) => { + if let Some(thought) = pending_thought.take() { + parts_values.push(json!({"thoughtSignature": thought})); + } + } + ContentPart::ToolResponse(_) => { + if let Some(thought) = pending_thought.take() { + parts_values.push(json!({"thoughtSignature": thought})); + } + } } } + if let Some(thought) = pending_thought { + parts_values.push(json!({"thoughtSignature": thought})); + } if !parts_values.is_empty() { contents.push(json!({"role": "model", "parts": parts_values})); } @@ -515,10 +658,15 @@ impl GeminiAdapter { } })); } + ContentPart::ThoughtSignature(thought) => { + parts_values.push(json!({ + "thoughtSignature": thought + })); + } _ => { return Err(Error::MessageContentTypeNotSupported { model_iden: model_iden.clone(), - cause: "ChatRole::Tool can only contain ToolCall or ToolResponse content parts", + cause: "ChatRole::Tool can only contain ToolCall, ToolResponse, or Thought content parts", }); } } @@ -587,6 +735,8 @@ pub(super) struct GeminiChatResponse { pub(super) enum GeminiChatContent { Text(String), ToolCall(ToolCall), + Reasoning(String), + ThoughtSignature(String), } struct GeminiChatRequestParts { diff --git a/src/adapter/adapters/gemini/embed.rs b/src/adapter/adapters/gemini/embed.rs index 7e142872..e1a63ab4 100644 --- a/src/adapter/adapters/gemini/embed.rs +++ b/src/adapter/adapters/gemini/embed.rs @@ -1,5 +1,5 @@ //! Gemini Embeddings API implementation -//! API Documentation: https://ai.google.dev/gemini-api/docs/embeddings +//! API Documentation: use crate::adapter::adapters::support::get_api_key; use crate::adapter::{Adapter, ServiceType, WebRequestData}; @@ -83,9 +83,6 @@ pub fn to_embed_request_data( let ServiceTarget { model, auth, .. } = service_target; let api_key = get_api_key(auth, &model)?; - // Extract the actual model name (without namespace) - not needed for Gemini request body - let (_model_name, _) = model.model_name.as_model_name_and_namespace(); - // Build headers - Gemini uses x-goog-api-key header let mut headers = Headers::from(vec![ ("x-goog-api-key".to_string(), api_key), @@ -97,8 +94,9 @@ pub fn to_embed_request_data( headers.merge_with(custom_headers); } + // Extract the actual model name (without namespace) - not needed for Gemini request body // Get the model name for the request - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); let full_model_name = format!("models/{model_name}",); // Convert EmbedRequest to Gemini format and determine URL diff --git a/src/adapter/adapters/gemini/mod.rs b/src/adapter/adapters/gemini/mod.rs index c34237b8..d62d7492 100644 --- a/src/adapter/adapters/gemini/mod.rs +++ b/src/adapter/adapters/gemini/mod.rs @@ -1,6 +1,6 @@ -//! API Documentation: https://ai.google.dev/api/rest/v1beta/models/generateContent -//! Model Names: https://ai.google.dev/gemini-api/docs/models/gemini -//! Pricing: https://ai.google.dev/pricing +//! API Documentation: +//! Model Names: +//! Pricing: // region: --- Modules diff --git a/src/adapter/adapters/gemini/streamer.rs b/src/adapter/adapters/gemini/streamer.rs index df18fd3d..0ee1ea13 100644 --- a/src/adapter/adapters/gemini/streamer.rs +++ b/src/adapter/adapters/gemini/streamer.rs @@ -10,6 +10,8 @@ use std::task::{Context, Poll}; use super::GeminiChatContent; +use std::collections::VecDeque; + pub struct GeminiStreamer { inner: WebStream, options: StreamerOptions, @@ -18,6 +20,7 @@ pub struct GeminiStreamer { /// Flag to not poll the EventSource after a MessageStop event. done: bool, captured_data: StreamerCapturedData, + pending_events: VecDeque, } impl GeminiStreamer { @@ -27,6 +30,7 @@ impl GeminiStreamer { done: false, options: StreamerOptions::new(model_iden, options_set), captured_data: Default::default(), + pending_events: VecDeque::new(), } } } @@ -40,6 +44,11 @@ impl futures::Stream for GeminiStreamer { return Poll::Ready(None); } + // 1. Check if we have pending events + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(Some(Ok(event))); + } + while let Poll::Ready(item) = Pin::new(&mut self.inner).poll_next(cx) { match item { Some(Ok(raw_message)) => { @@ -47,18 +56,18 @@ impl futures::Stream for GeminiStreamer { // - `[` document start // - `{...}` block // - `]` document end - let inter_event = match raw_message.as_str() { - "[" => InterStreamEvent::Start, + match raw_message.as_str() { + "[" => return Poll::Ready(Some(Ok(InterStreamEvent::Start))), "]" => { let inter_stream_end = InterStreamEnd { captured_usage: self.captured_data.usage.take(), captured_text_content: self.captured_data.content.take(), captured_reasoning_content: self.captured_data.reasoning_content.take(), captured_tool_calls: self.captured_data.tool_calls.take(), - captured_thought_signatures: None, + captured_thought_signatures: self.captured_data.thought_signatures.take(), }; - InterStreamEvent::End(inter_stream_end) + return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end)))); } block_string => { // -- Parse the block to JSON @@ -91,16 +100,54 @@ impl futures::Stream for GeminiStreamer { // -- Extract text and toolcall // WARNING: Assume that only ONE tool call per message (or take the last one) let mut stream_text_content: String = String::new(); + let mut stream_reasoning_content: Option = None; let mut stream_tool_call: Option = None; + let mut stream_thought: Option = None; + for g_content_item in content { match g_content_item { + GeminiChatContent::Reasoning(reasoning) => { + stream_reasoning_content = Some(reasoning) + } GeminiChatContent::Text(text) => stream_text_content.push_str(&text), GeminiChatContent::ToolCall(tool_call) => stream_tool_call = Some(tool_call), + GeminiChatContent::ThoughtSignature(thought) => stream_thought = Some(thought), + } + } + + // -- Queue Events + // Priority: Thought -> Text -> ToolCall + + // 1. Thought + if let Some(thought) = stream_thought { + // Capture thought + match self.captured_data.thought_signatures { + Some(ref mut thoughts) => thoughts.push(thought.clone()), + None => self.captured_data.thought_signatures = Some(vec![thought.clone()]), + } + + if self.options.capture_usage { + self.captured_data.usage = Some(usage.clone()); } + + self.pending_events.push_back(InterStreamEvent::ThoughtSignatureChunk(thought)); + } + if let Some(reasoning_content) = stream_reasoning_content { + // Capture reasoning content + if self.options.capture_content { + match self.captured_data.reasoning_content { + Some(ref mut rc) => rc.push_str(&reasoning_content), + None => self.captured_data.reasoning_content = Some(reasoning_content.clone()), + } + } + if self.options.capture_usage { + self.captured_data.usage = Some(usage.clone()); + } + self.pending_events + .push_back(InterStreamEvent::ReasoningChunk(reasoning_content)); } - // -- Send Event - // WARNING: Assume only text or toolcall (not both on the same event) + // 2. Text if !stream_text_content.is_empty() { // Capture content if self.options.capture_content { @@ -110,18 +157,15 @@ impl futures::Stream for GeminiStreamer { } } - // NOTE: Apparently in the Gemini API, all events have cumulative usage, - // meaning each message seems to include the tokens for all previous streams. - // Thus, we do not need to add it; we only need to replace captured_data.usage with the latest one. - // See https://twitter.com/jeremychone/status/1813734565967802859 for potential additional information. if self.options.capture_usage { - self.captured_data.usage = Some(usage); + self.captured_data.usage = Some(usage.clone()); } - InterStreamEvent::Chunk(stream_text_content) + self.pending_events.push_back(InterStreamEvent::Chunk(stream_text_content)); } - // tool call - else if let Some(tool_call) = stream_tool_call { + + // 3. Tool Call + if let Some(tool_call) = stream_tool_call { if self.options.capture_tool_calls { match self.captured_data.tool_calls { Some(ref mut tool_calls) => tool_calls.push(tool_call.clone()), @@ -131,14 +175,15 @@ impl futures::Stream for GeminiStreamer { if self.options.capture_usage { self.captured_data.usage = Some(usage); } - InterStreamEvent::ToolCallChunk(tool_call) - } else { - continue; + self.pending_events.push_back(InterStreamEvent::ToolCallChunk(tool_call)); + } + + // Return the first event if any + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(Some(Ok(event))); } } }; - - return Poll::Ready(Some(Ok(inter_event))); } Some(Err(err)) => { tracing::error!("Gemini Adapter Stream Error: {}", err); diff --git a/src/adapter/adapters/groq/mod.rs b/src/adapter/adapters/groq/mod.rs index 9c5bbd0c..e95858f2 100644 --- a/src/adapter/adapters/groq/mod.rs +++ b/src/adapter/adapters/groq/mod.rs @@ -1,6 +1,6 @@ -//! API Documentation: https://console.groq.com/docs/api-reference#chat -//! Model Names: https://console.groq.com/docs/models -//! Pricing: https://groq.com/pricing/ +//! API Documentation: +//! Model Names: +//! Pricing: // region: --- Modules diff --git a/src/adapter/adapters/mimo/adapter_impl.rs b/src/adapter/adapters/mimo/adapter_impl.rs new file mode 100644 index 00000000..d89ef898 --- /dev/null +++ b/src/adapter/adapters/mimo/adapter_impl.rs @@ -0,0 +1,76 @@ +use crate::ModelIden; +use crate::adapter::openai::OpenAIAdapter; +use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; +use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::resolver::{AuthData, Endpoint}; +use crate::webc::WebResponse; +use crate::{Result, ServiceTarget}; +use reqwest::RequestBuilder; + +pub struct MimoAdapter; + +pub(in crate::adapter) const MODELS: &[&str] = &["mimo-v2-flash"]; + +impl MimoAdapter { + pub const API_KEY_DEFAULT_ENV_NAME: &str = "MIMO_API_KEY"; +} + +impl Adapter for MimoAdapter { + fn default_auth() -> AuthData { + AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) + } + + fn default_endpoint() -> Endpoint { + const BASE_URL: &str = "https://api.xiaomimimo.com/v1/"; + Endpoint::from_static(BASE_URL) + } + + async fn all_model_names(_kind: AdapterKind) -> Result> { + Ok(MODELS.iter().map(|s| s.to_string()).collect()) + } + + fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { + OpenAIAdapter::util_get_service_url(model, service_type, endpoint) + } + + fn to_web_request_data( + target: ServiceTarget, + service_type: ServiceType, + chat_req: ChatRequest, + chat_options: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options, None) + } + + fn to_chat_response( + model_iden: ModelIden, + web_response: WebResponse, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_response(model_iden, web_response, options_set) + } + + fn to_chat_stream( + model_iden: ModelIden, + reqwest_builder: RequestBuilder, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set) + } + + fn to_embed_request_data( + service_target: crate::ServiceTarget, + embed_req: crate::embed::EmbedRequest, + options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_embed_request_data(service_target, embed_req, options_set) + } + + fn to_embed_response( + model_iden: crate::ModelIden, + web_response: crate::webc::WebResponse, + options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_embed_response(model_iden, web_response, options_set) + } +} diff --git a/src/adapter/adapters/mimo/mod.rs b/src/adapter/adapters/mimo/mod.rs new file mode 100644 index 00000000..997f386a --- /dev/null +++ b/src/adapter/adapters/mimo/mod.rs @@ -0,0 +1,7 @@ +// region: --- Modules + +mod adapter_impl; + +pub use adapter_impl::*; + +// endregion: --- Modules diff --git a/src/adapter/adapters/mod.rs b/src/adapter/adapters/mod.rs index 935645ba..fffddf18 100644 --- a/src/adapter/adapters/mod.rs +++ b/src/adapter/adapters/mod.rs @@ -2,12 +2,14 @@ mod support; pub(super) mod anthropic; pub(super) mod bedrock; +pub(super) mod bigmodel; pub(super) mod cerebras; pub(super) mod cohere; pub(super) mod deepseek; pub(super) mod fireworks; pub(super) mod gemini; pub(super) mod groq; +pub(super) mod mimo; pub(super) mod nebius; pub(super) mod ollama; pub(super) mod openai; diff --git a/src/adapter/adapters/nebius/mod.rs b/src/adapter/adapters/nebius/mod.rs index a347bf0c..1fdeee1b 100644 --- a/src/adapter/adapters/nebius/mod.rs +++ b/src/adapter/adapters/nebius/mod.rs @@ -1,6 +1,6 @@ -//! API Documentation: https://studio.nebius.com/api-reference -//! Model Names: https://studio.nebius.com/ -//! Endpoint: https://api.studio.nebius.ai/v1/ +//! API Documentation: +//! Model Names: +//! Endpoint: // region: --- Modules diff --git a/src/adapter/adapters/ollama/adapter_impl.rs b/src/adapter/adapters/ollama/adapter_impl.rs index b0c2b061..4e6c173f 100644 --- a/src/adapter/adapters/ollama/adapter_impl.rs +++ b/src/adapter/adapters/ollama/adapter_impl.rs @@ -1,4 +1,4 @@ -//! API DOC: https://github.com/ollama/ollama/blob/main/docs/openai.md +//! API DOC: use crate::adapter::openai::OpenAIAdapter; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; @@ -14,7 +14,7 @@ use value_ext::JsonValueExt; pub struct OllamaAdapter; /// Note: For now, it uses the OpenAI compatibility layer -/// (https://github.com/ollama/ollama/blob/main/docs/openai.md) +/// () /// Since the base Ollama API supports `application/x-ndjson` for streaming, whereas others support `text/event-stream` impl Adapter for OllamaAdapter { fn default_endpoint() -> Endpoint { diff --git a/src/adapter/adapters/ollama/mod.rs b/src/adapter/adapters/ollama/mod.rs index b0d785b7..2ddf1258 100644 --- a/src/adapter/adapters/ollama/mod.rs +++ b/src/adapter/adapters/ollama/mod.rs @@ -1,7 +1,7 @@ //! NOTE: Currently, GenAI uses the OpenAI compatibility layer, except for listing models. -//! OPENAI API DOC: https://platform.openai.com/docs/api-reference/chat -//! OLLAMA API DOC: https://github.com/ollama/ollama/blob/main/docs/api.md -//! OLLAMA Models: https://ollama.com/library +//! OPENAI API DOC: +//! OLLAMA API DOC: +//! OLLAMA Models: // region: --- Modules diff --git a/src/adapter/adapters/openai/embed.rs b/src/adapter/adapters/openai/embed.rs index 2a131b09..8c9d439c 100644 --- a/src/adapter/adapters/openai/embed.rs +++ b/src/adapter/adapters/openai/embed.rs @@ -83,7 +83,7 @@ pub fn to_embed_request_data( }; // Extract the actual model name (without namespace) - let (model_name, _) = model.model_name.as_model_name_and_namespace(); + let (_, model_name) = model.model_name.namespace_and_name(); let openai_req = OpenAIEmbedRequest { input, diff --git a/src/adapter/adapters/openai/streamer.rs b/src/adapter/adapters/openai/streamer.rs index b1d9b136..e16e5625 100644 --- a/src/adapter/adapters/openai/streamer.rs +++ b/src/adapter/adapters/openai/streamer.rs @@ -197,7 +197,7 @@ impl futures::Stream for OpenAIStreamer { self.captured_data.usage = Some(usage) } AdapterKind::DeepSeek - | AdapterKind::ZAi + | AdapterKind::Zai | AdapterKind::Fireworks | AdapterKind::Together => { let usage = message_data diff --git a/src/adapter/adapters/support.rs b/src/adapter/adapters/support.rs index 98712f72..d78f31d8 100644 --- a/src/adapter/adapters/support.rs +++ b/src/adapter/adapters/support.rs @@ -46,6 +46,7 @@ pub struct StreamerCapturedData { pub content: Option, pub reasoning_content: Option, pub tool_calls: Option>, + pub thought_signatures: Option>, } // endregion: --- Streamer Captured Data diff --git a/src/adapter/adapters/zai/adapter_impl.rs b/src/adapter/adapters/zai/adapter_impl.rs index 205bb0b8..82835549 100644 --- a/src/adapter/adapters/zai/adapter_impl.rs +++ b/src/adapter/adapters/zai/adapter_impl.rs @@ -2,70 +2,112 @@ use crate::ModelIden; use crate::adapter::openai::OpenAIAdapter; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; -use crate::embed::{EmbedOptionsSet, EmbedRequest, EmbedResponse}; use crate::resolver::{AuthData, Endpoint}; use crate::webc::WebResponse; use crate::{Result, ServiceTarget}; use reqwest::RequestBuilder; -pub struct ZAiAdapter; +pub const ZAI_CODING_NAMESPACE: &str = "zai-coding"; + +/// Helper structure to hold ZAI model parsing information +struct ZaiModelEndpoint { + endpoint: Endpoint, +} + +impl ZaiModelEndpoint { + /// Parse ModelIden to determine if it's a coding model and return endpoint + fn from_model(model: &ModelIden) -> Self { + let (namespace, _) = model.model_name.namespace_and_name(); + + // Check if namespace is "zai" to route to coding endpoint + let endpoint = match namespace { + Some(ZAI_CODING_NAMESPACE) => Endpoint::from_static("https://api.z.ai/api/coding/paas/v4/"), + _ => ZaiAdapter::default_endpoint(), + }; + + Self { endpoint } + } +} + +/// The ZAI API is mostly compatible with the OpenAI API. +/// +/// NOTE: This adapter will automatically route to the coding endpoint +/// when the model name starts with "zai::". +/// +/// For example, `glm-4.6` uses the regular API endpoint, +/// while `zai::glm-4.6` uses the coding plan endpoint. +/// +pub struct ZaiAdapter; -// Z.AI model names -// Based on https://z.ai/model-api documentation -// These are the models Z.AI supports pub(in crate::adapter) const MODELS: &[&str] = &[ - "glm-4.6", "glm-4.5", "glm-4", "glm-4.1v", "glm-4.5v", "vidu", "vidu-q1", - "vidu-2.0", - // Note: No turbo models are supported by Z.AI + "glm-4-plus", + "glm-4.6", + "glm-4.5", + "glm-4.5v", + "glm-4.5-x", + "glm-4.5-air", + "glm-4.5-airx", + "glm-4-32b-0414-128k", + "glm-4.5-flash", + "glm-4-air-250414", + "glm-4-flashx-250414", + "glm-4-flash-250414", + "glm-4-air", + "glm-4-airx", + "glm-4-long", + "glm-4-flash", + "glm-4v-plus-0111", + "glm-4v-flash", + "glm-z1-air", + "glm-z1-airx", + "glm-z1-flash", + "glm-z1-flashx", + "glm-4.1v-thinking-flash", + "glm-4.1v-thinking-flashx", ]; -impl ZAiAdapter { +impl ZaiAdapter { pub const API_KEY_DEFAULT_ENV_NAME: &str = "ZAI_API_KEY"; } -// Z.AI adapter uses OpenAI-compatible implementation (most common format) -// Note: This may need adjustment based on actual Z.AI API documentation -impl Adapter for ZAiAdapter { - fn default_auth() -> AuthData { - AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) - } - +// The ZAI API is mostly compatible with the OpenAI API. +impl Adapter for ZaiAdapter { fn default_endpoint() -> Endpoint { - const BASE_URL: &str = "https://api.z.ai/v1/"; + const BASE_URL: &str = "https://api.z.ai/api/paas/v4/"; Endpoint::from_static(BASE_URL) } + fn default_auth() -> AuthData { + AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) + } + async fn all_model_names(_kind: AdapterKind) -> Result> { Ok(MODELS.iter().map(|s| s.to_string()).collect()) } - fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { - OpenAIAdapter::util_get_service_url(model, service_type, endpoint) + fn get_service_url(_model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result { + // For ZAI, we need to handle model-specific routing at this level + // because get_service_url is called with the modified endpoint from to_web_request_data + let base_url = endpoint.base_url(); + + let url = match service_type { + ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat/completions"), + ServiceType::Embed => format!("{base_url}embeddings"), + }; + Ok(url) } fn to_web_request_data( - target: ServiceTarget, + mut target: ServiceTarget, service_type: ServiceType, chat_req: ChatRequest, chat_options: ChatOptionsSet<'_, '_>, ) -> Result { - OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options, None) - } - - fn to_embed_request_data( - target: ServiceTarget, - embed_req: EmbedRequest, - options_set: EmbedOptionsSet<'_, '_>, - ) -> Result { - OpenAIAdapter::to_embed_request_data(target, embed_req, options_set) - } + // Parse model name and determine appropriate endpoint + let zai_info = ZaiModelEndpoint::from_model(&target.model); + target.endpoint = zai_info.endpoint; - fn to_embed_response( - model_iden: ModelIden, - web_response: WebResponse, - options_set: EmbedOptionsSet<'_, '_>, - ) -> Result { - OpenAIAdapter::to_embed_response(model_iden, web_response, options_set) + OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options, None) } fn to_chat_response( @@ -83,4 +125,23 @@ impl Adapter for ZAiAdapter { ) -> Result { OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set) } + + fn to_embed_request_data( + mut service_target: crate::ServiceTarget, + embed_req: crate::embed::EmbedRequest, + options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + let zai_info = ZaiModelEndpoint::from_model(&service_target.model); + service_target.endpoint = zai_info.endpoint; + + OpenAIAdapter::to_embed_request_data(service_target, embed_req, options_set) + } + + fn to_embed_response( + model_iden: crate::ModelIden, + web_response: crate::webc::WebResponse, + options_set: crate::embed::EmbedOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_embed_response(model_iden, web_response, options_set) + } } diff --git a/src/adapter/adapters/zai/mod.rs b/src/adapter/adapters/zai/mod.rs index 2acb6732..be9942c4 100644 --- a/src/adapter/adapters/zai/mod.rs +++ b/src/adapter/adapters/zai/mod.rs @@ -1,7 +1,40 @@ -//! API Documentation: https://z.ai/docs -//! Model Names: https://z.ai/docs/models -//! Pricing: https://z.ai/docs/pricing -//! Note: Z.AI API is compatible with Anthropic's API +//! ZAI API Documentation +//! API Documentation: +//! Model Names: GLM series models +//! Pricing: +//! +//! ## Dual Endpoint Support +//! +//! ZAI supports two different API endpoints using the ServiceTargetResolver pattern: +//! +//! ### Regular API (Credit-based) (default for those models or with `zai::` namespace) +//! - Endpoint: `` +//! - Models: `glm-4.6`, `glm-4.5`, etc. +//! - Usage: Standard API calls billed per token +//! +//! ### Coding Plan (Subscription-based only with the `zai-coding::` namepace) +//! - Endpoint: `` +//! - Models: `zai-coding::glm-4.6`, `zai-coding::glm-4.5`, etc. +//! - Usage: Fixed monthly subscription for coding tasks +//! +//! ## For example +//! +//! ```rust +//! use genai::resolver::{Endpoint, ServiceTargetResolver}; +//! use genai::{Client, AdapterKind, ModelIden}; +//! +//! let client = Client::builder().with_service_target_resolver(target_resolver).build(); +//! +//! // Use regular API +//! let response = client.exec_chat("glm-4.6", chat_request, None).await?; +//! // Same, regular API +//! let response = client.exec_chat("zai::glm-4.6", chat_request, None).await?; +//! +//! // Use coding plan +//! let response = client.exec_chat("zai-coding::glm-4.6", chat_request, None).await?; +//! ``` +//! +//! See `examples/c07-zai-dual-endpoints.rs` for a complete working example. // region: --- Modules diff --git a/src/adapter/dispatcher.rs b/src/adapter/dispatcher.rs index 1bc10b7e..90236524 100644 --- a/src/adapter/dispatcher.rs +++ b/src/adapter/dispatcher.rs @@ -1,7 +1,10 @@ use super::groq::GroqAdapter; +use crate::adapter::adapters::mimo::MimoAdapter; use crate::adapter::adapters::together::TogetherAdapter; +use crate::adapter::adapters::zai::ZaiAdapter; use crate::adapter::anthropic::AnthropicAdapter; use crate::adapter::bedrock::BedrockAdapter; +use crate::adapter::bigmodel::BigModelAdapter; use crate::adapter::cerebras::CerebrasAdapter; use crate::adapter::cohere::CohereAdapter; use crate::adapter::deepseek::DeepSeekAdapter; @@ -12,9 +15,7 @@ use crate::adapter::ollama::OllamaAdapter; use crate::adapter::openai::OpenAIAdapter; use crate::adapter::openai_resp::OpenAIRespAdapter; use crate::adapter::openrouter::OpenRouterAdapter; - use crate::adapter::xai::XaiAdapter; -use crate::adapter::zai::ZAiAdapter; use crate::adapter::zhipu::ZhipuAdapter; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; @@ -42,16 +43,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::default_endpoint(), AdapterKind::Together => TogetherAdapter::default_endpoint(), AdapterKind::Groq => GroqAdapter::default_endpoint(), + AdapterKind::Mimo => MimoAdapter::default_endpoint(), AdapterKind::Nebius => NebiusAdapter::default_endpoint(), AdapterKind::Xai => XaiAdapter::default_endpoint(), AdapterKind::DeepSeek => DeepSeekAdapter::default_endpoint(), - AdapterKind::Zhipu => ZhipuAdapter::default_endpoint(), + AdapterKind::Zai => ZaiAdapter::default_endpoint(), + AdapterKind::BigModel => BigModelAdapter::default_endpoint(), AdapterKind::Cohere => CohereAdapter::default_endpoint(), AdapterKind::Ollama => OllamaAdapter::default_endpoint(), + AdapterKind::OpenRouter => OpenRouterAdapter::default_endpoint(), AdapterKind::Cerebras => CerebrasAdapter::default_endpoint(), - AdapterKind::ZAi => ZAiAdapter::default_endpoint(), - AdapterKind::OpenRouter => Endpoint::from_static("https://openrouter.ai/api/v1/"), AdapterKind::Bedrock => BedrockAdapter::default_endpoint(), + AdapterKind::Zhipu => ZhipuAdapter::default_endpoint(), } } @@ -64,16 +67,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::default_auth(), AdapterKind::Together => TogetherAdapter::default_auth(), AdapterKind::Groq => GroqAdapter::default_auth(), + AdapterKind::Mimo => MimoAdapter::default_auth(), AdapterKind::Nebius => NebiusAdapter::default_auth(), AdapterKind::Xai => XaiAdapter::default_auth(), AdapterKind::DeepSeek => DeepSeekAdapter::default_auth(), - AdapterKind::Zhipu => ZhipuAdapter::default_auth(), + AdapterKind::Zai => ZaiAdapter::default_auth(), + AdapterKind::BigModel => BigModelAdapter::default_auth(), AdapterKind::Cohere => CohereAdapter::default_auth(), AdapterKind::Ollama => OllamaAdapter::default_auth(), + AdapterKind::OpenRouter => OpenRouterAdapter::default_auth(), AdapterKind::Cerebras => CerebrasAdapter::default_auth(), - AdapterKind::ZAi => ZAiAdapter::default_auth(), - AdapterKind::OpenRouter => AuthData::from_env(OpenRouterAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Bedrock => BedrockAdapter::default_auth(), + AdapterKind::Zhipu => ZhipuAdapter::default_auth(), } } @@ -86,16 +91,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::all_model_names(kind).await, AdapterKind::Together => TogetherAdapter::all_model_names(kind).await, AdapterKind::Groq => GroqAdapter::all_model_names(kind).await, + AdapterKind::Mimo => MimoAdapter::all_model_names(kind).await, AdapterKind::Nebius => NebiusAdapter::all_model_names(kind).await, AdapterKind::Xai => XaiAdapter::all_model_names(kind).await, AdapterKind::DeepSeek => DeepSeekAdapter::all_model_names(kind).await, - AdapterKind::Zhipu => ZhipuAdapter::all_model_names(kind).await, + AdapterKind::Zai => ZaiAdapter::all_model_names(kind).await, + AdapterKind::BigModel => BigModelAdapter::all_model_names(kind).await, AdapterKind::Cohere => CohereAdapter::all_model_names(kind).await, AdapterKind::Ollama => OllamaAdapter::all_model_names(kind).await, - AdapterKind::Cerebras => CerebrasAdapter::all_model_names(kind).await, - AdapterKind::ZAi => ZAiAdapter::all_model_names(kind).await, AdapterKind::OpenRouter => OpenRouterAdapter::all_model_names(kind).await, + AdapterKind::Cerebras => CerebrasAdapter::all_model_names(kind).await, AdapterKind::Bedrock => BedrockAdapter::all_model_names(kind).await, + AdapterKind::Zhipu => ZhipuAdapter::all_model_names(kind).await, } } @@ -108,16 +115,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Together => TogetherAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Groq => GroqAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::Mimo => MimoAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Nebius => NebiusAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Xai => XaiAdapter::get_service_url(model, service_type, endpoint), AdapterKind::DeepSeek => DeepSeekAdapter::get_service_url(model, service_type, endpoint), - AdapterKind::Zhipu => ZhipuAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::Zai => ZaiAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::BigModel => BigModelAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Cohere => CohereAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Ollama => OllamaAdapter::get_service_url(model, service_type, endpoint), - AdapterKind::Cerebras => CerebrasAdapter::get_service_url(model, service_type, endpoint), - AdapterKind::ZAi => ZAiAdapter::get_service_url(model, service_type, endpoint), AdapterKind::OpenRouter => OpenRouterAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::Cerebras => CerebrasAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Bedrock => BedrockAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::Zhipu => ZhipuAdapter::get_service_url(model, service_type, endpoint), } } @@ -142,18 +151,20 @@ impl AdapterDispatcher { } AdapterKind::Together => TogetherAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Groq => GroqAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::Mimo => MimoAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Nebius => NebiusAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Xai => XaiAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::DeepSeek => DeepSeekAdapter::to_web_request_data(target, service_type, chat_req, options_set), - AdapterKind::Zhipu => ZhipuAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::Zai => ZaiAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::BigModel => BigModelAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Cohere => CohereAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Ollama => OllamaAdapter::to_web_request_data(target, service_type, chat_req, options_set), - AdapterKind::Cerebras => CerebrasAdapter::to_web_request_data(target, service_type, chat_req, options_set), - AdapterKind::ZAi => ZAiAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::OpenRouter => { OpenRouterAdapter::to_web_request_data(target, service_type, chat_req, options_set) } + AdapterKind::Cerebras => CerebrasAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Bedrock => BedrockAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::Zhipu => ZhipuAdapter::to_web_request_data(target, service_type, chat_req, options_set), } } @@ -170,16 +181,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Together => TogetherAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Groq => GroqAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::Mimo => MimoAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Nebius => NebiusAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Xai => XaiAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::DeepSeek => DeepSeekAdapter::to_chat_response(model_iden, web_response, options_set), - AdapterKind::Zhipu => ZhipuAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::Zai => ZaiAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::BigModel => BigModelAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Cohere => CohereAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Ollama => OllamaAdapter::to_chat_response(model_iden, web_response, options_set), - AdapterKind::Cerebras => CerebrasAdapter::to_chat_response(model_iden, web_response, options_set), - AdapterKind::ZAi => ZAiAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::OpenRouter => OpenRouterAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_chat_response(model_iden, web_response, options_set), AdapterKind::Bedrock => BedrockAdapter::to_chat_response(model_iden, web_response, options_set), + AdapterKind::Zhipu => ZhipuAdapter::to_chat_response(model_iden, web_response, options_set), } } @@ -199,16 +212,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Together => TogetherAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Groq => GroqAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::Mimo => MimoAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Nebius => NebiusAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Xai => XaiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::DeepSeek => DeepSeekAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), - AdapterKind::Zhipu => ZhipuAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::Zai => ZaiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::BigModel => BigModelAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Cohere => CohereAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Ollama => OllamaAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), - AdapterKind::Cerebras => CerebrasAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), - AdapterKind::ZAi => ZAiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::OpenRouter => OpenRouterAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Bedrock => BedrockAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::Zhipu => ZhipuAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), } } @@ -229,16 +244,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Together => TogetherAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Groq => GroqAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::Mimo => MimoAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Nebius => NebiusAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Xai => XaiAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::DeepSeek => DeepSeekAdapter::to_embed_request_data(target, embed_req, options_set), - AdapterKind::Zhipu => ZhipuAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::Zai => ZaiAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::BigModel => BigModelAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Cohere => CohereAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Ollama => OllamaAdapter::to_embed_request_data(target, embed_req, options_set), - AdapterKind::Cerebras => CerebrasAdapter::to_embed_request_data(target, embed_req, options_set), - AdapterKind::ZAi => ZAiAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::OpenRouter => OpenRouterAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_embed_request_data(target, embed_req, options_set), AdapterKind::Bedrock => BedrockAdapter::to_embed_request_data(target, embed_req, options_set), + AdapterKind::Zhipu => ZhipuAdapter::to_embed_request_data(target, embed_req, options_set), } } @@ -258,16 +275,18 @@ impl AdapterDispatcher { AdapterKind::Fireworks => FireworksAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Together => TogetherAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Groq => GroqAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::Mimo => MimoAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Nebius => NebiusAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Xai => XaiAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::DeepSeek => DeepSeekAdapter::to_embed_response(model_iden, web_response, options_set), - AdapterKind::Zhipu => ZhipuAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::Zai => ZaiAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::BigModel => BigModelAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Cohere => CohereAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Ollama => OllamaAdapter::to_embed_response(model_iden, web_response, options_set), - AdapterKind::Cerebras => CerebrasAdapter::to_embed_response(model_iden, web_response, options_set), - AdapterKind::ZAi => ZAiAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::OpenRouter => OpenRouterAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::Cerebras => CerebrasAdapter::to_embed_response(model_iden, web_response, options_set), AdapterKind::Bedrock => BedrockAdapter::to_embed_response(model_iden, web_response, options_set), + AdapterKind::Zhipu => ZhipuAdapter::to_embed_response(model_iden, web_response, options_set), } } } diff --git a/src/chat/chat_request.rs b/src/chat/chat_request.rs index dc73a03b..62608eb3 100644 --- a/src/chat/chat_request.rs +++ b/src/chat/chat_request.rs @@ -1,6 +1,6 @@ //! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file). -use crate::chat::{ChatMessage, ChatRole, Tool}; +use crate::chat::{ChatMessage, ChatRole, StreamEnd, Tool, ToolCall, ToolResponse}; use crate::support; use serde::{Deserialize, Serialize}; @@ -98,6 +98,29 @@ impl ChatRequest { self.tools.get_or_insert_with(Vec::new).push(tool.into()); self } + + /// Append an assistant tool-use turn and the corresponding tool response based on a + /// streaming `StreamEnd` capture. Thought signatures are included automatically and + /// ordered before tool calls when present. + /// + /// If neither content nor tool calls were captured, this is a no-op before appending + /// the provided tool response. + pub fn append_tool_use_from_stream_end(mut self, end: &StreamEnd, tool_response: ToolResponse) -> Self { + if let Some(content) = &end.captured_content { + // Use captured content directly (contains thoughts/text/tool calls in correct order) + self.messages.push(ChatMessage::assistant(content.clone())); + } else if let Some(calls_ref) = end.captured_tool_calls() { + // Fallback: build assistant message from tool calls only + let calls: Vec = calls_ref.into_iter().cloned().collect(); + if !calls.is_empty() { + self.messages.push(ChatMessage::from(calls)); + } + } + + // Append the tool response turn + self.messages.push(ChatMessage::from(tool_response)); + self + } } /// Getters diff --git a/src/chat/chat_response.rs b/src/chat/chat_response.rs index a4624260..594e2954 100644 --- a/src/chat/chat_response.rs +++ b/src/chat/chat_response.rs @@ -30,7 +30,8 @@ pub struct ChatResponse { /// Token usage reported by the provider. pub usage: Usage, - /// Raw response body for provider-specific features. + /// IMPORTANT: (since 0.5.3) This is populated at the client.exec_chat when the options capture_raw_body is set to true + /// Raw response body (only if asked via options.capture_raw_body) pub captured_raw_body: Option, } diff --git a/src/chat/chat_stream.rs b/src/chat/chat_stream.rs index 8d1057a7..9b9aecd8 100644 --- a/src/chat/chat_stream.rs +++ b/src/chat/chat_stream.rs @@ -1,5 +1,5 @@ use crate::adapter::inter_stream::{InterStreamEnd, InterStreamEvent}; -use crate::chat::{MessageContent, ToolCall, Usage}; +use crate::chat::{ChatMessage, ContentPart, MessageContent, ToolCall, Usage}; use futures::Stream; use serde::{Deserialize, Serialize}; use std::pin::Pin; @@ -34,31 +34,27 @@ impl Stream for ChatStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - loop { - match Pin::new(&mut this.inter_stream).poll_next(cx) { - Poll::Ready(Some(Ok(event))) => { - let chat_event = match event { - InterStreamEvent::Start => ChatStreamEvent::Start, - InterStreamEvent::Chunk(content) => ChatStreamEvent::Chunk(StreamChunk { content }), - InterStreamEvent::ReasoningChunk(content) => { - ChatStreamEvent::ReasoningChunk(StreamChunk { content }) - } - InterStreamEvent::ToolCallChunk(tool_call) => { - ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call }) - } - InterStreamEvent::ThoughtSignatureChunk(_signature) => { - // Thought signatures are internal metadata, not streamed to users - // Skip this event and continue polling - continue; - } - InterStreamEvent::End(inter_end) => ChatStreamEvent::End(inter_end.into()), - }; - return Poll::Ready(Some(Ok(chat_event))); - } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, + match Pin::new(&mut this.inter_stream).poll_next(cx) { + Poll::Ready(Some(Ok(event))) => { + let chat_event = match event { + InterStreamEvent::Start => ChatStreamEvent::Start, + InterStreamEvent::Chunk(content) => ChatStreamEvent::Chunk(StreamChunk { content }), + InterStreamEvent::ReasoningChunk(content) => { + ChatStreamEvent::ReasoningChunk(StreamChunk { content }) + } + InterStreamEvent::ThoughtSignatureChunk(content) => { + ChatStreamEvent::ThoughtSignatureChunk(StreamChunk { content }) + } + InterStreamEvent::ToolCallChunk(tool_call) => { + ChatStreamEvent::ToolCallChunk(ToolChunk { tool_call }) + } + InterStreamEvent::End(inter_end) => ChatStreamEvent::End(inter_end.into()), + }; + Poll::Ready(Some(Ok(chat_event))) } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } @@ -79,6 +75,9 @@ pub enum ChatStreamEvent { /// Reasoning content chunk. ReasoningChunk(StreamChunk), + /// Thought signature content chunk. + ThoughtSignatureChunk(StreamChunk), + /// Tool-call chunk. ToolCallChunk(ToolChunk), @@ -121,13 +120,42 @@ pub struct StreamEnd { impl From for StreamEnd { fn from(inter_end: InterStreamEnd) -> Self { let captured_text_content = inter_end.captured_text_content; - let captured_tool_calls = inter_end.captured_tool_calls; + let mut captured_tool_calls = inter_end.captured_tool_calls; // -- create public captured_content + // Ordering policy: ThoughtSignature -> Text -> ToolCall + // This matches provider expectations (e.g., Gemini 3 requires thought first). let mut captured_content: Option = None; + if let Some(captured_thoughts) = inter_end.captured_thought_signatures { + let thoughts_content = captured_thoughts + .into_iter() + .map(ContentPart::ThoughtSignature) + .collect::>(); + // Also attach thoughts to the first tool call so that + // ChatMessage::from(Vec) can auto-prepend them. + if let Some(tool_calls) = captured_tool_calls.as_mut() + && let Some(first_call) = tool_calls.first_mut() + { + first_call.thought_signatures = Some( + thoughts_content + .iter() + .filter_map(|p| p.as_thought_signature().map(|s| s.to_string())) + .collect(), + ); + } + if let Some(existing_content) = &mut captured_content { + existing_content.extend_front(thoughts_content); + } else { + captured_content = Some(MessageContent::from_parts(thoughts_content)); + } + } if let Some(captured_text_content) = captured_text_content { // This `captured_text_content` is the concatenation of all text chunks received. - captured_content = Some(MessageContent::from_text(captured_text_content)); + if let Some(existing_content) = &mut captured_content { + existing_content.extend(MessageContent::from_text(captured_text_content)); + } else { + captured_content = Some(MessageContent::from_text(captured_text_content)); + } } if let Some(captured_tool_calls) = captured_tool_calls { if let Some(existing_content) = &mut captured_content { @@ -186,6 +214,53 @@ impl StreamEnd { let captured_content = self.captured_content?; Some(captured_content.into_tool_calls()) } + + /// Returns all captured thought signatures, if any. + pub fn captured_thought_signatures(&self) -> Option> { + let captured_content = self.captured_content.as_ref()?; + Some( + captured_content + .parts() + .iter() + .filter_map(|p| p.as_thought_signature()) + .collect(), + ) + } + + /// Consumes `self` and returns all captured thought signatures, if any. + pub fn captured_into_thought_signatures(self) -> Option> { + let captured_content = self.captured_content?; + Some( + captured_content + .into_parts() + .into_iter() + .filter_map(|p| p.into_thought_signature()) + .collect(), + ) + } + + /// Convenience: build an assistant message for a tool-use handoff that places + /// thought signatures (if any) before tool calls. Returns None if no tool calls + /// were captured. + pub fn into_assistant_message_for_tool_use(self) -> Option { + let content = self.captured_content?; + let mut thought_signatures: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + for part in content.into_parts() { + match part { + ContentPart::ThoughtSignature(t) => thought_signatures.push(t), + ContentPart::ToolCall(tc) => tool_calls.push(tc), + _ => {} + } + } + if tool_calls.is_empty() { + return None; + } + Some(ChatMessage::assistant_tool_calls_with_thoughts( + tool_calls, + thought_signatures, + )) + } } // endregion: --- ChatStreamEvent diff --git a/src/chat/message_content.rs b/src/chat/message_content.rs index 3ce3c39a..95f021b6 100644 --- a/src/chat/message_content.rs +++ b/src/chat/message_content.rs @@ -1,5 +1,5 @@ /// Note: MessageContent is used for ChatRequest and ChatResponse. -use crate::chat::{ContentPart, ToolCall, ToolResponse}; +use crate::chat::{Binary, ContentPart, ToolCall, ToolResponse}; use serde::{Deserialize, Serialize}; /// Message content container used in ChatRequest and ChatResponse. @@ -47,6 +47,29 @@ impl MessageContent { self.parts.push(part.into()); } + /// Insert one part at the given index (mutating). + pub fn insert(&mut self, index: usize, part: impl Into) { + self.parts.insert(index, part.into()); + } + + /// Prepend one part to the beginning (mutating). + pub fn prepend(&mut self, part: impl Into) { + self.parts.insert(0, part.into()); + } + + /// Prepend multiple parts while preserving their original order. + pub fn extend_front(&mut self, iter: I) + where + I: IntoIterator, + { + // Collect then insert in reverse so that the first element in `iter` + // ends up closest to the front after all insertions. + let collected: Vec = iter.into_iter().collect(); + for part in collected.into_iter().rev() { + self.parts.insert(0, part); + } + } + /// Extend with an iterator of parts, returning self. pub fn extended(mut self, iter: I) -> Self where @@ -63,6 +86,15 @@ impl Extend for MessageContent { } } +/// Computed accessors +impl MessageContent { + /// Returns an approximate in-memory size of this `MessageContent`, in bytes, + /// computed as the sum of the sizes of all parts. + pub fn size(&self) -> usize { + self.parts.iter().map(|p| p.size()).sum() + } +} + // region: --- Iterator Support use crate::support; @@ -126,6 +158,14 @@ impl MessageContent { self.parts.into_iter().filter_map(|p| p.into_text()).collect() } + pub fn binaries(&self) -> Vec<&Binary> { + self.parts.iter().filter_map(|p| p.as_binary()).collect() + } + + pub fn into_binaries(self) -> Vec { + self.parts.into_iter().filter_map(|p| p.into_binary()).collect() + } + /// Return references to all ToolCall parts. pub fn tool_calls(&self) -> Vec<&ToolCall> { self.parts @@ -222,9 +262,7 @@ impl MessageContent { let mut combined = String::new(); for text in texts { - if !combined.is_empty() { - support::combine_text_with_empty_line(&mut combined, text); - } + support::combine_text_with_empty_line(&mut combined, text); } Some(combined) } @@ -269,12 +307,6 @@ impl MessageContent { pub fn contains_tool_response(&self) -> bool { self.parts.iter().any(|p| p.is_tool_response()) } - - /// Returns an approximate in-memory size of this `MessageContent`, in bytes, - /// computed as the sum of the sizes of all content parts. - pub fn size(&self) -> usize { - self.parts.iter().map(|p| p.size()).sum() - } } // region: --- Froms @@ -319,6 +351,20 @@ impl From for MessageContent { } } +impl From for MessageContent { + fn from(part: ContentPart) -> Self { + Self { parts: vec![part] } + } +} + +impl From for MessageContent { + fn from(bin: Binary) -> Self { + Self { + parts: vec![bin.into()], + } + } +} + impl From> for MessageContent { fn from(parts: Vec) -> Self { Self { parts } @@ -326,3 +372,58 @@ impl From> for MessageContent { } // endregion: --- Froms + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_content_joined_texts_empty() { + assert_eq!(MessageContent::from_parts(vec![]).joined_texts(), None); + } + + #[test] + fn test_message_content_joined_texts_single_part() { + assert_eq!( + MessageContent::from_parts(vec![ContentPart::Text("Hello".to_string())]).joined_texts(), + Some("Hello".to_string()) + ); + } + + #[test] + fn test_message_content_joined_texts_two_parts() { + assert_eq!( + MessageContent::from_parts(vec![ + ContentPart::Text("Hello".to_string()), + ContentPart::Text("World".to_string()), + ]) + .joined_texts(), + Some("Hello\n\nWorld".to_string()) + ); + } + + #[test] + fn test_message_content_into_joined_texts_empty() { + assert_eq!(MessageContent::from_parts(vec![]).into_joined_texts(), None); + } + + #[test] + fn test_message_content_into_joined_texts_single_part() { + assert_eq!( + MessageContent::from_parts(vec![ContentPart::Text("Hello".to_string())]).into_joined_texts(), + Some("Hello".to_string()) + ); + } + + #[test] + fn test_message_content_into_joined_texts_two_parts() { + assert_eq!( + MessageContent::from_parts(vec![ + ContentPart::Text("Hello".to_string()), + ContentPart::Text("World".to_string()), + ]) + .into_joined_texts(), + Some("Hello\n\nWorld".to_string()) + ); + } +} diff --git a/src/chat/printer.rs b/src/chat/printer.rs index a14c9db9..9c08cd18 100644 --- a/src/chat/printer.rs +++ b/src/chat/printer.rs @@ -68,6 +68,7 @@ async fn print_chat_stream_inner( let mut first_chunk = true; let mut first_reasoning_chunk = true; + let mut first_thought_signature_chunk = true; let mut first_tool_chunk = true; while let Some(next) = stream.next().await { @@ -109,6 +110,19 @@ async fn print_chat_stream_inner( } } + ChatStreamEvent::ThoughtSignatureChunk(StreamChunk { content }) => { + if print_events && first_thought_signature_chunk { + first_thought_signature_chunk = false; + ( + Some("\n-- ChatStreamEvent::ThoughtSignatureChunk (concatenated):\n".to_string()), + Some(content), + false, // print but do not capture + ) + } else { + (None, Some(content), false) // print but do not capture + } + } + ChatStreamEvent::ToolCallChunk(tool_chunk) => { if print_events && first_tool_chunk { first_tool_chunk = false; diff --git a/src/chat/tool/tool_response.rs b/src/chat/tool/tool_response.rs index 34149816..288f292e 100644 --- a/src/chat/tool/tool_response.rs +++ b/src/chat/tool/tool_response.rs @@ -21,6 +21,17 @@ impl ToolResponse { } } +/// Computed accessors +impl ToolResponse { + /// Returns an approximate in-memory size of this `ToolResponse`, in bytes, + /// computed as the sum of the UTF-8 lengths of: + /// - `call_id` + /// - `content` + pub fn size(&self) -> usize { + self.call_id.len() + self.content.len() + } +} + /// Getters #[allow(unused)] impl ToolResponse { @@ -32,12 +43,3 @@ impl ToolResponse { &self.content } } - -/// Computed accessors -impl ToolResponse { - /// Returns an approximate in-memory size of this `ToolResponse`, in bytes, - /// computed as the sum of the UTF-8 lengths of `call_id` and `content`. - pub fn size(&self) -> usize { - self.call_id.len() + self.content.len() - } -} diff --git a/src/client/client_impl.rs b/src/client/client_impl.rs index 8a73ca4e..93e21322 100644 --- a/src/client/client_impl.rs +++ b/src/client/client_impl.rs @@ -1,5 +1,6 @@ use crate::adapter::{AdapterDispatcher, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ChatOptions, ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::client::ModelSpec; use crate::embed::{EmbedOptions, EmbedOptionsSet, EmbedRequest, EmbedResponse}; use crate::resolver::AuthData; use crate::{Client, Error, ModelIden, Result, ServiceTarget}; @@ -41,26 +42,33 @@ impl Client { Ok(target.model) } - /// Resolves the service target (endpoint, auth, and model) for the given model name. - pub async fn resolve_service_target(&self, model_name: &str) -> Result { - let model = self.default_model(model_name)?; - self.config().resolve_service_target(model).await + /// Resolves the service target (endpoint, auth, and model) for the given model. + /// + /// Accepts any type that implements `Into`: + /// - `&str` or `String`: Model name with full inference + /// - `ModelIden`: Explicit adapter, resolves auth/endpoint + /// - `ServiceTarget`: Uses directly, bypasses model mapping and auth resolution + pub async fn resolve_service_target(&self, model: impl Into) -> Result { + self.config().resolve_model_spec(model.into()).await } /// Sends a chat request and returns the full response. + /// + /// Accepts any type that implements `Into`: + /// - `&str` or `String`: Model name with full inference + /// - `ModelIden`: Explicit adapter, resolves auth/endpoint + /// - `ServiceTarget`: Uses directly, bypasses model mapping and auth resolution pub async fn exec_chat( &self, - model: &str, + model: impl Into, chat_req: ChatRequest, - // options not implemented yet options: Option<&ChatOptions>, ) -> Result { let options_set = ChatOptionsSet::default() .with_chat_options(options) .with_client_options(self.config().chat_options()); - let model = self.default_model(model)?; - let target = self.config().resolve_service_target(model).await?; + let target = self.config().resolve_model_spec(model.into()).await?; let model = target.model.clone(); let auth_data = target.auth.clone(); @@ -70,6 +78,10 @@ impl Client { payload, } = AdapterDispatcher::to_web_request_data(target, ServiceType::Chat, chat_req, options_set.clone())?; + if let Some(extra_headers) = options.and_then(|o| o.extra_headers.as_ref()) { + headers.merge_with(&extra_headers); + } + if let AuthData::RequestOverride { url: override_url, headers: override_headers, @@ -88,24 +100,46 @@ impl Client { webc_error, })?; - let chat_res = AdapterDispatcher::to_chat_response(model, web_res, options_set)?; - - Ok(chat_res) + // Note: here we capture/clone the raw body if set in the options_set + let captured_raw_body = options_set.capture_raw_body().unwrap_or_default().then(|| web_res.body.clone()); + + match AdapterDispatcher::to_chat_response(model.clone(), web_res, options_set) { + Ok(mut chat_res) => { + chat_res.captured_raw_body = captured_raw_body; + Ok(chat_res) + } + Err(err) => { + let response_body = captured_raw_body.unwrap_or_else(|| { + "Raw response not captured. Use the ChatOptions.capturre_raw_body flag to see raw response in this error".into() + }); + let err = Error::ChatResponseGeneration { + model_iden: model, + request_payload: Box::new(payload), + response_body: Box::new(response_body), + cause: err.to_string(), + }; + Err(err) + } + } } /// Streams a chat response. + /// + /// Accepts any type that implements `Into`: + /// - `&str` or `String`: Model name with full inference + /// - `ModelIden`: Explicit adapter, resolves auth/endpoint + /// - `ServiceTarget`: Uses directly, bypasses model mapping and auth resolution pub async fn exec_chat_stream( &self, - model: &str, - chat_req: ChatRequest, // options not implemented yet + model: impl Into, + chat_req: ChatRequest, options: Option<&ChatOptions>, ) -> Result { let options_set = ChatOptionsSet::default() .with_chat_options(options) .with_client_options(self.config().chat_options()); - let model = self.default_model(model)?; - let target = self.config().resolve_service_target(model).await?; + let target = self.config().resolve_model_spec(model.into()).await?; let model = target.model.clone(); let auth_data = target.auth.clone(); @@ -115,6 +149,10 @@ impl Client { payload, } = AdapterDispatcher::to_web_request_data(target, ServiceType::ChatStream, chat_req, options_set.clone())?; + if let Some(extra_headers) = options.and_then(|o| o.extra_headers.as_ref()) { + headers.merge_with(&extra_headers); + } + // TODO: Need to check this. // This was part of the 429c5cee2241dbef9f33699b9c91202233c22816 commit // But now it is missing in the the exec_chat(..) above, which is probably an issue. @@ -141,9 +179,11 @@ impl Client { } /// Creates embeddings for a single input string. + /// + /// Accepts any type that implements `Into` for the model parameter. pub async fn embed( &self, - model: &str, + model: impl Into, input: impl Into, options: Option<&EmbedOptions>, ) -> Result { @@ -152,9 +192,11 @@ impl Client { } /// Creates embeddings for multiple input strings. + /// + /// Accepts any type that implements `Into` for the model parameter. pub async fn embed_batch( &self, - model: &str, + model: impl Into, inputs: Vec, options: Option<&EmbedOptions>, ) -> Result { @@ -163,9 +205,14 @@ impl Client { } /// Sends an embedding request and returns the response. + /// + /// Accepts any type that implements `Into`: + /// - `&str` or `String`: Model name with full inference + /// - `ModelIden`: Explicit adapter, resolves auth/endpoint + /// - `ServiceTarget`: Uses directly, bypasses model mapping and auth resolution pub async fn exec_embed( &self, - model: &str, + model: impl Into, embed_req: EmbedRequest, options: Option<&EmbedOptions>, ) -> Result { @@ -173,8 +220,7 @@ impl Client { .with_request_options(options) .with_client_options(self.config().embed_options()); - let model = self.default_model(model)?; - let target = self.config().resolve_service_target(model).await?; + let target = self.config().resolve_model_spec(model.into()).await?; let model = target.model.clone(); let WebRequestData { headers, payload, url } = diff --git a/src/client/config.rs b/src/client/config.rs index bce41a8c..e5858e69 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -1,8 +1,8 @@ -use crate::adapter::AdapterDispatcher; +use crate::adapter::{AdapterDispatcher, AdapterKind}; use crate::chat::ChatOptions; -use crate::client::ServiceTarget; +use crate::client::{ModelSpec, ServiceTarget}; use crate::embed::EmbedOptions; -use crate::resolver::{AuthResolver, ModelMapper, ServiceTargetResolver}; +use crate::resolver::{AuthData, AuthResolver, ModelMapper, ServiceTargetResolver}; use crate::{Error, ModelIden, Result, WebConfig}; /// Configuration for building and customizing a `Client`. @@ -104,41 +104,65 @@ impl ClientConfig { /// Errors with Error::Resolver if any resolver step fails. pub async fn resolve_service_target(&self, model: ModelIden) -> Result { // -- Resolve the Model first - let model = match self.model_mapper() { - Some(model_mapper) => model_mapper.map_model(model.clone()), - None => Ok(model.clone()), - } - .map_err(|resolver_error| Error::Resolver { - model_iden: model.clone(), - resolver_error, - })?; + let model = self.run_model_mapper(model.clone())?; // -- Get the auth - let auth = if let Some(auth) = self.auth_resolver() { - // resolve async which may be async - auth.resolve(model.clone()) - .await - .map_err(|err| Error::Resolver { - model_iden: model.clone(), - resolver_error: err, - })? - // default the resolver resolves to nothing - .unwrap_or_else(|| AdapterDispatcher::default_auth(model.adapter_kind)) - } else { - AdapterDispatcher::default_auth(model.adapter_kind) - }; + let auth = self.run_auth_resolver(model.clone()).await?; // -- Get the default endpoint // For now, just get the default endpoint; the `resolve_target` will allow overriding it. let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind); - // -- Resolve the service_target + // -- Create the default service target let service_target = ServiceTarget { model: model.clone(), auth, endpoint, }; - let service_target = match self.service_target_resolver() { + + // -- Resolve the service target + let service_target = self.run_service_target_resolver(service_target).await?; + + Ok(service_target) + } + + /// Resolves a [`ModelIden`] to a [`ModelIden`] via the [`ModelMapper`] (if any). + fn run_model_mapper(&self, model: ModelIden) -> Result { + match self.model_mapper() { + Some(model_mapper) => model_mapper.map_model(model.clone()), + None => Ok(model.clone()), + } + .map_err(|resolver_error| Error::Resolver { + model_iden: model.clone(), + resolver_error, + }) + } + + /// Resolves a [`ModelIden`] to an [`AuthData`] via the [`AuthResolver`] (if any). + async fn run_auth_resolver(&self, model: ModelIden) -> Result { + match self.auth_resolver() { + Some(auth_resolver) => { + let auth_data = auth_resolver + .resolve(model.clone()) + .await + .map_err(|err| Error::Resolver { + model_iden: model.clone(), + resolver_error: err, + })? + // default the resolver resolves to nothing + .unwrap_or_else(|| AdapterDispatcher::default_auth(model.adapter_kind)); + + Ok(auth_data) + } + None => Ok(AdapterDispatcher::default_auth(model.adapter_kind)), + } + } + + /// Resolves a [`ServiceTarget`] via the [`ServiceTargetResolver`] (if any). + async fn run_service_target_resolver(&self, service_target: ServiceTarget) -> Result { + let model = service_target.model.clone(); + + match self.service_target_resolver() { Some(service_target_resolver) => { service_target_resolver .resolve(service_target) @@ -146,11 +170,31 @@ impl ClientConfig { .map_err(|resolver_error| Error::Resolver { model_iden: model, resolver_error, - })? + }) } - None => service_target, - }; + None => Ok(service_target), + } + } - Ok(service_target) + /// Resolves a [`ModelSpec`] to a [`ServiceTarget`]. + /// + /// The resolution behavior depends on the variant: + /// + /// - [`ModelSpec::Name`]: Infers adapter from name, then applies full resolution + /// (model mapper, auth resolver, service target resolver). + /// + /// - [`ModelSpec::Iden`]: Skips adapter inference, applies full resolution. + /// + /// - [`ModelSpec::Target`]: Returns the target directly, running only the service target resolver. + pub async fn resolve_model_spec(&self, spec: ModelSpec) -> Result { + match spec { + ModelSpec::Name(name) => { + let adapter_kind = AdapterKind::from_model(&name)?; + let model = ModelIden::new(adapter_kind, name); + self.resolve_service_target(model).await + } + ModelSpec::Iden(model) => self.resolve_service_target(model).await, + ModelSpec::Target(target) => self.run_service_target_resolver(target).await, + } } } diff --git a/src/client/headers.rs b/src/client/headers.rs index 0aaa007e..2b3a74ef 100644 --- a/src/client/headers.rs +++ b/src/client/headers.rs @@ -17,7 +17,7 @@ pub struct Headers { impl Headers { /// Merge headers from overlay into self, consuming overlay. /// Later values override existing ones. - /// Use [`merge_with`] for a borrowed overlay. + /// Use [`Headers::merge_with`] for a borrowed overlay. pub fn merge(&mut self, overlay: impl Into) { let overlay = overlay.into(); diff --git a/src/client/mod.rs b/src/client/mod.rs index d5943c78..3fbb6d75 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -10,6 +10,7 @@ mod client_impl; mod client_types; mod config; mod headers; +mod model_spec; mod service_target; mod web_config; @@ -17,6 +18,7 @@ pub use builder::*; pub use client_types::*; pub use config::*; pub use headers::*; +pub use model_spec::*; pub use service_target::*; pub use web_config::*; diff --git a/src/client/model_spec.rs b/src/client/model_spec.rs new file mode 100644 index 00000000..2641eded --- /dev/null +++ b/src/client/model_spec.rs @@ -0,0 +1,133 @@ +use crate::{ModelIden, ModelName, ServiceTarget}; + +/// Specifies how to identify and resolve a model for API calls. +/// +/// `ModelSpec` provides three levels of control over model resolution: +/// +/// - [`ModelSpec::Name`]: Just a model name string. The adapter kind is inferred +/// from the name, and auth/endpoint are resolved via the client's configured resolvers. +/// +/// - [`ModelSpec::Iden`]: An explicit [`ModelIden`] with adapter kind specified. +/// Skips adapter inference but still resolves auth/endpoint via config. +/// +/// - [`ModelSpec::Target`]: A complete [`ServiceTarget`] with endpoint, auth, and model. +/// Used directly, only runs the service target resolver. +/// +/// # Examples +/// +/// ```rust +/// use genai::adapter::AdapterKind; +/// use genai::resolver::{AuthData, Endpoint}; +/// use genai::{ModelIden, ModelSpec, ServiceTarget}; +/// +/// // Using a string name (full inference) +/// let spec: ModelSpec = "gpt-4".into(); +/// +/// // Using an explicit ModelIden (skip adapter inference) +/// let spec: ModelSpec = ModelIden::new(AdapterKind::OpenAI, "gpt-4").into(); +/// +/// // Using a complete ServiceTarget (bypass all resolution) +/// let target = ServiceTarget { +/// endpoint: Endpoint::from_static("https://custom.api/v1/"), +/// auth: AuthData::from_env("CUSTOM_API_KEY"), +/// model: ModelIden::new(AdapterKind::OpenAI, "custom-model"), +/// }; +/// let spec: ModelSpec = target.into(); +/// ``` +#[derive(Debug, Clone)] +pub enum ModelSpec { + /// Model name - without or without model namespace + Name(ModelName), + + /// Explicit [`ModelIden`] - skips adapter inference, still resolves auth/endpoint. + Iden(ModelIden), + + /// Complete [`ServiceTarget`] - used directly, bypasses model mapping and auth resolution + Target(ServiceTarget), +} + +// region: --- Constructors + +impl ModelSpec { + /// Creates a `ModelSpec::Name` from a string. + pub fn from_name(name: impl Into) -> Self { + ModelSpec::Name(name.into()) + } + + /// Creates a `ModelSpec::Name` from a static str. + pub fn from_static_name(name: &'static str) -> Self { + let name = ModelName::from_static(name); + ModelSpec::Name(name) + } + + /// Creates a `ModelSpec::Iden` from a ModelIden + pub fn from_iden(model_iden: impl Into) -> Self { + let model_iden = model_iden.into(); + Self::Iden(model_iden) + } + + /// Creates a `ModelSpec::Target` from a complete service target. + pub fn from_target(target: ServiceTarget) -> Self { + ModelSpec::Target(target) + } +} + +// endregion: --- Constructors + +// region: --- From Implementations + +impl From<&str> for ModelSpec { + fn from(name: &str) -> Self { + ModelSpec::Name(name.into()) + } +} + +impl From<&&str> for ModelSpec { + fn from(name: &&str) -> Self { + ModelSpec::Name((*name).into()) + } +} + +impl From for ModelSpec { + fn from(name: String) -> Self { + ModelSpec::Name(name.into()) + } +} + +impl From<&String> for ModelSpec { + fn from(name: &String) -> Self { + ModelSpec::Name(name.into()) + } +} + +impl From for ModelSpec { + fn from(model: ModelName) -> Self { + ModelSpec::Name(model) + } +} + +impl From<&ModelName> for ModelSpec { + fn from(model: &ModelName) -> Self { + ModelSpec::Name(model.clone()) + } +} + +impl From for ModelSpec { + fn from(model: ModelIden) -> Self { + ModelSpec::Iden(model) + } +} + +impl From<&ModelIden> for ModelSpec { + fn from(model: &ModelIden) -> Self { + ModelSpec::Iden(model.clone()) + } +} + +impl From for ModelSpec { + fn from(target: ServiceTarget) -> Self { + ModelSpec::Target(target) + } +} + +// endregion: --- From Implementations diff --git a/src/client/service_target.rs b/src/client/service_target.rs index dac06af1..a2d5a37e 100644 --- a/src/client/service_target.rs +++ b/src/client/service_target.rs @@ -9,6 +9,7 @@ use crate::resolver::{AuthData, Endpoint}; /// - `auth`: Authentication data for the request. /// /// - `model`: Target model identifier. +#[derive(Debug, Clone)] pub struct ServiceTarget { pub endpoint: Endpoint, pub auth: AuthData, diff --git a/src/common/model_iden.rs b/src/common/model_iden.rs index 60e01861..ed40c64e 100644 --- a/src/common/model_iden.rs +++ b/src/common/model_iden.rs @@ -20,12 +20,20 @@ pub struct ModelIden { /// Contructor impl ModelIden { /// Create a new `ModelIden` with the given adapter kind and model name. - pub fn new(adapter_kind: AdapterKind, model_name: impl Into) -> Self { + pub fn new(adapter_kind: impl Into, model_name: impl Into) -> Self { Self { - adapter_kind, + adapter_kind: adapter_kind.into(), model_name: model_name.into(), } } + + /// Create a new `ModelIden` with the given adapter kind and model name. + pub fn from_static(adapter_kind: impl Into, model_name: &'static str) -> Self { + Self { + adapter_kind: adapter_kind.into(), + model_name: ModelName::from_static(model_name), + } + } } impl ModelIden { @@ -49,7 +57,7 @@ impl ModelIden { } /// Creates a new `ModelIden` with the specified name, or clones the existing one if the name is the same. - /// NOTE: Might be deprecated in favor of [`from_name`] + /// NOTE: Might be deprecated in favor of [`ModelIden::from_name`] pub fn from_optional_name(&self, new_name: Option) -> ModelIden { if let Some(new_name) = new_name { self.from_name(new_name) diff --git a/src/common/model_name.rs b/src/common/model_name.rs index 7ddd38c1..087a4d6b 100644 --- a/src/common/model_name.rs +++ b/src/common/model_name.rs @@ -56,15 +56,6 @@ impl ModelName { Self::split_as_namespace_and_name(self.as_str()) } - /// Backward compatibility - returns `(name, namespace)` - /// e.g.: - /// `openai::gpt4.1` β†’ ("gpt4.1", Some("openai")) - /// `gpt4.1` β†’ ("gpt4.1", None) - pub fn as_model_name_and_namespace(&self) -> (&str, Option<&str>) { - let (ns, name) = Self::split_as_namespace_and_name(self.as_str()); - (name, ns) - } - /// e.g.: /// `openai::gpt4.1` β†’ (Some("openai"), "gpt4.1") /// `gpt4.1` β†’ (None, "gpt4.1") @@ -77,12 +68,6 @@ impl ModelName { (None, model) } } - - /// Backward compatibility - static method that returns `(name, Option)` - pub fn model_name_and_namespace(model: &str) -> (&str, Option<&str>) { - let (ns, name) = Self::split_as_namespace_and_name(model); - (name, ns) - } } impl std::fmt::Display for ModelName { diff --git a/src/lib.rs b/src/lib.rs index f8591550..07a218f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ mod error; // -- Flatten pub use client::*; pub use common::*; -pub use error::{Error, Result}; +pub use error::{BoxError, Error, Result}; // -- Public Modules pub mod adapter; diff --git a/tests/data/other-one.png b/tests/data/other-one.png new file mode 100644 index 00000000..6e281aa4 Binary files /dev/null and b/tests/data/other-one.png differ diff --git a/tests/support/common_tests.rs b/tests/support/common_tests.rs index 10efdc20..162827e8 100644 --- a/tests/support/common_tests.rs +++ b/tests/support/common_tests.rs @@ -1,5 +1,8 @@ use crate::get_option_value; -use crate::support::data::{IMAGE_URL_JPG_DUCK, get_b64_duck, get_b64_pdf}; +use crate::support::data::{ + AUDIO_TEST_FILE_PATH, IMAGE_URL_JPG_DUCK, TEST_IMAGE_FILE_PATH, get_b64_audio, get_b64_duck, get_b64_pdf, + has_audio_file, +}; use crate::support::{ Check, StreamExtract, TestResult, assert_contains, assert_reasoning_content, assert_reasoning_usage, contains_checks, extract_stream_end, get_big_content, seed_chat_req_simple, seed_chat_req_tool_simple, @@ -20,7 +23,7 @@ use value_ext::JsonValueExt; // region: --- Chat pub async fn common_test_chat_simple_ok(model: &str, checks: Option) -> TestResult<()> { - validate_checks(checks.clone(), Check::REASONING | Check::REASONING_USAGE)?; + validate_checks(checks.clone(), Check::REASONING_CONTENT | Check::REASONING_USAGE)?; // -- Setup & Fixtures let client = Client::default(); @@ -50,7 +53,7 @@ pub async fn common_test_chat_simple_ok(model: &str, checks: Option) -> T } // -- Check Reasoning Content - if contains_checks(checks, Check::REASONING) { + if contains_checks(checks, Check::REASONING_CONTENT) { assert_reasoning_content(&chat_res)?; } @@ -58,11 +61,19 @@ pub async fn common_test_chat_simple_ok(model: &str, checks: Option) -> T } // NOTE: here we still have the options about checking REASONING_USAGE, because Anthropic does not have reasoning token. -pub async fn common_test_chat_reasoning_ok(model: &str, checks: Option) -> TestResult<()> { +pub async fn common_test_chat_reasoning_ok( + model: &str, + reasoning_effort: ReasoningEffort, + checks: Option, +) -> TestResult<()> { // -- Setup & Fixtures let client = Client::default(); - let chat_req = seed_chat_req_simple(); - let options = ChatOptions::default().with_reasoning_effort(ReasoningEffort::High); + let chat_req = ChatRequest::new(vec![ + // -- Messages (deactivate to see the differences) + ChatMessage::system("Answer in one sentence. But make think hard to make sure it is not a trick question."), + ChatMessage::user("Why is the sky red?"), + ]); + let options = ChatOptions::default().with_reasoning_effort(reasoning_effort); // -- Exec let chat_res = client.exec_chat(model, chat_req, Some(&options)).await?; @@ -93,7 +104,7 @@ pub async fn common_test_chat_reasoning_ok(model: &str, checks: Option) - } // -- Check Reasoning Content - if contains_checks(checks, Check::REASONING) { + if contains_checks(checks, Check::REASONING_CONTENT) { let reasoning_content = chat_res .reasoning_content .as_deref() @@ -507,12 +518,102 @@ pub async fn common_test_chat_cache_explicit_system_ok(model: &str) -> TestResul Ok(()) } +/// Test for 1-hour TTL cache (Ephemeral1h). +/// Note: 1h TTL is only supported on Claude 4.5 models (Opus 4.5, Sonnet 4.5, Haiku 4.5). +pub async fn common_test_chat_cache_explicit_1h_ttl_ok(model: &str) -> TestResult<()> { + // -- Setup & Fixtures + let client = Client::default(); + let big_content = get_big_content()?; + let chat_req = ChatRequest::new(vec![ + // -- Messages + ChatMessage::system("Give a very short summary of what each of those files are about"), + ChatMessage::user(big_content).with_options(CacheControl::Ephemeral1h), + ]); + + // -- Exec + let chat_res = client.exec_chat(model, chat_req, None).await?; + + // -- Check Content + let content = chat_res.first_text().ok_or("Should have content")?; + assert!(!content.trim().is_empty(), "Content should not be empty"); + + // -- Check Usage + let usage = &chat_res.usage; + + let total_tokens = get_option_value!(usage.total_tokens); + let prompt_tokens_details = usage + .prompt_tokens_details + .as_ref() + .ok_or("Should have prompt_tokens_details")?; + let cache_creation_tokens = get_option_value!(prompt_tokens_details.cache_creation_tokens); + let cached_tokens = get_option_value!(prompt_tokens_details.cached_tokens); + + assert!( + cache_creation_tokens > 0 || cached_tokens > 0, + "one of cache_creation_tokens or cached_tokens should be greater than 0" + ); + assert!(total_tokens > 0, "total_tokens should be > 0"); + + // Note: cache_creation_details may or may not be present depending on provider response format + // The API may return TTL-specific breakdown in cache_creation_details when using different TTLs + + Ok(()) +} + +/// Streaming test for 1-hour TTL cache (Ephemeral1h). +/// Note: 1h TTL is only supported on Claude 4.5 models (Opus 4.5, Sonnet 4.5, Haiku 4.5). +pub async fn common_test_chat_stream_cache_explicit_1h_ttl_ok(model: &str) -> TestResult<()> { + // -- Setup & Fixtures + let client = Client::builder() + .with_chat_options(ChatOptions::default().with_capture_usage(true)) + .build(); + let big_content = get_big_content()?; + let chat_req = ChatRequest::new(vec![ + // -- Messages + ChatMessage::system("Give a very short summary of what each of those files are about"), + ChatMessage::user(big_content).with_options(CacheControl::Ephemeral1h), + ]); + + // -- Exec + let chat_res = client.exec_chat_stream(model, chat_req, None).await?; + + // -- Extract Stream content + let StreamExtract { + stream_end, + content, + reasoning_content: _, + } = extract_stream_end(chat_res.stream).await?; + let content = content.ok_or("extract_stream_end SHOULD have extracted some content")?; + + // -- Check Content + assert!(!content.trim().is_empty(), "Content should not be empty"); + + // -- Check Usage + let usage = stream_end.captured_usage.as_ref().ok_or("Should have captured_usage")?; + + let total_tokens = get_option_value!(usage.total_tokens); + let prompt_tokens_details = usage + .prompt_tokens_details + .as_ref() + .ok_or("Should have prompt_tokens_details")?; + let cache_creation_tokens = get_option_value!(prompt_tokens_details.cache_creation_tokens); + let cached_tokens = get_option_value!(prompt_tokens_details.cached_tokens); + + assert!( + cache_creation_tokens > 0 || cached_tokens > 0, + "one of cache_creation_tokens or cached_tokens should be greater than 0" + ); + assert!(total_tokens > 0, "total_tokens should be > 0"); + + Ok(()) +} + // endregion: --- Chat Explicit Cache // region: --- Chat Stream Tests pub async fn common_test_chat_stream_simple_ok(model: &str, checks: Option) -> TestResult<()> { - validate_checks(checks.clone(), Check::REASONING)?; + validate_checks(checks.clone(), Check::REASONING_CONTENT)?; // -- Setup & Fixtures let client = Client::default(); @@ -541,7 +642,7 @@ pub async fn common_test_chat_stream_simple_ok(model: &str, checks: Option TestResu } pub async fn common_test_chat_stream_capture_all_ok(model: &str, checks: Option) -> TestResult<()> { - validate_checks(checks.clone(), Check::REASONING | Check::REASONING_USAGE)?; + validate_checks(checks.clone(), Check::REASONING_CONTENT | Check::REASONING_USAGE)?; // -- Setup & Fixtures let mut chat_options = ChatOptions::default() @@ -604,7 +705,7 @@ pub async fn common_test_chat_stream_capture_all_ok(model: &str, checks: Option< .with_capture_content(true) .with_capture_reasoning_content(true); - if contains_checks(checks.clone(), Check::REASONING | Check::REASONING_USAGE) { + if contains_checks(checks.clone(), Check::REASONING_CONTENT | Check::REASONING_USAGE) { chat_options = chat_options.with_reasoning_effort(ReasoningEffort::Medium); } @@ -648,13 +749,44 @@ pub async fn common_test_chat_stream_capture_all_ok(model: &str, checks: Option< } // -- Check Reasoning Content - if contains_checks(checks, Check::REASONING) { + if contains_checks(checks, Check::REASONING_CONTENT) { let _reasoning_content = reasoning_content.ok_or("Should have reasoning content")?; } Ok(()) } +/// Just making the tool request, and checking the tool call response +/// `complete_check` if for LLMs that are better at giving back the unit and weather. +pub async fn common_test_chat_stream_tool_capture_ok(model: &str) -> TestResult<()> { + // -- Setup & Fixtures + let client = Client::default(); + let chat_req = seed_chat_req_tool_simple(); + let mut chat_options = ChatOptions::default().with_capture_tool_calls(true); + + // -- Exec + let chat_res = client.exec_chat_stream(model, chat_req, Some(&chat_options)).await?; + + // Extract Stream content + let StreamExtract { + stream_end, + content, + reasoning_content, + } = extract_stream_end(chat_res.stream).await?; + + // -- Check + let mut tool_calls = stream_end.captured_tool_calls().ok_or("Should have captured tools")?; + if tool_calls.is_empty() { + return Err("Should have tool calls in chat_res".into()); + } + let tool_call = tool_calls.pop().ok_or("Should have at least one tool call")?; + assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("city")?, "Paris"); + assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("country")?, "France"); + assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("unit")?, "C"); + + Ok(()) +} + // endregion: --- Chat Stream Tests // region: --- Binaries @@ -700,6 +832,55 @@ pub async fn common_test_chat_image_b64_ok(model: &str) -> TestResult<()> { Ok(()) } +pub async fn common_test_chat_image_file_ok(model: &str) -> TestResult<()> { + // -- Setup + let client = Client::default(); + + // -- Build & Exec + let mut chat_req = ChatRequest::default().with_system("Answer in one sentence"); + // This is similar to sending initial system chat messages (which will be cumulative with system chat messages) + chat_req = chat_req.append_message(ChatMessage::user(vec![ + ContentPart::from_text("What is in this picture?"), + ContentPart::from_binary_file(TEST_IMAGE_FILE_PATH)?, + ])); + + let chat_res = client.exec_chat(model, chat_req, None).await?; + + // -- Check + let res = chat_res.first_text().ok_or("Should have text result")?; + assert_contains(res, "duck"); + + Ok(()) +} + +pub async fn common_test_chat_audio_b64_ok(model: &str) -> TestResult<()> { + if !has_audio_file() { + println!("No test audio file. Skipping this test."); + return Ok(()); + } + + // -- Setup + let client = Client::default(); + + // -- Build & Exec + let mut chat_req = ChatRequest::default().with_system("Transcribe the audio"); + let cp_audio = ContentPart::from_binary_file(AUDIO_TEST_FILE_PATH)?; + // similar as the from_binary_file but manual + // let cp_audio = ContentPart::from_binary_base64("audio/wav", get_b64_audio()?, None); + + chat_req = chat_req.append_message(ChatMessage::user(vec![cp_audio])); + + let chat_res = client.exec_chat(model, chat_req, None).await?; + + // -- Check + let res = chat_res.first_text().ok_or("Should have text result")?; + // NOTE: here we make the test a little loose as the point of this test is not to test the model accuracy + assert_contains(res, "one small step"); + assert_contains(res, "one giant leap"); + + Ok(()) +} + pub async fn common_test_chat_pdf_b64_ok(model: &str) -> TestResult<()> { // -- Setup let client = Client::default(); diff --git a/tests/support/data.rs b/tests/support/data.rs index 0e03fe4f..031fcf36 100644 --- a/tests/support/data.rs +++ b/tests/support/data.rs @@ -3,14 +3,25 @@ use crate::support::TestResult; use base64::Engine; use base64::engine::general_purpose; +use simple_fs::SPath; use std::fs::File; use std::io::Read; -pub const IMAGE_URL_JPG_DUCK: &str = "https://upload.wikimedia.org/wikipedia/commons/thumb/b/bf/Bucephala-albeola-010.jpg/440px-Bucephala-albeola-010.jpg"; +pub const IMAGE_URL_JPG_DUCK: &str = "https://aipack.ai/images/test-duck.jpg"; +pub const AUDIO_TEST_FILE_PATH: &str = "./tests/data/phrase_neil_armstrong.wav"; +pub const TEST_IMAGE_FILE_PATH: &str = "./tests/data/duck-small.jpg"; /// Get the base64 of the image above (but resized/lower to fit 5kb) pub fn get_b64_duck() -> TestResult { - get_b64_file("./tests/data/duck-small.jpg") + get_b64_file(TEST_IMAGE_FILE_PATH) +} + +pub fn has_audio_file() -> bool { + SPath::new(AUDIO_TEST_FILE_PATH).exists() +} + +pub fn get_b64_audio() -> TestResult { + get_b64_file(AUDIO_TEST_FILE_PATH) } pub fn get_b64_pdf() -> TestResult { diff --git a/tests/support/helpers.rs b/tests/support/helpers.rs index 0d29ab0e..f3a0764e 100644 --- a/tests/support/helpers.rs +++ b/tests/support/helpers.rs @@ -33,7 +33,7 @@ bitflags::bitflags! { #[derive(Clone)] pub struct Check: u8 { /// Check if the - const REASONING = 0b00000001; + const REASONING_CONTENT = 0b00000001; const REASONING_USAGE = 0b00000010; const USAGE = 0b00000100; } @@ -91,7 +91,8 @@ pub async fn extract_stream_end(mut chat_stream: ChatStream) -> TestResult (), // nothing to do ChatStreamEvent::Chunk(s_chunk) => content.push(s_chunk.content), ChatStreamEvent::ReasoningChunk(s_chunk) => reasoning_content.push(s_chunk.content), - ChatStreamEvent::ToolCallChunk(_) => (), // ignore tool call chunks for now + ChatStreamEvent::ThoughtSignatureChunk(_) => (), // ignore thought signature chunks for now + ChatStreamEvent::ToolCallChunk(_) => (), // ignore tool call chunks for now ChatStreamEvent::End(s_end) => { stream_end = Some(s_end); break; diff --git a/tests/test_adapter_consistency.rs b/tests/test_adapter_consistency.rs index 5f1f86a5..f32456ab 100644 --- a/tests/test_adapter_consistency.rs +++ b/tests/test_adapter_consistency.rs @@ -90,18 +90,34 @@ fn get_expected_models() -> std::collections::HashMap> { ], ); - // Z.AI models + // Z.AI models (from upstream v0.6.0-alpha.2) expected.insert( "ZAi".to_string(), vec![ + "glm-4-plus".to_string(), "glm-4.6".to_string(), "glm-4.5".to_string(), - "glm-4".to_string(), - "glm-4.1v".to_string(), "glm-4.5v".to_string(), - "vidu".to_string(), - "vidu-q1".to_string(), - "vidu-2.0".to_string(), + "glm-4.5-x".to_string(), + "glm-4.5-air".to_string(), + "glm-4.5-airx".to_string(), + "glm-4-32b-0414-128k".to_string(), + "glm-4.5-flash".to_string(), + "glm-4-air-250414".to_string(), + "glm-4-flashx-250414".to_string(), + "glm-4-flash-250414".to_string(), + "glm-4-air".to_string(), + "glm-4-airx".to_string(), + "glm-4-long".to_string(), + "glm-4-flash".to_string(), + "glm-4v-plus-0111".to_string(), + "glm-4v-flash".to_string(), + "glm-z1-air".to_string(), + "glm-z1-airx".to_string(), + "glm-z1-flash".to_string(), + "glm-z1-flashx".to_string(), + "glm-4.1v-thinking-flash".to_string(), + "glm-4.1v-thinking-flashx".to_string(), ], ); diff --git a/tests/test_verify_model_lists.rs b/tests/test_verify_model_lists.rs index 4bb70ff3..5d06d590 100644 --- a/tests/test_verify_model_lists.rs +++ b/tests/test_verify_model_lists.rs @@ -18,18 +18,16 @@ fn get_expected_models() -> HashMap> { ], ); - // Z.AI models (from src/adapter/adapters/zai/adapter_impl.rs) + // Z.AI models (from upstream v0.6.0-alpha.2 src/adapter/adapters/zai/adapter_impl.rs) + // Note: Adapter is now named "Zai" (not "ZAi") expected.insert( - "ZAi".to_string(), + "Zai".to_string(), vec![ "glm-4.6".to_string(), "glm-4.5".to_string(), - "glm-4".to_string(), - "glm-4.1v".to_string(), "glm-4.5v".to_string(), - "vidu".to_string(), - "vidu-q1".to_string(), - "vidu-2.0".to_string(), + "glm-4-plus".to_string(), + "glm-4-flash".to_string(), ], ); @@ -57,7 +55,7 @@ async fn test_provider_model_lists() -> Result<(), Box> { let client = Client::default(); let expected_models = get_expected_models(); - println!("πŸ” Verifying model lists match actual provider APIs...\n"); + println!("Verifying model lists match actual provider APIs...\n"); let mut all_passed = true; @@ -73,15 +71,15 @@ async fn test_provider_model_lists() -> Result<(), Box> { // Check if the model resolves to the expected adapter if actual_adapter == provider { - println!(" βœ… {} -> {}", model, actual_adapter); + println!(" [OK] {} -> {}", model, actual_adapter); } else { - println!(" ❌ {} -> {} (expected {})", model, actual_adapter, provider); + println!(" [FAIL] {} -> {} (expected {})", model, actual_adapter, provider); provider_passed = false; all_passed = false; } } Err(e) => { - println!(" ❌ {} -> ERROR: {}", model, e); + println!(" [FAIL] {} -> ERROR: {}", model, e); provider_passed = false; all_passed = false; } @@ -89,17 +87,17 @@ async fn test_provider_model_lists() -> Result<(), Box> { } if provider_passed { - println!(" βœ“ All {} models resolved correctly\n", provider); + println!(" All {} models resolved correctly\n", provider); } else { - println!(" βœ— Some {} models failed to resolve\n", provider); + println!(" Some {} models failed to resolve\n", provider); } } - println!("πŸ“Š Summary:"); + println!("Summary:"); if all_passed { - println!("βœ… All model lists verified successfully!"); + println!("All model lists verified successfully!"); } else { - println!("❌ Some model lists need updating"); + println!("Some model lists need updating"); panic!("Model lists do not match actual provider APIs"); } @@ -111,22 +109,21 @@ async fn test_provider_model_lists() -> Result<(), Box> { async fn test_model_resolution_edge_cases() -> Result<(), Box> { let client = Client::default(); - println!("πŸ§ͺ Testing edge cases and conflicts...\n"); + println!("Testing edge cases and conflicts...\n"); // Test that Z.AI models work correctly + // Note: Adapter is now "Zai" (not "ZAi") from upstream let test_cases = vec![ - // Z.AI models should resolve to ZAi - ("glm-4.6", "ZAi"), - ("glm-4.5", "ZAi"), - ("glm-4", "ZAi"), - ("vidu", "ZAi"), - ("vidu-q1", "ZAi"), + // Z.AI models should resolve to Zai + ("glm-4.6", "Zai"), + ("glm-4.5", "Zai"), + ("glm-4-plus", "Zai"), // DeepSeek models ("deepseek-coder", "DeepSeek"), ("deepseek-reasoner", "DeepSeek"), ("deepseek-chat", "DeepSeek"), // Model namespace - ("zai::glm-4.6", "ZAi"), + ("zai::glm-4.6", "Zai"), ("openai::gpt-4o", "OpenAI"), ("cerebras::llama3.1-8b", "Cerebras"), ]; @@ -138,20 +135,23 @@ async fn test_model_resolution_edge_cases() -> Result<(), Box { let actual_adapter = format!("{:?}", target.model.adapter_kind); if actual_adapter == expected_adapter { - println!(" βœ… {} -> {}", model, actual_adapter); + println!(" [OK] {} -> {}", model, actual_adapter); } else { - println!(" ❌ {} -> {} (expected {})", model, actual_adapter, expected_adapter); + println!( + " [FAIL] {} -> {} (expected {})", + model, actual_adapter, expected_adapter + ); all_passed = false; } } Err(e) => { - println!(" ❌ {} -> ERROR: {}", model, e); + println!(" [FAIL] {} -> ERROR: {}", model, e); all_passed = false; } } } - println!("\n✨ Edge case tests completed!"); + println!("\nEdge case tests completed!"); assert!(all_passed, "Some edge cases failed"); @@ -163,7 +163,7 @@ async fn test_model_resolution_edge_cases() -> Result<(), Box Result<(), Box> { let client = Client::default(); - println!("🌐 Testing OpenRouter model patterns...\n"); + println!("Testing OpenRouter model patterns...\n"); // These should resolve to OpenRouter (non-namespaced with /) let openrouter_patterns = vec![ @@ -178,20 +178,20 @@ async fn test_openrouter_model_patterns() -> Result<(), Box { let adapter = format!("{:?}", target.model.adapter_kind); if adapter == "OpenRouter" { - println!(" βœ… {} -> {}", model, adapter); + println!(" [OK] {} -> {}", model, adapter); } else { - println!(" ❌ {} -> {} (expected OpenRouter)", model, adapter); + println!(" [FAIL] {} -> {} (expected OpenRouter)", model, adapter); return Err(format!("OpenRouter pattern failed for {}", model).into()); } } Err(e) => { - println!(" ❌ {} -> ERROR: {}", model, e); + println!(" [FAIL] {} -> ERROR: {}", model, e); return Err(format!("Failed to resolve {}: {}", model, e).into()); } } } - println!("\nβœ… OpenRouter patterns verified!"); + println!("\nOpenRouter patterns verified!"); Ok(()) } diff --git a/tests/test_zai_adapter.rs b/tests/test_zai_adapter.rs index 402c6c97..c43b5b67 100644 --- a/tests/test_zai_adapter.rs +++ b/tests/test_zai_adapter.rs @@ -1,4 +1,5 @@ -//! Test for Z.AI adapter support +//! Test for Z.AI adapter support (upstream v0.6.0-alpha.2) +//! Note: Adapter is now named "Zai" (not "ZAi") use genai::Client; use genai::chat::{ChatMessage, ChatRequest}; @@ -7,31 +8,27 @@ use genai::chat::{ChatMessage, ChatRequest}; async fn test_zai_model_resolution() -> Result<(), Box> { let client = Client::default(); - // Test that Z.AI models resolve correctly (only models in ZAI MODELS list) - let zai_models = vec!["glm-4.6", "glm-4", "glm-4.5", "vidu"]; + // Test that Z.AI models resolve correctly (models in upstream ZAI MODELS list) + let zai_models = vec!["glm-4.6", "glm-4.5", "glm-4-plus", "glm-4-flash"]; for model in zai_models { let target = client.resolve_service_target(model).await?; - assert_eq!(format!("{:?}", target.model.adapter_kind), "ZAi"); - println!("βœ… {} -> ZAi", model); + assert_eq!(format!("{:?}", target.model.adapter_kind), "Zai"); + println!("[OK] {} -> Zai", model); } // Test that namespaced Z.AI works let target = client.resolve_service_target("zai::glm-4.6").await?; - assert_eq!(format!("{:?}", target.model.adapter_kind), "ZAi"); - println!("βœ… zai::glm-4.6 -> ZAi"); + assert_eq!(format!("{:?}", target.model.adapter_kind), "Zai"); + println!("[OK] zai::glm-4.6 -> Zai"); - // Test that other GLM models not in Z.AI list go to Zhipu (not Ollama) + // Test that GLM models starting with "glm" go to Zai (per upstream logic) + // In upstream, any model starting with "glm" resolves to Zai adapter let target = client.resolve_service_target("glm-2").await?; - assert_eq!(format!("{:?}", target.model.adapter_kind), "Zhipu"); - println!("βœ… glm-2 -> Zhipu (not in Z.AI list, goes to Zhipu instead)"); + assert_eq!(format!("{:?}", target.model.adapter_kind), "Zai"); + println!("[OK] glm-2 -> Zai (any glm-* goes to Zai in upstream)"); - // Test that glm-3-turbo (not supported by Z.AI) goes to Zhipu - let target = client.resolve_service_target("glm-3-turbo").await?; - assert_eq!(format!("{:?}", target.model.adapter_kind), "Zhipu"); - println!("βœ… glm-3-turbo -> Zhipu (turbo models not supported by Z.AI)"); - - println!("\n✨ Z.AI model resolution tests passed!"); + println!("\nZ.AI model resolution tests passed!"); Ok(()) } @@ -39,18 +36,18 @@ async fn test_zai_model_resolution() -> Result<(), Box> { async fn test_zai_adapter_integration() -> Result<(), Box> { // Only run if API key is available if std::env::var("ZAI_API_KEY").is_err() { - println!("⚠️ ZAI_API_KEY not set, skipping integration test"); + println!("ZAI_API_KEY not set, skipping integration test"); return Ok(()); } let client = Client::default(); let chat_req = ChatRequest::new(vec![ChatMessage::user("Say 'Hello from Z.AI!'")]); - let result = client.exec_chat("glm-4", chat_req, None).await?; + let result = client.exec_chat("glm-4-flash", chat_req, None).await?; let content = result.first_text().ok_or("Should have content")?; assert!(!content.is_empty()); - println!("βœ… Z.AI response: {}", content); + println!("[OK] Z.AI response: {}", content); Ok(()) } diff --git a/tests/tests_p_anthropic.rs b/tests/tests_p_anthropic.rs index 93530135..22f3d50d 100644 --- a/tests/tests_p_anthropic.rs +++ b/tests/tests_p_anthropic.rs @@ -2,6 +2,7 @@ mod support; use crate::support::{Check, TestResult, common_tests}; use genai::adapter::AdapterKind; +use genai::chat::ReasoningEffort; use genai::resolver::AuthData; use serial_test::serial; @@ -10,7 +11,8 @@ use serial_test::serial; // "claude-sonnet-4-20250514" (fail on test_chat_json_mode_ok) // const MODEL: &str = "claude-3-5-haiku-latest"; -const MODEL_THINKING: &str = "claude-sonnet-4-5-20250929"; +// const MODEL_THINKING: &str = "claude-sonnet-4-5-20250929"; +const MODEL_THINKING: &str = "claude-opus-4-5"; const MODEL_NS: &str = "anthropic::claude-3-5-haiku-latest"; // region: --- Chat @@ -25,7 +27,8 @@ async fn test_chat_simple_ok() -> TestResult<()> { #[serial(anthropic)] async fn test_chat_reasoning_ok() -> TestResult<()> { // NOTE: Does not test REASONING_USAGE as Anthropic does not report it - common_tests::common_test_chat_reasoning_ok(MODEL_THINKING, Some(Check::REASONING)).await + common_tests::common_test_chat_reasoning_ok(MODEL_THINKING, ReasoningEffort::High, Some(Check::REASONING_CONTENT)) + .await } #[tokio::test] @@ -75,6 +78,20 @@ async fn test_chat_cache_explicit_system_ok() -> TestResult<()> { common_tests::common_test_chat_cache_explicit_system_ok(MODEL).await } +/// Test for 1-hour TTL cache (only supported on Claude 4.5 models) +#[tokio::test] +#[serial(anthropic)] +async fn test_chat_cache_explicit_1h_ttl_ok() -> TestResult<()> { + common_tests::common_test_chat_cache_explicit_1h_ttl_ok(MODEL_THINKING).await +} + +/// Streaming test for 1-hour TTL cache (only supported on Claude 4.5 models) +#[tokio::test] +#[serial(anthropic)] +async fn test_chat_stream_cache_explicit_1h_ttl_ok() -> TestResult<()> { + common_tests::common_test_chat_stream_cache_explicit_1h_ttl_ok(MODEL_THINKING).await +} + // endregion: --- Chat Explicit Cache // region: --- Chat Stream Tests @@ -119,6 +136,11 @@ async fn test_chat_binary_pdf_b64_ok() -> TestResult<()> { common_tests::common_test_chat_pdf_b64_ok(MODEL).await } +#[tokio::test] +async fn test_chat_binary_image_file_ok() -> TestResult<()> { + common_tests::common_test_chat_image_file_ok(MODEL).await +} + #[tokio::test] async fn test_chat_binary_multi_b64_ok() -> TestResult<()> { common_tests::common_test_chat_multi_binary_b64_ok(MODEL).await @@ -156,7 +178,7 @@ async fn test_resolver_auth_ok() -> TestResult<()> { #[tokio::test] async fn test_list_models() -> TestResult<()> { - common_tests::common_test_list_models(AdapterKind::Anthropic, "claude-sonnet-4-5-20250929").await + common_tests::common_test_list_models(AdapterKind::Anthropic, "claude-sonnet-4-5").await } // endregion: --- List diff --git a/tests/tests_p_cohere.rs b/tests/tests_p_cohere.rs index 4fed909a..fe581857 100644 --- a/tests/tests_p_cohere.rs +++ b/tests/tests_p_cohere.rs @@ -3,6 +3,7 @@ mod support; use crate::support::{TestResult, common_tests}; use genai::adapter::AdapterKind; use genai::resolver::AuthData; +use serial_test::serial; const MODEL: &str = "command-r7b-12-2024"; const MODEL_NS: &str = "cohere::command-r7b-12-2024"; @@ -10,21 +11,25 @@ const MODEL_NS: &str = "cohere::command-r7b-12-2024"; // region: --- Chat #[tokio::test] +#[serial(cohere)] async fn test_chat_simple_ok() -> TestResult<()> { common_tests::common_test_chat_simple_ok(MODEL, None).await } #[tokio::test] +#[serial(cohere)] async fn test_chat_namespaced_ok() -> TestResult<()> { common_tests::common_test_chat_simple_ok(MODEL_NS, None).await } #[tokio::test] +#[serial(cohere)] async fn test_chat_multi_system_ok() -> TestResult<()> { common_tests::common_test_chat_multi_system_ok(MODEL).await } #[tokio::test] +#[serial(cohere)] async fn test_chat_stop_sequences_ok() -> TestResult<()> { common_tests::common_test_chat_stop_sequences_ok(MODEL).await } @@ -34,6 +39,7 @@ async fn test_chat_stop_sequences_ok() -> TestResult<()> { // region: --- Chat Stream Tests #[tokio::test] +#[serial(cohere)] async fn test_chat_stream_simple_ok() -> TestResult<()> { common_tests::common_test_chat_stream_simple_ok(MODEL, None).await } @@ -41,16 +47,18 @@ async fn test_chat_stream_simple_ok() -> TestResult<()> { // NOTE 2024-06-23 - Occasionally, the last stream message sent by Cohere is malformed and cannot be parsed. // Will investigate further if requested. // #[tokio::test] +#[serial(cohere)] // async fn test_chat_stream_capture_content_ok() -> TestResult<()> { // common_tests::common_test_chat_stream_capture_content_ok(MODEL).await // } - #[tokio::test] +#[serial(cohere)] async fn test_chat_stream_capture_all_ok() -> TestResult<()> { common_tests::common_test_chat_stream_capture_all_ok(MODEL, None).await } #[tokio::test] +#[serial(cohere)] async fn test_chat_temperature_ok() -> TestResult<()> { common_tests::common_test_chat_temperature_ok(MODEL).await } @@ -60,6 +68,7 @@ async fn test_chat_temperature_ok() -> TestResult<()> { // region: --- Resolver Tests #[tokio::test] +#[serial(cohere)] async fn test_resolver_auth_ok() -> TestResult<()> { common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("COHERE_API_KEY")).await } @@ -69,6 +78,7 @@ async fn test_resolver_auth_ok() -> TestResult<()> { // region: --- List #[tokio::test] +#[serial(cohere)] async fn test_list_models() -> TestResult<()> { common_tests::common_test_list_models(AdapterKind::Cohere, "command-r-plus").await } diff --git a/tests/tests_p_deepseek_reasoning.rs b/tests/tests_p_deepseek_reasoning.rs index d6533ea9..54324e4a 100644 --- a/tests/tests_p_deepseek_reasoning.rs +++ b/tests/tests_p_deepseek_reasoning.rs @@ -11,7 +11,7 @@ const MODEL: &str = "deepseek-reasoner"; #[tokio::test] async fn test_chat_simple_ok() -> TestResult<()> { - common_tests::common_test_chat_simple_ok(MODEL, Some(Check::REASONING)).await + common_tests::common_test_chat_simple_ok(MODEL, Some(Check::REASONING_CONTENT)).await } #[tokio::test] @@ -47,7 +47,7 @@ async fn test_chat_reasoning_normalize_ok() -> TestResult<()> { #[tokio::test] async fn test_chat_stream_simple_ok() -> TestResult<()> { - common_tests::common_test_chat_stream_simple_ok(MODEL, Some(Check::REASONING)).await + common_tests::common_test_chat_stream_simple_ok(MODEL, Some(Check::REASONING_CONTENT)).await } #[tokio::test] @@ -57,7 +57,7 @@ async fn test_chat_stream_capture_content_ok() -> TestResult<()> { #[tokio::test] async fn test_chat_stream_capture_all_ok() -> TestResult<()> { - common_tests::common_test_chat_stream_capture_all_ok(MODEL, Some(Check::REASONING)).await + common_tests::common_test_chat_stream_capture_all_ok(MODEL, Some(Check::REASONING_CONTENT)).await } // endregion: --- Chat Stream Tests diff --git a/tests/tests_p_gemini.rs b/tests/tests_p_gemini.rs index b4ea38d1..aec90373 100644 --- a/tests/tests_p_gemini.rs +++ b/tests/tests_p_gemini.rs @@ -2,10 +2,13 @@ mod support; use crate::support::{Check, TestResult, common_tests}; use genai::adapter::AdapterKind; +use genai::chat::ReasoningEffort; use genai::resolver::AuthData; // "gemini-2.5-flash" "gemini-2.5-pro" "gemini-2.5-flash-lite" // "gemini-2.5-flash-zero" +const MODEL_GPRO_3: &str = "gemini-3-pro-preview"; +const MODEL_FLASH_3: &str = "gemini-3-flash-preview"; // pure gem, fast, cheap, and good! const MODEL: &str = "gemini-2.5-flash"; const MODEL_NS: &str = "gemini::gemini-2.5-flash"; @@ -16,6 +19,16 @@ async fn test_chat_simple_ok() -> TestResult<()> { common_tests::common_test_chat_simple_ok(MODEL, None).await } +#[tokio::test] +async fn test_chat_reasoning_ok() -> TestResult<()> { + common_tests::common_test_chat_reasoning_ok( + MODEL_GPRO_3, + ReasoningEffort::Low, + Some(Check::REASONING_USAGE | Check::REASONING_USAGE), + ) + .await +} + #[tokio::test] async fn test_chat_namespaced_ok() -> TestResult<()> { common_tests::common_test_chat_simple_ok(MODEL_NS, None).await @@ -91,6 +104,11 @@ async fn test_chat_binary_pdf_b64_ok() -> TestResult<()> { common_tests::common_test_chat_pdf_b64_ok(MODEL).await } +#[tokio::test] +async fn test_chat_binary_image_file_ok() -> TestResult<()> { + common_tests::common_test_chat_image_file_ok(MODEL).await +} + #[tokio::test] async fn test_chat_binary_multi_b64_ok() -> TestResult<()> { common_tests::common_test_chat_multi_binary_b64_ok(MODEL).await @@ -109,6 +127,74 @@ async fn test_tool_simple_ok() -> TestResult<()> { async fn test_tool_full_flow_ok() -> TestResult<()> { common_tests::common_test_tool_full_flow_ok(MODEL).await } + +#[tokio::test] +async fn test_tool_deterministic_history_gemini_3_ok() -> TestResult<()> { + use genai::chat::{ChatMessage, ChatRequest, Tool, ToolCall, ToolResponse}; + use serde_json::json; + + let client = genai::Client::default(); + + let weather_tool = Tool::new("get_weather").with_schema(json!({ + "type": "object", + "properties": { + "city": { "type": "string" }, + "unit": { "type": "string", "enum": ["C", "F"] } + }, + "required": ["city", "unit"] + })); + + // Pre-seed history with a "synthetic" tool call (missing thought signatures) + let messages = vec![ + ChatMessage::user("What's the weather like in Paris?"), + ChatMessage::assistant(vec![ToolCall { + call_id: "call_123".to_string(), + fn_name: "get_weather".to_string(), + fn_arguments: json!({"city": "Paris", "unit": "C"}), + thought_signatures: None, + }]), + ChatMessage::from(ToolResponse::new( + "call_123".to_string(), + json!({"temperature": 15, "condition": "Cloudy"}).to_string(), + )), + ]; + + let chat_req = ChatRequest::new(messages).with_tools(vec![weather_tool]); + + // This verifies that the adapter correctly injects 'skip_thought_signature_validator'. + // (Otherwise Gemini 3 would return a 400 error.) + let chat_res = client.exec_chat(MODEL_GPRO_3, chat_req, None).await?; + + assert!( + chat_res.first_text().is_some(), + "Expected a text response from the model" + ); + + Ok(()) +} + +// NOTE: Issue of this test is that it is pretty slow +#[tokio::test] +async fn test_tool_google_web_search_ok() -> TestResult<()> { + use genai::chat::{ChatRequest, Tool}; + use serde_json::json; + + // -- Fixtures & Setup + let client = genai::Client::default(); + let web_search_tool = Tool::new("googleSearch").with_config(json!({})); + let chat_req = + ChatRequest::from_user("What is the latest version of Rust? (be concise)").append_tool(web_search_tool); + + // Exec + let res = client.exec_chat(MODEL_FLASH_3, chat_req, None).await?; + + // Check + let res_txt = res.content.into_first_text().ok_or("Should have result")?; + assert!(res_txt.contains("Rust"), "should contains 'Rust'"); + + Ok(()) +} + // endregion: --- Tool Tests // region: --- Resolver Tests diff --git a/tests/tests_p_ollama.rs b/tests/tests_p_ollama.rs index 3c4ff68a..8bac571f 100644 --- a/tests/tests_p_ollama.rs +++ b/tests/tests_p_ollama.rs @@ -6,8 +6,9 @@ use genai::resolver::AuthData; // "gemma3:4b" "phi3:latest" "gpt-oss:20b" // NOTE: "gpt-oss:20b" has some issues on json_mode, stop_sequence -const MODEL: &str = "gemma3:4b"; // +const MODEL: &str = "gemma3:4b"; // const MODEL_NS: &str = "ollama::gemma3:4b"; +const MODEL_TOOL: &str = "ollama::gpt-oss:20b"; // region: --- Chat @@ -65,6 +66,19 @@ async fn test_chat_stream_capture_content_ok() -> TestResult<()> { // endregion: --- Chat Stream Tests +// region: --- Tool Tests + +#[tokio::test] +async fn test_tool_simple_ok() -> TestResult<()> { + common_tests::common_test_tool_simple_ok(MODEL_TOOL).await +} + +#[tokio::test] +async fn test_tool_full_flow_ok() -> TestResult<()> { + common_tests::common_test_tool_full_flow_ok(MODEL_TOOL).await +} +// endregion: --- Tool Tests + /* Added Binary Tests region (commented-out until Ollama supports binary inputs) */ // region: --- Binary Tests diff --git a/tests/tests_p_ollama_reasoning.rs b/tests/tests_p_ollama_reasoning.rs index bc19489e..eeda2911 100644 --- a/tests/tests_p_ollama_reasoning.rs +++ b/tests/tests_p_ollama_reasoning.rs @@ -1,9 +1,12 @@ mod support; -use crate::support::{TestResult, common_tests}; +use crate::support::{TestResult, common_tests, seed_chat_req_simple}; +use genai::Client; use genai::adapter::AdapterKind; +use genai::chat::ChatStreamEvent; use genai::resolver::AuthData; use serial_test::serial; +use tokio_stream::StreamExt; // NOTE: Sometimes the 1.5b model does not provide the reasoning or has some issues. // Rerunning the test or switching to the 8b model would generally solve the issues. @@ -11,6 +14,7 @@ use serial_test::serial; // NOTE: Also, #[serial(ollama)] seems more reliable when using it. const MODEL: &str = "deepseek-r1:1.5b"; // "deepseek-r1:8b" "deepseek-r1:1.5b" +const MODEL_QWEN3: &str = "qwen3:4b"; // region: --- Chat @@ -69,6 +73,68 @@ async fn test_chat_stream_capture_content_ok() -> TestResult<()> { common_tests::common_test_chat_stream_capture_content_ok(MODEL).await } +#[tokio::test] +#[serial(ollama)] +async fn test_chat_stream_reasoning_chunk_ok() -> TestResult<()> { + let client = Client::default(); + let chat_req = seed_chat_req_simple(); + + let chat_res = client.exec_chat_stream(MODEL_QWEN3, chat_req, None).await?; + let mut stream = chat_res.stream; + let mut reasoning_content = String::new(); + + while let Some(result) = stream.next().await { + match result? { + ChatStreamEvent::ReasoningChunk(chunk) => { + reasoning_content.push_str(&chunk.content); + break; + } + ChatStreamEvent::End(_) => break, + _ => {} + } + } + assert!(!reasoning_content.is_empty(), "reasoning_content should not be empty"); + + Ok(()) +} + +#[tokio::test] +#[serial(ollama)] +async fn test_chat_stream_non_empty_chunk_deepseek_ok() -> TestResult<()> { + let client = Client::default(); + let chat_req = seed_chat_req_simple(); + + let chat_res = client.exec_chat_stream(MODEL, chat_req, None).await?; + let mut stream = chat_res.stream; + let mut found_non_empty = false; + + while let Some(result) = stream.next().await { + match result? { + ChatStreamEvent::Chunk(chunk) => { + if !chunk.content.is_empty() { + found_non_empty = true; + break; + } + } + ChatStreamEvent::ReasoningChunk(chunk) => { + if !chunk.content.is_empty() { + found_non_empty = true; + break; + } + } + ChatStreamEvent::End(_) => break, + _ => {} + } + } + + assert!( + found_non_empty, + "stream should yield non-empty content or reasoning chunks" + ); + + Ok(()) +} + // /// COMMENTED FOR NOW AS OLLAMA OpenAI Compatibility Layer does not support // /// usage tokens when streaming. See https://github.com/ollama/ollama/issues/4448 // #[tokio::test] diff --git a/tests/tests_p_openai.rs b/tests/tests_p_openai.rs index 8c3571e4..f5eaa3a2 100644 --- a/tests/tests_p_openai.rs +++ b/tests/tests_p_openai.rs @@ -2,11 +2,14 @@ mod support; use crate::support::{Check, TestResult, common_tests}; use genai::adapter::AdapterKind; +use genai::chat::ReasoningEffort; use genai::resolver::AuthData; // note: "gpt-4o-mini" has issue when image & pdf // as for 2025-08-08 gpt-5-mini does not support temperature & stop sequence -const MODEL: &str = "gpt-5-mini"; +const MODEL_LATEST: &str = "gpt-5.2"; +const MODEL_GPT_5_MINI: &str = "gpt-5-mini"; // for the streaming reasoning test +const AUDIO_MODEL: &str = "gpt-audio-mini"; const MODEL2: &str = "gpt-4.1-mini"; // for temperature & stop sequence const MODEL_NS: &str = "openai::gpt-4.1-mini"; @@ -18,24 +21,31 @@ async fn test_chat_reasoning_minimal_ok() -> TestResult<()> { common_tests::common_test_chat_simple_ok("gpt-5-mini-minimal", None).await } +// gpt-5-pro (different api than gpt-5) +// expensive, so, will be commented most of the time. +// #[tokio::test] +// async fn test_chat_gpt_5_pro_simple_ok() -> TestResult<()> { +// common_tests::common_test_chat_simple_ok("gpt-5-pro", None).await +// } + // endregion: --- Provider Specific // region: --- Chat #[tokio::test] async fn test_chat_simple_ok() -> TestResult<()> { - common_tests::common_test_chat_simple_ok(MODEL, None).await + common_tests::common_test_chat_simple_ok(MODEL_LATEST, None).await } #[tokio::test] async fn test_chat_reasoning_ok() -> TestResult<()> { // For now, do not test Check::REASONING, for OpenAI as it is not captured - common_tests::common_test_chat_reasoning_ok(MODEL, Some(Check::REASONING_USAGE)).await + common_tests::common_test_chat_reasoning_ok(MODEL_LATEST, ReasoningEffort::High, Some(Check::REASONING_USAGE)).await } #[tokio::test] async fn test_chat_verbosity_ok() -> TestResult<()> { - common_tests::common_test_chat_verbosity_ok(MODEL).await + common_tests::common_test_chat_verbosity_ok(MODEL_GPT_5_MINI).await } #[tokio::test] @@ -45,17 +55,17 @@ async fn test_chat_namespaced_ok() -> TestResult<()> { #[tokio::test] async fn test_chat_multi_system_ok() -> TestResult<()> { - common_tests::common_test_chat_multi_system_ok(MODEL).await + common_tests::common_test_chat_multi_system_ok(MODEL_LATEST).await } #[tokio::test] async fn test_chat_json_mode_ok() -> TestResult<()> { - common_tests::common_test_chat_json_mode_ok(MODEL, Some(Check::USAGE)).await + common_tests::common_test_chat_json_mode_ok(MODEL_LATEST, Some(Check::USAGE)).await } #[tokio::test] async fn test_chat_json_structured_ok() -> TestResult<()> { - common_tests::common_test_chat_json_structured_ok(MODEL, Some(Check::USAGE)).await + common_tests::common_test_chat_json_structured_ok(MODEL_LATEST, Some(Check::USAGE)).await } #[tokio::test] @@ -74,7 +84,7 @@ async fn test_chat_stop_sequences_ok() -> TestResult<()> { #[tokio::test] async fn test_chat_cache_implicit_simple_ok() -> TestResult<()> { - common_tests::common_test_chat_cache_implicit_simple_ok(MODEL).await + common_tests::common_test_chat_cache_implicit_simple_ok(MODEL_GPT_5_MINI).await } // endregion: --- Chat Implicit Cache @@ -83,18 +93,24 @@ async fn test_chat_cache_implicit_simple_ok() -> TestResult<()> { #[tokio::test] async fn test_chat_stream_simple_ok() -> TestResult<()> { - common_tests::common_test_chat_stream_simple_ok(MODEL, None).await + common_tests::common_test_chat_stream_simple_ok(MODEL_LATEST, None).await } #[tokio::test] async fn test_chat_stream_capture_content_ok() -> TestResult<()> { - common_tests::common_test_chat_stream_capture_content_ok(MODEL).await + common_tests::common_test_chat_stream_capture_content_ok(MODEL_LATEST).await } #[tokio::test] async fn test_chat_stream_capture_all_ok() -> TestResult<()> { + // NOTE: gpt-5.1 even when reasoning is Medium, does not give reasoning when simple chat when streaming + common_tests::common_test_chat_stream_capture_all_ok(MODEL_GPT_5_MINI, Some(Check::REASONING_USAGE)).await +} + +#[tokio::test] +async fn test_chat_stream_tool_capture_ok() -> TestResult<()> { // NOTE: For now the OpenAI Adapter do not capture the thinking as not available in chat completions - common_tests::common_test_chat_stream_capture_all_ok(MODEL, Some(Check::REASONING_USAGE)).await + common_tests::common_test_chat_stream_tool_capture_ok(MODEL_LATEST).await } // endregion: --- Chat Stream Tests @@ -103,22 +119,32 @@ async fn test_chat_stream_capture_all_ok() -> TestResult<()> { #[tokio::test] async fn test_chat_binary_image_url_ok() -> TestResult<()> { - common_tests::common_test_chat_image_url_ok(MODEL).await + common_tests::common_test_chat_image_url_ok(MODEL_LATEST).await } #[tokio::test] async fn test_chat_binary_image_b64_ok() -> TestResult<()> { - common_tests::common_test_chat_image_b64_ok(MODEL).await + common_tests::common_test_chat_image_b64_ok(MODEL_LATEST).await +} + +#[tokio::test] +async fn test_chat_binary_image_file_ok() -> TestResult<()> { + common_tests::common_test_chat_image_file_ok(MODEL_LATEST).await +} + +#[tokio::test] +async fn test_chat_binary_audio_b64_ok() -> TestResult<()> { + common_tests::common_test_chat_audio_b64_ok(AUDIO_MODEL).await } #[tokio::test] async fn test_chat_binary_pdf_b64_ok() -> TestResult<()> { - common_tests::common_test_chat_pdf_b64_ok(MODEL).await + common_tests::common_test_chat_pdf_b64_ok(MODEL_LATEST).await } #[tokio::test] async fn test_chat_binary_multi_b64_ok() -> TestResult<()> { - common_tests::common_test_chat_multi_binary_b64_ok(MODEL).await + common_tests::common_test_chat_multi_binary_b64_ok(MODEL_LATEST).await } // endregion: --- Binary Tests @@ -127,12 +153,12 @@ async fn test_chat_binary_multi_b64_ok() -> TestResult<()> { #[tokio::test] async fn test_tool_simple_ok() -> TestResult<()> { - common_tests::common_test_tool_simple_ok(MODEL).await + common_tests::common_test_tool_simple_ok(MODEL_LATEST).await } #[tokio::test] async fn test_tool_full_flow_ok() -> TestResult<()> { - common_tests::common_test_tool_full_flow_ok(MODEL).await + common_tests::common_test_tool_full_flow_ok(MODEL_LATEST).await } // endregion: --- Tool Tests @@ -140,7 +166,7 @@ async fn test_tool_full_flow_ok() -> TestResult<()> { #[tokio::test] async fn test_resolver_auth_ok() -> TestResult<()> { - common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("OPENAI_API_KEY")).await + common_tests::common_test_resolver_auth_ok(MODEL_LATEST, AuthData::from_env("OPENAI_API_KEY")).await } // endregion: --- Resolver Tests diff --git a/tests/tests_p_openai_resp.rs b/tests/tests_p_openai_resp.rs index 4f69febf..24250d95 100644 --- a/tests/tests_p_openai_resp.rs +++ b/tests/tests_p_openai_resp.rs @@ -14,7 +14,7 @@ const MODEL_NS: &str = "openai_resp::gpt-5-mini"; // openai specific #[tokio::test] async fn test_chat_reasoning_minimal_ok() -> TestResult<()> { - common_tests::common_test_chat_simple_ok("gpt-5-codex-minimal", None).await + common_tests::common_test_chat_simple_ok("gpt-5-minimal", None).await } // endregion: --- Provider Specific diff --git a/tests/tests_p_zhipu.rs b/tests/tests_p_zai.rs similarity index 95% rename from tests/tests_p_zhipu.rs rename to tests/tests_p_zai.rs index d32f53df..70bc6952 100644 --- a/tests/tests_p_zhipu.rs +++ b/tests/tests_p_zai.rs @@ -5,7 +5,7 @@ use genai::adapter::AdapterKind; use genai::resolver::AuthData; const MODEL: &str = "glm-4-plus"; -const MODEL_NS: &str = "zhipu::glm-4-plus"; +const MODEL_NS: &str = "zai::glm-4-plus"; const MODEL_V: &str = "glm-4v-flash"; // Visual language model does not support function calling // region: --- Chat @@ -106,7 +106,7 @@ async fn test_tool_full_flow_ok() -> TestResult<()> { #[tokio::test] async fn test_resolver_auth_ok() -> TestResult<()> { - common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("ZHIPU_API_KEY")).await + common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("ZAI_API_KEY")).await } // endregion: --- Resolver Tests @@ -115,7 +115,7 @@ async fn test_resolver_auth_ok() -> TestResult<()> { #[tokio::test] async fn test_list_models() -> TestResult<()> { - common_tests::common_test_list_models(AdapterKind::Zhipu, "glm-4-plus").await + common_tests::common_test_list_models(AdapterKind::Zai, "glm-4-plus").await } // endregion: --- List diff --git a/tests/tests_p_zhipu_reasoning.rs b/tests/tests_p_zai_reasoning.rs similarity index 96% rename from tests/tests_p_zhipu_reasoning.rs rename to tests/tests_p_zai_reasoning.rs index 9031a409..c405e759 100644 --- a/tests/tests_p_zhipu_reasoning.rs +++ b/tests/tests_p_zai_reasoning.rs @@ -66,7 +66,7 @@ async fn test_chat_stream_capture_content_ok() -> TestResult<()> { #[tokio::test] async fn test_resolver_auth_ok() -> TestResult<()> { - common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("ZHIPU_API_KEY")).await + common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("ZAI_API_KEY")).await } // endregion: --- Resolver Tests @@ -75,7 +75,7 @@ async fn test_resolver_auth_ok() -> TestResult<()> { #[tokio::test] async fn test_list_models() -> TestResult<()> { - common_tests::common_test_list_models(AdapterKind::Zhipu, "glm-z1-flash").await + common_tests::common_test_list_models(AdapterKind::Zai, "glm-z1-flash").await } // endregion: --- List