diff --git a/lib/model_context_protocol/server/client_logger.rb b/lib/model_context_protocol/server/client_logger.rb index ea25b7d..aebb764 100644 --- a/lib/model_context_protocol/server/client_logger.rb +++ b/lib/model_context_protocol/server/client_logger.rb @@ -30,7 +30,7 @@ class Server::ClientLogger Logger::UNKNOWN => "emergency" }.freeze - attr_accessor :transport + attr_reader :transport attr_reader :logger_name def initialize(logger_name: "server", level: "info") diff --git a/lib/model_context_protocol/server/prompt.rb b/lib/model_context_protocol/server/prompt.rb index 18673ad..bd26795 100644 --- a/lib/model_context_protocol/server/prompt.rb +++ b/lib/model_context_protocol/server/prompt.rb @@ -70,20 +70,6 @@ def define(&block) @defined_arguments.concat(definition_dsl.arguments) end - def with_argument(&block) - @defined_arguments ||= [] - - argument_dsl = ArgumentDSL.new - argument_dsl.instance_eval(&block) - - @defined_arguments << { - name: argument_dsl.name, - description: argument_dsl.description, - required: argument_dsl.required, - completion: argument_dsl.completion - } - end - def inherited(subclass) subclass.instance_variable_set(:@name, @name) subclass.instance_variable_set(:@description, @description) diff --git a/lib/model_context_protocol/server/redis_client_proxy.rb b/lib/model_context_protocol/server/redis_client_proxy.rb index b711ce5..aa6df2a 100644 --- a/lib/model_context_protocol/server/redis_client_proxy.rb +++ b/lib/model_context_protocol/server/redis_client_proxy.rb @@ -39,10 +39,6 @@ def hset(key, *args) with_connection { |redis| redis.hset(key, *args) } end - def hgetall(key) - with_connection { |redis| redis.hgetall(key) } - end - def hmget(key, *fields) with_connection { |redis| redis.hmget(key, *fields) } end @@ -71,10 +67,6 @@ def incr(key) with_connection { |redis| redis.incr(key) } end - def decr(key) - with_connection { |redis| redis.decr(key) } - end - def keys(pattern) with_connection { |redis| redis.keys(pattern) } end @@ -105,14 +97,6 @@ def eval(script, keys: [], argv: []) with_connection { |redis| redis.eval(script, keys: keys, argv: argv) } end - def ping - with_connection { |redis| redis.ping } - end - - def flushdb - with_connection { |redis| redis.flushdb } - end - private def with_connection(&block) diff --git a/lib/model_context_protocol/server/redis_config.rb b/lib/model_context_protocol/server/redis_config.rb index 771204b..ee2e5fc 100644 --- a/lib/model_context_protocol/server/redis_config.rb +++ b/lib/model_context_protocol/server/redis_config.rb @@ -36,10 +36,6 @@ def self.stats instance.stats end - def self.pool_manager - instance.manager - end - def initialize reset! end diff --git a/lib/model_context_protocol/server/redis_pool_manager.rb b/lib/model_context_protocol/server/redis_pool_manager.rb index 4a6a6f7..5f44207 100644 --- a/lib/model_context_protocol/server/redis_pool_manager.rb +++ b/lib/model_context_protocol/server/redis_pool_manager.rb @@ -1,6 +1,6 @@ module ModelContextProtocol class Server::RedisPoolManager - attr_reader :pool, :reaper_thread + attr_reader :pool def initialize(redis_url:, pool_size: 20, pool_timeout: 5, ssl_params: nil) @redis_url = redis_url @@ -36,16 +36,6 @@ def shutdown close_pool end - def healthy? - return false unless @pool - - @pool.with do |conn| - conn.ping == "PONG" - end - rescue - false - end - def reap_now return unless @pool diff --git a/lib/model_context_protocol/server/router.rb b/lib/model_context_protocol/server/router.rb index ed3fc22..81e3cac 100644 --- a/lib/model_context_protocol/server/router.rb +++ b/lib/model_context_protocol/server/router.rb @@ -38,7 +38,7 @@ def route(message, request_store: nil, session_id: nil, transport: nil, stream_i result = nil begin - execute_with_context(handler, message, session_context:) do + execute_with_context do context = { jsonrpc_request_id:, request_store:, @@ -322,7 +322,7 @@ def build_capabilities end # Execute handler with appropriate context setup - def execute_with_context(handler, message, session_context:, &block) + def execute_with_context(&block) # Skip ENV manipulation for streamable_http transport because ENV is # global state and modifying it is thread-unsafe in multi-threaded servers. # For stdio transport, apply ENV variables as before (single-threaded). diff --git a/lib/model_context_protocol/server/stdio_transport/request_store.rb b/lib/model_context_protocol/server/stdio_transport/request_store.rb index dabeacb..a358a26 100644 --- a/lib/model_context_protocol/server/stdio_transport/request_store.rb +++ b/lib/model_context_protocol/server/stdio_transport/request_store.rb @@ -56,47 +56,6 @@ def unregister_request(jsonrpc_request_id) @requests.delete(jsonrpc_request_id) end end - - # Get information about a specific request - # - # @param jsonrpc_request_id [String] the unique JSON-RPC request identifier - # @return [Hash, nil] request information or nil if not found - def get_request(jsonrpc_request_id) - @mutex.synchronize do - @requests[jsonrpc_request_id]&.dup - end - end - - # Get all active request IDs - # - # @return [Array] list of active request IDs - def active_requests - @mutex.synchronize do - @requests.keys.dup - end - end - - # Clean up old requests (useful for preventing memory leaks) - # - # @param max_age_seconds [Integer] maximum age of requests to keep - # @return [Array] list of cleaned up request IDs - def cleanup_old_requests(max_age_seconds = 300) - cutoff_time = Time.now - max_age_seconds - removed_ids = [] - - @mutex.synchronize do - @requests.delete_if do |jsonrpc_request_id, data| - if data[:started_at] < cutoff_time - removed_ids << jsonrpc_request_id - true - else - false - end - end - end - - removed_ids - end end end end diff --git a/lib/model_context_protocol/server/streamable_http_transport.rb b/lib/model_context_protocol/server/streamable_http_transport.rb index f559f58..8f6de3d 100644 --- a/lib/model_context_protocol/server/streamable_http_transport.rb +++ b/lib/model_context_protocol/server/streamable_http_transport.rb @@ -785,12 +785,6 @@ def cleanup_session(session_id) @server_request_store.cleanup_session_requests(session_id) end - # Check if this transport instance has any active local streams - # Used to determine if notifications should be queued or delivered immediately - def has_active_streams? - @stream_registry.has_any_local_streams? - end - # Broadcast notification to all active streams on this transport instance # Handles connection errors gracefully and removes disconnected streams def deliver_to_active_streams(notification) diff --git a/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb b/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb index bc21cf3..4dcc8ff 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb @@ -17,19 +17,6 @@ def next_event_id count = @redis.incr(@counter_key) "#{@server_instance}-#{count}" end - - def current_count - count = @redis.get(@counter_key) - count ? count.to_i : 0 - end - - def reset - @redis.set(@counter_key, 0) - end - - def set_count(value) - @redis.set(@counter_key, value.to_i) - end end end end diff --git a/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb b/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb index 97d8aa3..f890d02 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb @@ -22,13 +22,6 @@ def push(notification) end end - def pop - notification_json = @redis.rpop(@queue_key) - return nil unless notification_json - - JSON.parse(notification_json) - end - def pop_all notification_jsons = @redis.multi do |multi| multi.lrange(@queue_key, 0, -1) @@ -41,40 +34,6 @@ def pop_all JSON.parse(notification_json) end end - - def peek_all - notification_jsons = @redis.lrange(@queue_key, 0, -1) - return [] if notification_jsons.empty? - - notification_jsons.reverse.map do |notification_json| - JSON.parse(notification_json) - end - end - - def size - @redis.llen(@queue_key) - end - - def empty? - size == 0 - end - - def clear - @redis.del(@queue_key) - end - - def push_bulk(notifications) - return if notifications.empty? - - notification_jsons = notifications.map(&:to_json) - - @redis.multi do |multi| - notification_jsons.each do |json| - multi.lpush(@queue_key, json) - end - multi.ltrim(@queue_key, 0, @max_size - 1) - end - end end end end diff --git a/lib/model_context_protocol/server/streamable_http_transport/request_store.rb b/lib/model_context_protocol/server/streamable_http_transport/request_store.rb index 3815e85..df816d3 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/request_store.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/request_store.rb @@ -64,17 +64,6 @@ def cancelled?(jsonrpc_request_id) @redis.exists("#{CANCELLED_KEY_PREFIX}#{jsonrpc_request_id}") == 1 end - # Get cancellation information for a request - # - # @param jsonrpc_request_id [String] the unique JSON-RPC request identifier - # @return [Hash, nil] cancellation data or nil if not cancelled - def get_cancellation_info(jsonrpc_request_id) - data = @redis.get("#{CANCELLED_KEY_PREFIX}#{jsonrpc_request_id}") - data ? JSON.parse(data) : nil - rescue JSON::ParserError - nil - end - # Unregister a request (typically called when request completes) # # @param jsonrpc_request_id [String] the unique JSON-RPC request identifier @@ -101,25 +90,6 @@ def unregister_request(jsonrpc_request_id) @redis.del(*keys_to_delete) unless keys_to_delete.empty? end - # Get information about a specific request - # - # @param jsonrpc_request_id [String] the unique JSON-RPC request identifier - # @return [Hash, nil] request information or nil if not found - def get_request(jsonrpc_request_id) - data = @redis.get("#{REQUEST_KEY_PREFIX}#{jsonrpc_request_id}") - data ? JSON.parse(data) : nil - rescue JSON::ParserError - nil - end - - # Check if a request is currently active - # - # @param jsonrpc_request_id [String] the unique JSON-RPC request identifier - # @return [Boolean] true if the request is active, false otherwise - def active?(jsonrpc_request_id) - @redis.exists("#{REQUEST_KEY_PREFIX}#{jsonrpc_request_id}") == 1 - end - # Clean up all requests associated with a session # This is typically called when a session is terminated # @@ -146,79 +116,6 @@ def cleanup_session_requests(session_id) @redis.del(*all_keys) unless all_keys.empty? jsonrpc_request_ids end - - # Get all active request IDs for a specific session - # - # @param session_id [String] the session identifier - # @return [Array] list of active request IDs for the session - def get_session_requests(session_id) - pattern = "#{SESSION_KEY_PREFIX}#{session_id}:*" - request_keys = @redis.keys(pattern) - - request_keys.map do |key| - key.sub("#{SESSION_KEY_PREFIX}#{session_id}:", "") - end - end - - # Get all active request IDs across all sessions - # - # @return [Array] list of all active request IDs - def get_all_active_requests - pattern = "#{REQUEST_KEY_PREFIX}*" - request_keys = @redis.keys(pattern) - - request_keys.map do |key| - key.sub(REQUEST_KEY_PREFIX, "") - end - end - - # Clean up expired requests based on TTL - # This method can be called periodically to ensure cleanup - # - # @return [Integer] number of expired requests cleaned up - def cleanup_expired_requests - active_keys = @redis.keys("#{REQUEST_KEY_PREFIX}*") - expired_count = 0 - key_exists_without_expiration = -1 - key_does_not_exist = -2 - - active_keys.each do |key| - ttl = @redis.ttl(key) - if ttl == key_exists_without_expiration - @redis.expire(key, @ttl) - elsif ttl == key_does_not_exist - expired_count += 1 - end - end - - expired_count - end - - # Refresh the TTL for an active request - # - # @param jsonrpc_request_id [String] the unique JSON-RPC request identifier - # @return [Boolean] true if TTL was refreshed, false if request doesn't exist - def refresh_request_ttl(jsonrpc_request_id) - request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{jsonrpc_request_id}") - return false unless request_data - - @redis.multi do |multi| - multi.expire("#{REQUEST_KEY_PREFIX}#{jsonrpc_request_id}", @ttl) - multi.expire("#{CANCELLED_KEY_PREFIX}#{jsonrpc_request_id}", @ttl) - - begin - data = JSON.parse(request_data) - session_id = data["session_id"] - if session_id - multi.expire("#{SESSION_KEY_PREFIX}#{session_id}:#{jsonrpc_request_id}", @ttl) - end - rescue JSON::ParserError - nil - end - end - - true - end end end end diff --git a/lib/model_context_protocol/server/streamable_http_transport/server_request_store.rb b/lib/model_context_protocol/server/streamable_http_transport/server_request_store.rb index 37eb0e8..2b96118 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/server_request_store.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/server_request_store.rb @@ -112,20 +112,6 @@ def get_expired_requests(timeout_seconds) expired_requests end - # Clean up expired requests based on timeout - # - # @param timeout_seconds [Integer] timeout in seconds - # @return [Array] list of cleaned up request IDs - def cleanup_expired_requests(timeout_seconds) - expired_requests = get_expired_requests(timeout_seconds) - - expired_requests.each do |request_info| - unregister_request(request_info[:request_id]) - end - - expired_requests.map { |r| r[:request_id] } - end - # Unregister a request (typically called when request completes or times out) # # @param request_id [String] the unique JSON-RPC request identifier @@ -176,56 +162,6 @@ def cleanup_session_requests(session_id) @redis.del(*all_keys) unless all_keys.empty? request_ids end - - # Get all pending request IDs for a specific session - # - # @param session_id [String] the session identifier - # @return [Array] list of pending request IDs for the session - def get_session_requests(session_id) - pattern = "#{SESSION_KEY_PREFIX}#{session_id}:*" - request_keys = @redis.keys(pattern) - - request_keys.map do |key| - key.sub("#{SESSION_KEY_PREFIX}#{session_id}:", "") - end - end - - # Get all pending request IDs across all sessions - # - # @return [Array] list of all pending request IDs - def get_all_pending_requests - pattern = "#{REQUEST_KEY_PREFIX}*" - request_keys = @redis.keys(pattern) - - request_keys.map do |key| - key.sub(REQUEST_KEY_PREFIX, "") - end - end - - # Refresh the TTL for a pending request - # - # @param request_id [String] the unique JSON-RPC request identifier - # @return [Boolean] true if TTL was refreshed, false if request doesn't exist - def refresh_request_ttl(request_id) - request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}") - return false unless request_data - - @redis.multi do |multi| - multi.expire("#{REQUEST_KEY_PREFIX}#{request_id}", @ttl) - - begin - data = JSON.parse(request_data) - session_id = data["session_id"] - if session_id - multi.expire("#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}", @ttl) - end - rescue JSON::ParserError - nil - end - end - - true - end end end end diff --git a/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb b/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb index 92b2a24..46c9eb2 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb @@ -1,20 +1,16 @@ require "json" -require "securerandom" module ModelContextProtocol class Server::StreamableHttpTransport class SessionMessageQueue QUEUE_KEY_PREFIX = "session_messages:" - LOCK_KEY_PREFIX = "session_lock:" DEFAULT_TTL = 3600 # 1 hour MAX_MESSAGES = 1000 - LOCK_TIMEOUT = 5 # seconds def initialize(redis_client, session_id, ttl: DEFAULT_TTL) @redis = redis_client @session_id = session_id @queue_key = "#{QUEUE_KEY_PREFIX}#{session_id}" - @lock_key = "#{LOCK_KEY_PREFIX}#{session_id}" @ttl = ttl end @@ -28,20 +24,6 @@ def push_message(message) end end - def push_messages(messages) - return if messages.empty? - - message_jsons = messages.map { |msg| serialize_message(msg) } - - @redis.multi do |multi| - message_jsons.each do |json| - multi.lpush(@queue_key, json) - end - multi.expire(@queue_key, @ttl) - multi.ltrim(@queue_key, 0, MAX_MESSAGES - 1) - end - end - def poll_messages lua_script = <<~LUA local messages = redis.call('lrange', KEYS[1], 0, -1) @@ -58,52 +40,12 @@ def poll_messages [] end - def peek_messages - messages = @redis.lrange(@queue_key, 0, -1) - messages.reverse.map { |json| deserialize_message(json) } - rescue - [] - end - def has_messages? @redis.exists(@queue_key) > 0 rescue false end - def message_count - @redis.llen(@queue_key) - rescue - 0 - end - - def clear - @redis.del(@queue_key) - rescue - end - - def with_lock(timeout: LOCK_TIMEOUT, &block) - lock_id = SecureRandom.hex(16) - - acquired = @redis.set(@lock_key, lock_id, nx: true, ex: timeout) - return false unless acquired - - begin - yield - ensure - lua_script = <<~LUA - if redis.call("get", KEYS[1]) == ARGV[1] then - return redis.call("del", KEYS[1]) - else - return 0 - end - LUA - @redis.eval(lua_script, keys: [@lock_key], argv: [lock_id]) - end - - true - end - private def serialize_message(message) diff --git a/lib/model_context_protocol/server/streamable_http_transport/session_store.rb b/lib/model_context_protocol/server/streamable_http_transport/session_store.rb index 0bafcb4..610a489 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/session_store.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/session_store.rb @@ -45,11 +45,6 @@ def mark_stream_inactive(session_id) end end - def get_session_server(session_id) - server_data = @redis.hget("session:#{session_id}", "stream_server") - server_data ? JSON.parse(server_data) : nil - end - def session_exists?(session_id) @redis.exists("session:#{session_id}") == 1 end @@ -87,37 +82,6 @@ def poll_messages_for_session(session_id) [] end - def get_sessions_with_messages - session_keys = @redis.keys("session:*") - sessions_with_messages = [] - - session_keys.each do |key| - session_id = key.sub("session:", "") - queue = SessionMessageQueue.new(@redis, session_id, ttl: @ttl) - if queue.has_messages? - sessions_with_messages << session_id - end - end - - sessions_with_messages - rescue - [] - end - - def get_all_active_sessions - keys = @redis.keys("session:*") - active_sessions = [] - - keys.each do |key| - session_id = key.sub("session:", "") - if session_has_active_stream?(session_id) - active_sessions << session_id - end - end - - active_sessions - end - def store_registered_handlers(session_id, prompts:, resources:, tools:) @redis.hset("session:#{session_id}", "registered_prompts", prompts.to_json, diff --git a/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb b/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb index 79f5d17..55eeba8 100644 --- a/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb +++ b/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb @@ -43,14 +43,6 @@ def has_local_stream?(session_id) @local_streams.key?(session_id) end - def get_stream_server(session_id) - @redis.get("#{STREAM_KEY_PREFIX}#{session_id}") - end - - def stream_active?(session_id) - @redis.exists("#{STREAM_KEY_PREFIX}#{session_id}") == 1 - end - def refresh_heartbeat(session_id) @redis.multi do |multi| multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl) @@ -88,32 +80,6 @@ def cleanup_expired_streams expired_sessions end - - def get_stale_streams(max_age_seconds = 90) - current_time = Time.now.to_f - stale_streams = [] - - # Get all heartbeat keys - heartbeat_keys = @redis.keys("#{HEARTBEAT_KEY_PREFIX}*") - - return stale_streams if heartbeat_keys.empty? - - # Get all heartbeat timestamps - heartbeat_values = @redis.mget(heartbeat_keys) - - heartbeat_keys.each_with_index do |key, index| - next unless heartbeat_values[index] - - session_id = key.sub(HEARTBEAT_KEY_PREFIX, "") - last_heartbeat = heartbeat_values[index].to_f - - if current_time - last_heartbeat > max_age_seconds - stale_streams << session_id - end - end - - stale_streams - end end end end diff --git a/spec/lib/model_context_protocol/server/redis_client_proxy_spec.rb b/spec/lib/model_context_protocol/server/redis_client_proxy_spec.rb index 1246ca9..c1e3582 100644 --- a/spec/lib/model_context_protocol/server/redis_client_proxy_spec.rb +++ b/spec/lib/model_context_protocol/server/redis_client_proxy_spec.rb @@ -118,16 +118,6 @@ expect(result).to eq(2) end end - - describe "#hgetall" do - it "calls hgetall on the Redis connection" do - hash_data = {"field1" => "value1", "field2" => "value2"} - expect(redis_mock).to receive(:hgetall).with("hash_key").and_return(hash_data) - - result = wrapper.hgetall("hash_key") - expect(result).to eq(hash_data) - end - end end describe "list operations" do @@ -194,15 +184,6 @@ expect(result).to eq(1) end end - - describe "#decr" do - it "calls decr on the Redis connection" do - expect(redis_mock).to receive(:decr).with("counter_key").and_return(0) - - result = wrapper.decr("counter_key") - expect(result).to eq(0) - end - end end describe "key pattern operations" do @@ -249,26 +230,6 @@ end end - describe "utility methods" do - describe "#ping" do - it "calls ping on the Redis connection" do - expect(redis_mock).to receive(:ping).and_return("PONG") - - result = wrapper.ping - expect(result).to eq("PONG") - end - end - - describe "#flushdb" do - it "calls flushdb on the Redis connection" do - expect(redis_mock).to receive(:flushdb).and_return("OK") - - result = wrapper.flushdb - expect(result).to eq("OK") - end - end - end - describe "transaction support" do describe "#multi" do let(:multi_mock) { double("redis_multi") } diff --git a/spec/lib/model_context_protocol/server/redis_pool_manager_spec.rb b/spec/lib/model_context_protocol/server/redis_pool_manager_spec.rb index 7fb3ad7..b2e88d4 100644 --- a/spec/lib/model_context_protocol/server/redis_pool_manager_spec.rb +++ b/spec/lib/model_context_protocol/server/redis_pool_manager_spec.rb @@ -100,31 +100,6 @@ end end - context "with reaper enabled" do - before do - manager.configure_reaper(enabled: true, interval: 1) - end - - it "starts reaper thread" do - manager.start - - aggregate_failures do - expect(manager.reaper_thread).to be_a(Thread) - expect(manager.reaper_thread).to be_alive - end - - manager.shutdown - end - - it "names the reaper thread" do - manager.start - - expect(manager.reaper_thread.name).to eq("MCP-Redis-Reaper") - - manager.shutdown - end - end - context "with ssl_params" do let(:ssl_params) { {verify_mode: OpenSSL::SSL::VERIFY_NONE} } let(:redis_double) { double("redis", close: nil, ping: "PONG") } @@ -194,56 +169,12 @@ manager.configure_reaper(enabled: true) manager.start - thread = manager.reaper_thread - expect(thread).to be_alive + reaper = manager.instance_variable_get(:@reaper_thread) + expect(reaper).to be_alive manager.shutdown - expect(thread.alive?).to be_falsey - end - end - - describe "#healthy?" do - context "when pool does not exist" do - it "returns false" do - expect(manager.healthy?).to be false - end - end - - context "when pool exists" do - before do - manager.start - end - - context "and Redis responds to ping" do - it "returns true" do - expect(manager.healthy?).to be true - end - end - - context "and Redis connection fails" do - before do - redis_mock = double("redis") - allow(redis_mock).to receive(:ping).and_raise(StandardError.new("Connection failed")) - allow(manager.pool).to receive(:with).and_yield(redis_mock) - end - - it "returns false" do - expect(manager.healthy?).to be false - end - end - - context "and Redis returns unexpected response" do - before do - redis_mock = double("redis") - allow(redis_mock).to receive(:ping).and_return("UNEXPECTED") - allow(manager.pool).to receive(:with).and_yield(redis_mock) - end - - it "returns false" do - expect(manager.healthy?).to be false - end - end + expect(reaper.alive?).to be_falsey end end diff --git a/spec/lib/model_context_protocol/server/stdio_transport/request_store_spec.rb b/spec/lib/model_context_protocol/server/stdio_transport/request_store_spec.rb index 4607274..7b6a1ec 100644 --- a/spec/lib/model_context_protocol/server/stdio_transport/request_store_spec.rb +++ b/spec/lib/model_context_protocol/server/stdio_transport/request_store_spec.rb @@ -9,12 +9,11 @@ store.register_request(request_id) - request = store.get_request(request_id) - expect(request).to include( - thread: Thread.current, - cancelled: false, - started_at: be_a(Time) - ) + aggregate_failures do + expect(store.cancelled?(request_id)).to be false + data = store.unregister_request(request_id) + expect(data).to include(thread: Thread.current, cancelled: false, started_at: be_a(Time)) + end end it "registers a request with specific thread" do @@ -23,8 +22,8 @@ store.register_request(request_id, test_thread) - request = store.get_request(request_id) - expect(request[:thread]).to eq(test_thread) + data = store.unregister_request(request_id) + expect(data[:thread]).to eq(test_thread) test_thread.kill test_thread.join @@ -36,8 +35,8 @@ store.register_request(request_id) - request = store.get_request(request_id) - expect(request[:started_at]).to be >= start_time + data = store.unregister_request(request_id) + expect(data[:started_at]).to be >= start_time end end @@ -55,14 +54,13 @@ end it "preserves other request data" do - original_request = store.get_request(request_id) - store.mark_cancelled(request_id) - updated_request = store.get_request(request_id) + data = store.unregister_request(request_id) aggregate_failures do - expect(updated_request[:thread]).to eq(original_request[:thread]) - expect(updated_request[:started_at]).to eq(original_request[:started_at]) + expect(data[:thread]).to eq(Thread.current) + expect(data[:started_at]).to be_a(Time) + expect(data[:cancelled]).to be true end end end @@ -113,14 +111,14 @@ store.register_request(request_id) end - it "removes the request" do + it "removes the request and returns its data" do removed_data = store.unregister_request(request_id) - expect(removed_data).to include( - thread: Thread.current, - cancelled: false - ) - expect(store.get_request(request_id)).to be_nil + aggregate_failures do + expect(removed_data).to include(thread: Thread.current, cancelled: false) + expect(store.cancelled?(request_id)).to be false + expect(store.unregister_request(request_id)).to be_nil + end end end @@ -131,108 +129,20 @@ end end - describe "#get_request" do - let(:request_id) { "test-request" } - - context "when request exists" do - before do - store.register_request(request_id) - end - - it "returns a copy of request data" do - request = store.get_request(request_id) - - expect(request).to include( - thread: Thread.current, - cancelled: false, - started_at: be_a(Time) - ) - - request[:cancelled] = true - expect(store.cancelled?(request_id)).to be false - end - end - - context "when request does not exist" do - it "returns nil" do - expect(store.get_request(request_id)).to be_nil - end - end - end - - describe "#active_requests" do - it "returns empty array when no requests" do - expect(store.active_requests).to eq([]) - end - - it "returns list of active request IDs" do - request_ids = ["request-1", "request-2", "request-3"] - - request_ids.each { |id| store.register_request(id) } - - expect(store.active_requests).to match_array(request_ids) - end - - it "excludes unregistered requests" do - store.register_request("request-1") - store.register_request("request-2") - store.unregister_request("request-1") - - expect(store.active_requests).to eq(["request-2"]) - end - end - - describe "#cleanup_old_requests" do - it "removes requests older than specified age" do - old_request = "old-request" - new_request = "new-request" - - store.register_request(old_request) - store.instance_variable_get(:@requests)[old_request][:started_at] = Time.now - 400 - - store.register_request(new_request) - removed = store.cleanup_old_requests(300) - - aggregate_failures do - expect(removed).to include(old_request) - expect(store.get_request(old_request)).to be_nil - expect(store.get_request(new_request)).not_to be_nil - end - end - - it "returns list of removed request IDs" do - store.register_request("request-1") - store.register_request("request-2") - - old_time = Time.now - 400 - requests = store.instance_variable_get(:@requests) - requests["request-1"][:started_at] = old_time - requests["request-2"][:started_at] = old_time - - removed = store.cleanup_old_requests(300) - - expect(removed).to match_array(["request-1", "request-2"]) - end - end - describe "thread safety" do it "handles concurrent access safely" do threads = [] - request_ids = [] + request_ids = (0...10).map { |i| "concurrent-request-#{i}" } 10.times do |i| threads << Thread.new do - request_id = "concurrent-request-#{i}" - request_ids << request_id - store.register_request(request_id) - store.mark_cancelled(request_id) if i.even? + store.register_request(request_ids[i]) + store.mark_cancelled(request_ids[i]) if i.even? end end threads.each(&:join) - expect(store.active_requests.size).to eq(10) - cancelled_count = request_ids.count { |id| store.cancelled?(id) } expect(cancelled_count).to eq(5) end diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/event_counter_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/event_counter_spec.rb index f256183..fcd6a54 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/event_counter_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/event_counter_spec.rb @@ -80,80 +80,6 @@ end end - describe "#current_count" do - it "returns 0 for new counter" do - expect(counter.current_count).to eq(0) - end - - it "returns current count after increments" do - 3.times { counter.next_event_id } - - expect(counter.current_count).to eq(3) - end - - it "returns correct count for existing counter" do - redis.set("event_counter:#{server_instance}", "42") - - expect(counter.current_count).to eq(42) - end - - it "handles non-existent counter gracefully" do - redis.del("event_counter:#{server_instance}") - - expect(counter.current_count).to eq(0) - end - end - - describe "#reset" do - it "resets counter to 0" do - 5.times { counter.next_event_id } - - counter.reset - - aggregate_failures do - expect(counter.current_count).to eq(0) - expect(counter.next_event_id).to eq("#{server_instance}-1") - end - end - - it "works on already zero counter" do - counter.reset - - expect(counter.current_count).to eq(0) - end - end - - describe "#set_count" do - it "sets the counter to a specific value" do - counter.set_count(100) - - aggregate_failures do - expect(counter.current_count).to eq(100) - expect(counter.next_event_id).to eq("#{server_instance}-101") - end - end - - it "handles string input" do - counter.set_count("50") - - expect(counter.current_count).to eq(50) - end - - it "handles zero value" do - 5.times { counter.next_event_id } - - counter.set_count(0) - - expect(counter.current_count).to eq(0) - end - - it "handles negative values by converting to positive" do - counter.set_count(-10) - - expect(counter.current_count).to eq(-10) - end - end - describe "thread safety" do it "maintains consistency under concurrent access" do counters = 5.times.map { described_class.new(redis, server_instance) } @@ -175,7 +101,6 @@ aggregate_failures do expect(ids.size).to eq(50) expect(ids.uniq.size).to eq(50) - expect(counter.current_count).to eq(50) end end end @@ -193,9 +118,6 @@ expect(id1).to eq("#{server_instance}-1") expect(id2).to eq("#{server2}-1") expect(id3).to eq("#{server_instance}-2") - - expect(counter.current_count).to eq(2) - expect(counter2.current_count).to eq(1) end end end diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/message_poller_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/message_poller_spec.rb index 0ebcb61..7de814a 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/message_poller_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/message_poller_spec.rb @@ -410,12 +410,12 @@ before do allow(stream_registry).to receive(:get_all_local_streams).and_return({session_id => stream}) allow(stream_registry).to receive(:get_local_stream).with(session_id).and_return(stream) + end + it "polls and delivers queued messages" do queue = ModelContextProtocol::Server::StreamableHttpTransport::SessionMessageQueue.new(redis, session_id) messages.each { |msg| queue.push_message(msg) } - end - it "polls and delivers queued messages" do aggregate_failures do expect(message_delivery_block).to receive(:call).with(stream, messages[0]) expect(message_delivery_block).to receive(:call).with(stream, messages[1]) @@ -423,14 +423,10 @@ poller.send(:poll_and_deliver_messages) - queue = ModelContextProtocol::Server::StreamableHttpTransport::SessionMessageQueue.new(redis, session_id) expect(queue.has_messages?).to eq(false) end it "handles polling with no messages gracefully" do - queue = ModelContextProtocol::Server::StreamableHttpTransport::SessionMessageQueue.new(redis, session_id) - queue.clear - aggregate_failures do expect(message_delivery_block).not_to receive(:call) expect { poller.send(:poll_and_deliver_messages) }.not_to raise_error diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/notification_queue_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/notification_queue_spec.rb index 2932947..f54a7a6 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/notification_queue_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/notification_queue_spec.rb @@ -41,10 +41,7 @@ def serialized_array(notifications) it "adds a notification to the queue" do queue.push(sample_notification) - aggregate_failures do - expect(queue.size).to eq(1) - expect(queue.pop).to eq(serialized(sample_notification)) - end + expect(queue.pop_all).to eq([serialized(sample_notification)]) end it "maintains FIFO order" do @@ -54,10 +51,7 @@ def serialized_array(notifications) queue.push(notification1) queue.push(notification2) - aggregate_failures do - expect(queue.pop).to eq(serialized(notification1)) - expect(queue.pop).to eq(serialized(notification2)) - end + expect(queue.pop_all).to eq([serialized(notification1), serialized(notification2)]) end it "enforces max_size by removing oldest items" do @@ -65,37 +59,12 @@ def serialized_array(notifications) queue.push({method: "notification_#{i}"}) end - expect(queue.size).to eq(max_size) - - popped = queue.pop - expect(popped["method"]).to eq("notification_2") - end - end - - describe "#pop" do - it "returns nil when queue is empty" do - expect(queue.pop).to be_nil - end - - it "removes and returns the oldest notification" do - queue.push(sample_notification) - - result = queue.pop - + results = queue.pop_all aggregate_failures do - expect(result).to eq(serialized(sample_notification)) - expect(queue.size).to eq(0) + expect(results.length).to eq(max_size) + expect(results.first["method"]).to eq("notification_2") end end - - it "maintains FIFO order across multiple operations" do - notifications = 3.times.map { |i| {method: "notification_#{i}"} } - notifications.each { |n| queue.push(n) } - - results = 3.times.map { queue.pop } - - expect(results).to eq(serialized_array(notifications)) - end end describe "#pop_all" do @@ -111,7 +80,7 @@ def serialized_array(notifications) aggregate_failures do expect(results).to eq(serialized_array(notifications)) - expect(queue.size).to eq(0) + expect(queue.pop_all).to eq([]) end end @@ -125,132 +94,7 @@ def serialized_array(notifications) aggregate_failures do expect(results.size).to be >= 1 - expect(queue.size).to eq(0) - end - end - end - - describe "#peek_all" do - it "returns empty array when queue is empty" do - expect(queue.peek_all).to eq([]) - end - - it "returns all notifications without removing them" do - notifications = 3.times.map { |i| {method: "notification_#{i}"} } - notifications.each { |n| queue.push(n) } - - results = queue.peek_all - - aggregate_failures do - expect(results).to eq(serialized_array(notifications)) - expect(queue.size).to eq(3) - end - end - - it "returns notifications in FIFO order" do - notification1 = {method: "first"} - notification2 = {method: "second"} - - queue.push(notification1) - queue.push(notification2) - - expect(queue.peek_all).to eq([serialized(notification1), serialized(notification2)]) - end - end - - describe "#size" do - it "returns 0 for empty queue" do - expect(queue.size).to eq(0) - end - - it "returns correct size after adding items" do - 3.times { queue.push(sample_notification) } - - expect(queue.size).to eq(3) - end - - it "updates correctly after popping items" do - 3.times { queue.push(sample_notification) } - queue.pop - - expect(queue.size).to eq(2) - end - end - - describe "#empty?" do - it "returns true for empty queue" do - expect(queue.empty?).to be true - end - - it "returns false for non-empty queue" do - queue.push(sample_notification) - - expect(queue.empty?).to be false - end - end - - describe "#clear" do - it "removes all notifications from the queue" do - 3.times { queue.push(sample_notification) } - - queue.clear - - aggregate_failures do - expect(queue.size).to eq(0) - expect(queue.empty?).to be true - end - end - - it "works on empty queue without error" do - aggregate_failures do - expect { queue.clear }.not_to raise_error - expect(queue.size).to eq(0) - end - end - end - - describe "#push_bulk" do - it "adds multiple notifications at once" do - notifications = 3.times.map { |i| {method: "notification_#{i}"} } - - queue.push_bulk(notifications) - - aggregate_failures do - expect(queue.size).to eq(3) - expect(queue.pop_all).to eq(serialized_array(notifications)) - end - end - - it "maintains FIFO order for bulk operations" do - batch1 = 2.times.map { |i| {method: "batch1_#{i}"} } - batch2 = 2.times.map { |i| {method: "batch2_#{i}"} } - - queue.push_bulk(batch1) - queue.push_bulk(batch2) - - results = queue.pop_all - expected = serialized_array(batch1 + batch2) - - expect(results).to eq(expected) - end - - it "enforces max_size for bulk operations" do - notifications = (max_size + 2).times.map { |i| {method: "notification_#{i}"} } - - queue.push_bulk(notifications) - - expect(queue.size).to eq(max_size) - - results = queue.pop_all - expected = serialized_array(notifications.last(max_size)) - - expect(results).to eq(expected) - end - - it "handles empty array gracefully" do - aggregate_failures do - expect { queue.push_bulk([]) }.not_to raise_error - expect(queue.size).to eq(0) + expect(queue.pop_all).to eq([]) end end end @@ -271,10 +115,10 @@ def serialized_array(notifications) } queue.push(complex_notification) - result = queue.pop + results = queue.pop_all expected = JSON.parse(complex_notification.to_json) - expect(result).to eq(expected) + expect(results).to eq([expected]) end end end diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/request_store_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/request_store_spec.rb index aaf6f4b..70e15ba 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/request_store_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/request_store_spec.rb @@ -117,32 +117,6 @@ end end - describe "#get_cancellation_info" do - let(:request_id) { "test-request-123" } - let(:reason) { "User requested cancellation" } - - context "when request is not cancelled" do - it "returns nil" do - expect(store.get_cancellation_info(request_id)).to be_nil - end - end - - context "when request is cancelled" do - before do - store.mark_cancelled(request_id, reason) - end - - it "returns cancellation data" do - info = store.get_cancellation_info(request_id) - - expect(info).to include( - "cancelled_at" => be_a(Numeric), - "reason" => reason - ) - end - end - end - describe "#unregister_request" do let(:request_id) { "test-request-123" } let(:session_id) { "session-456" } @@ -187,53 +161,6 @@ end end - describe "#get_request" do - let(:request_id) { "test-request-123" } - let(:session_id) { "session-456" } - - context "when request exists" do - before do - store.register_request(request_id, session_id) - end - - it "returns request data" do - request = store.get_request(request_id) - - expect(request).to include( - "session_id" => session_id, - "server_instance" => server_instance, - "started_at" => be_a(Numeric) - ) - end - end - - context "when request does not exist" do - it "returns nil" do - expect(store.get_request(request_id)).to be_nil - end - end - end - - describe "#active?" do - let(:request_id) { "test-request-123" } - - context "when request is active" do - before do - store.register_request(request_id) - end - - it "returns true" do - expect(store.active?(request_id)).to be true - end - end - - context "when request is not active" do - it "returns false" do - expect(store.active?(request_id)).to be false - end - end - end - describe "#cleanup_session_requests" do let(:session_id) { "session-456" } let(:request_ids) { ["req-1", "req-2", "req-3"] } @@ -267,81 +194,6 @@ end end - describe "#get_session_requests" do - let(:session_id) { "session-456" } - let(:request_ids) { ["req-1", "req-2", "req-3"] } - - before do - request_ids.each { |req_id| store.register_request(req_id, session_id) } - store.register_request("other-req", "other-session") - end - - it "returns only requests for the specified session" do - session_requests = store.get_session_requests(session_id) - expect(session_requests).to match_array(request_ids) - end - - context "when session has no requests" do - it "returns empty array" do - session_requests = store.get_session_requests("nonexistent-session") - expect(session_requests).to eq([]) - end - end - end - - describe "#get_all_active_requests" do - let(:request_ids) { ["req-1", "req-2", "req-3"] } - - before do - request_ids.each { |req_id| store.register_request(req_id, "session-#{req_id}") } - end - - it "returns all active request IDs" do - active_requests = store.get_all_active_requests - expect(active_requests).to match_array(request_ids) - end - - context "when no requests are active" do - it "returns empty array" do - mock_redis.flushdb - expect(store.get_all_active_requests).to eq([]) - end - end - end - - describe "#refresh_request_ttl" do - let(:request_id) { "test-request-123" } - let(:session_id) { "session-456" } - - context "when request exists" do - before do - store.register_request(request_id, session_id) - store.mark_cancelled(request_id, "test") - end - - it "refreshes TTL for all related keys" do - mock_redis.expire("request:active:#{request_id}", 10) - mock_redis.expire("request:cancelled:#{request_id}", 10) - mock_redis.expire("request:session:#{session_id}:#{request_id}", 10) - - result = store.refresh_request_ttl(request_id) - - aggregate_failures do - expect(result).to be true - expect(mock_redis.ttl("request:active:#{request_id}")).to be > 10 - expect(mock_redis.ttl("request:cancelled:#{request_id}")).to be > 10 - expect(mock_redis.ttl("request:session:#{session_id}:#{request_id}")).to be > 10 - end - end - end - - context "when request does not exist" do - it "returns false" do - expect(store.refresh_request_ttl(request_id)).to be false - end - end - end - describe "TTL behavior" do let(:request_id) { "test-request-123" } let(:session_id) { "session-456" } diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/server_request_store_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/server_request_store_spec.rb index 1948d94..4145c8f 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/server_request_store_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/server_request_store_spec.rb @@ -167,25 +167,6 @@ end end - describe "#cleanup_expired_requests" do - let(:request_id) { "ping-test-123" } - let(:session_id) { "session-456" } - - before do - allow(Time).to receive(:now).and_return(Time.at(1000)) - store.register_request(request_id, session_id, type: :ping) - end - - it "removes expired requests and returns their IDs" do - allow(Time).to receive(:now).and_return(Time.at(1020)) - - cleaned_ids = store.cleanup_expired_requests(10) - - expect(cleaned_ids).to eq([request_id]) - expect(store.pending?(request_id)).to be false - end - end - describe "#cleanup_session_requests" do let(:session_id) { "session-456" } let(:request_id1) { "ping-test-123" } @@ -213,58 +194,4 @@ expect(keys).to be_empty end end - - describe "#get_session_requests" do - let(:session_id) { "session-456" } - let(:request_id1) { "ping-test-123" } - let(:request_id2) { "ping-test-456" } - - before do - store.register_request(request_id1, session_id, type: :ping) - store.register_request(request_id2, session_id, type: :ping) - end - - it "returns all request IDs for a session" do - request_ids = store.get_session_requests(session_id) - - expect(request_ids).to contain_exactly(request_id1, request_id2) - end - end - - describe "#get_all_pending_requests" do - let(:request_id1) { "ping-test-123" } - let(:request_id2) { "ping-test-456" } - - before do - store.register_request(request_id1, "session-1", type: :ping) - store.register_request(request_id2, "session-2", type: :ping) - end - - it "returns all pending request IDs" do - request_ids = store.get_all_pending_requests - - expect(request_ids).to contain_exactly(request_id1, request_id2) - end - end - - describe "#refresh_request_ttl" do - let(:request_id) { "ping-test-123" } - let(:session_id) { "session-456" } - - context "when request exists" do - before do - store.register_request(request_id, session_id, type: :ping) - end - - it "returns true and refreshes TTL" do - expect(store.refresh_request_ttl(request_id)).to be true - end - end - - context "when request does not exist" do - it "returns false" do - expect(store.refresh_request_ttl("non-existent")).to be false - end - end - end end diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/session_message_queue_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/session_message_queue_spec.rb index 55b6ce6..a10273e 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/session_message_queue_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/session_message_queue_spec.rb @@ -17,7 +17,6 @@ expect(queue.instance_variable_get(:@session_id)).to eq(session_id) expect(queue.instance_variable_get(:@ttl)).to eq(300) expect(queue.instance_variable_get(:@queue_key)).to eq("session_messages:#{session_id}") - expect(queue.instance_variable_get(:@lock_key)).to eq("session_lock:#{session_id}") end end @@ -33,10 +32,7 @@ it "adds a message to the queue" do queue.push_message(message) - aggregate_failures do - expect(queue.has_messages?).to eq(true) - expect(queue.message_count).to eq(1) - end + expect(queue.has_messages?).to eq(true) end it "serializes hash messages as JSON" do @@ -69,43 +65,7 @@ queue.push_message({"test" => i}) end - expect(queue.message_count).to eq(1000) - end - end - - describe "#push_messages" do - let(:messages) do - [ - {"method" => "test1", "params" => {"data" => "hello1"}}, - {"method" => "test2", "params" => {"data" => "hello2"}}, - {"method" => "test3", "params" => {"data" => "hello3"}} - ] - end - - it "adds multiple messages at once" do - queue.push_messages(messages) - - expect(queue.message_count).to eq(3) - end - - it "handles empty array gracefully" do - queue.push_messages([]) - - expect(queue.has_messages?).to eq(false) - end - - it "maintains FIFO order for bulk operations" do - queue.push_messages(messages) - - result = queue.poll_messages - expect(result).to eq(messages) - end - - it "enforces max_size for bulk operations" do - large_batch = Array.new(1010) { |i| {"test" => i} } - queue.push_messages(large_batch) - - expect(queue.message_count).to eq(1000) + expect(queue.poll_messages.length).to eq(1000) end end @@ -127,10 +87,7 @@ it "clears the queue after polling" do queue.poll_messages - aggregate_failures do - expect(queue.has_messages?).to eq(false) - expect(queue.message_count).to eq(0) - end + expect(queue.has_messages?).to eq(false) end it "is atomic - subsequent polls return empty" do @@ -159,46 +116,6 @@ end end - describe "#peek_messages" do - let(:message1) { {"method" => "test1", "params" => {"data" => "hello1"}} } - let(:message2) { {"method" => "test2", "params" => {"data" => "hello2"}} } - - context "when queue has messages" do - before do - queue.push_message(message1) - queue.push_message(message2) - end - - it "returns all messages without removing them" do - messages = queue.peek_messages - - aggregate_failures do - expect(messages).to eq([message1, message2]) - expect(queue.message_count).to eq(2) - end - end - - it "returns messages in FIFO order" do - messages = queue.peek_messages - expect(messages).to eq([message1, message2]) - end - end - - context "when queue is empty" do - it "returns empty array" do - messages = queue.peek_messages - expect(messages).to eq([]) - end - end - - it "handles Redis errors gracefully" do - allow(redis).to receive(:lrange).and_raise(MockRedis::ConnectionError) - - messages = queue.peek_messages - expect(messages).to eq([]) - end - end - describe "#has_messages?" do it "returns false when queue is empty" do expect(queue.has_messages?).to eq(false) @@ -216,119 +133,6 @@ end end - describe "#message_count" do - it "returns 0 for empty queue" do - expect(queue.message_count).to eq(0) - end - - it "returns correct count after adding messages" do - queue.push_message({"test1" => "message1"}) - queue.push_message({"test2" => "message2"}) - - expect(queue.message_count).to eq(2) - end - - it "updates correctly after polling messages" do - queue.push_message({"test" => "message"}) - expect(queue.message_count).to eq(1) - - queue.poll_messages - expect(queue.message_count).to eq(0) - end - - it "handles Redis errors gracefully" do - allow(redis).to receive(:llen).and_raise(MockRedis::ConnectionError) - - expect(queue.message_count).to eq(0) - end - end - - describe "#clear" do - before do - queue.push_message({"test1" => "message1"}) - queue.push_message({"test2" => "message2"}) - end - - it "removes all messages from the queue" do - expect(queue.message_count).to eq(2) - - queue.clear - - aggregate_failures do - expect(queue.message_count).to eq(0) - expect(queue.has_messages?).to eq(false) - end - end - - it "works on empty queue without error" do - queue.clear - queue.clear - - expect(queue.message_count).to eq(0) - end - - it "handles Redis errors gracefully" do - allow(redis).to receive(:del).and_raise(MockRedis::ConnectionError) - - expect { queue.clear }.not_to raise_error - end - end - - describe "#with_lock" do - it "acquires and releases lock successfully" do - result = queue.with_lock do - "locked operation" - end - - expect(result).to eq(true) - end - - it "executes the block when lock is acquired" do - executed = false - - queue.with_lock do - executed = true - end - - expect(executed).to eq(true) - end - - it "returns false when lock cannot be acquired" do - redis.set("session_lock:#{session_id}", "other-lock-id", nx: true, ex: 5) - - result = queue.with_lock(timeout: 1) do - "should not execute" - end - - expect(result).to eq(false) - end - - it "releases lock even if block raises error" do - expect do - queue.with_lock do - raise "test error" - end - end.to raise_error("test error") - - lock_key = "session_lock:#{session_id}" - expect(redis.exists(lock_key)).to eq(0) - - result = queue.with_lock do - "second attempt" - end - - expect(result).to eq(true) - end - - it "only releases lock if it owns it" do - result = queue.with_lock do - "locked operation" - end - - expect(result).to eq(true) - end - end - describe "JSON serialization" do it "properly serializes and deserializes complex data structures" do complex_message = { diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/session_store_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/session_store_spec.rb index e3ee0de..facc37b 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/session_store_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/session_store_spec.rb @@ -98,10 +98,7 @@ it "marks session stream as active" do session_store.mark_stream_active(session_id, stream_server) - aggregate_failures do - expect(session_store.session_has_active_stream?(session_id)).to eq(true) - expect(session_store.get_session_server(session_id)).to eq(stream_server) - end + expect(session_store.session_has_active_stream?(session_id)).to eq(true) end it "updates last_activity timestamp" do @@ -145,10 +142,7 @@ it "marks session stream as inactive" do session_store.mark_stream_inactive(session_id) - aggregate_failures do - expect(session_store.session_has_active_stream?(session_id)).to eq(false) - expect(session_store.get_session_server(session_id)).to be_nil - end + expect(session_store.session_has_active_stream?(session_id)).to eq(false) end it "updates last_activity timestamp" do @@ -195,30 +189,6 @@ end end - describe "#get_session_server" do - before { session_store.create_session(session_id, session_data) } - - context "when session has no active stream" do - it "returns nil" do - expect(session_store.get_session_server(session_id)).to be_nil - end - end - - context "when session has active stream" do - before { session_store.mark_stream_active(session_id, server_instance) } - - it "returns the server instance" do - expect(session_store.get_session_server(session_id)).to eq(server_instance) - end - end - - context "when session does not exist" do - it "returns nil" do - expect(session_store.get_session_server("nonexistent")).to be_nil - end - end - end - describe "#get_session_context" do before { session_store.create_session(session_id, session_data) } @@ -336,81 +306,6 @@ end end - describe "#get_sessions_with_messages" do - let(:session_id_1) { SecureRandom.uuid } - let(:session_id_2) { SecureRandom.uuid } - let(:session_id_3) { SecureRandom.uuid } - - before do - session_store.create_session(session_id_1, session_data) - session_store.create_session(session_id_2, session_data) - session_store.create_session(session_id_3, session_data) - end - - it "returns only sessions that have pending messages" do - session_store.queue_message_for_session(session_id_1, {"test" => "msg1"}) - session_store.queue_message_for_session(session_id_3, {"test" => "msg3"}) - - sessions = session_store.get_sessions_with_messages - - aggregate_failures do - expect(sessions).to contain_exactly(session_id_1, session_id_3) - expect(sessions).not_to include(session_id_2) - end - end - - it "returns empty array when no sessions have messages" do - sessions = session_store.get_sessions_with_messages - expect(sessions).to eq([]) - end - end - - describe "#get_all_active_sessions" do - let(:session_id_1) { SecureRandom.uuid } - let(:session_id_2) { SecureRandom.uuid } - let(:session_id_3) { SecureRandom.uuid } - - before do - session_store.create_session(session_id_1, session_data) - session_store.create_session(session_id_2, session_data) - session_store.create_session(session_id_3, session_data) - - session_store.mark_stream_active(session_id_1, "server-1") - session_store.mark_stream_active(session_id_2, "server-2") - end - - it "returns only sessions with active streams" do - active_sessions = session_store.get_all_active_sessions - - aggregate_failures do - expect(active_sessions).to contain_exactly(session_id_1, session_id_2) - expect(active_sessions).not_to include(session_id_3) - end - end - - context "when no sessions have active streams" do - before do - session_store.mark_stream_inactive(session_id_1) - session_store.mark_stream_inactive(session_id_2) - end - - it "returns empty array" do - active_sessions = session_store.get_all_active_sessions - expect(active_sessions).to eq([]) - end - end - - context "when no sessions exist" do - it "returns empty array" do - fresh_redis = MockRedis.new - fresh_session_store = described_class.new(fresh_redis, ttl: 300) - - active_sessions = fresh_session_store.get_all_active_sessions - expect(active_sessions).to eq([]) - end - end - end - describe "#store_registered_handlers" do before { session_store.create_session(session_id, session_data) } @@ -534,18 +429,12 @@ end session_store.mark_stream_active(session_id_1, server_2) - aggregate_failures do - expect(session_store.session_has_active_stream?(session_id_1)).to eq(true) - expect(session_store.get_session_server(session_id_1)).to eq(server_2) - end + expect(session_store.session_has_active_stream?(session_id_1)).to eq(true) expect(session_store.queue_message_for_session(session_id_1, {"test" => "message"})).to eq(true) session_store.mark_stream_inactive(session_id_1) - aggregate_failures do - expect(session_store.session_has_active_stream?(session_id_1)).to eq(false) - expect(session_store.get_session_server(session_id_1)).to be_nil - end + expect(session_store.session_has_active_stream?(session_id_1)).to eq(false) session_store.cleanup_session(session_id_1) expect(session_store.session_exists?(session_id_1)).to eq(false) @@ -556,7 +445,7 @@ session_store.mark_stream_active(session_id_1, server_2) aggregate_failures do - expect(session_store.get_session_server(session_id_1)).to eq(server_2) + expect(session_store.session_has_active_stream?(session_id_1)).to eq(true) expect(session_store.queue_message_for_session(session_id_1, {"from" => "server_1"})).to eq(true) end end diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport/stream_registry_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport/stream_registry_spec.rb index 3f45105..a959d52 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport/stream_registry_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport/stream_registry_spec.rb @@ -84,28 +84,6 @@ end end - describe "#get_stream_server" do - it "returns nil when stream doesn't exist in Redis" do - expect(registry.get_stream_server(session_id)).to be_nil - end - - it "returns the server instance when stream exists" do - registry.register_stream(session_id, mock_stream) - expect(registry.get_stream_server(session_id)).to eq(server_instance) - end - end - - describe "#stream_active?" do - it "returns false when stream doesn't exist in Redis" do - expect(registry.stream_active?(session_id)).to be false - end - - it "returns true when stream exists in Redis" do - registry.register_stream(session_id, mock_stream) - expect(registry.stream_active?(session_id)).to be true - end - end - describe "#refresh_heartbeat" do before do registry.register_stream(session_id, mock_stream) @@ -190,33 +168,4 @@ end end end - - describe "#get_stale_streams" do - it "returns streams with old heartbeats" do - registry.register_stream(session_id, mock_stream) - - old_time = Time.now.to_f - 120 - redis.set("stream:heartbeat:#{session_id}", old_time) - - stale_streams = registry.get_stale_streams(90) - - expect(stale_streams).to contain_exactly(session_id) - end - - it "returns empty array when no streams are stale" do - registry.register_stream(session_id, mock_stream) - - stale_streams = registry.get_stale_streams(90) - - expect(stale_streams).to be_empty - end - - it "handles missing heartbeat values gracefully" do - redis.set("stream:active:#{session_id}", server_instance) - - stale_streams = registry.get_stale_streams(90) - - expect(stale_streams).to be_empty - end - end end diff --git a/spec/lib/model_context_protocol/server/streamable_http_transport_spec.rb b/spec/lib/model_context_protocol/server/streamable_http_transport_spec.rb index 0c041a0..df8eabe 100644 --- a/spec/lib/model_context_protocol/server/streamable_http_transport_spec.rb +++ b/spec/lib/model_context_protocol/server/streamable_http_transport_spec.rb @@ -1153,9 +1153,9 @@ def rack_env }) notification_queue = transport.instance_variable_get(:@notification_queue) - expect(notification_queue.size).to eq(1) + queued_notifications = notification_queue.pop_all + expect(queued_notifications.length).to eq(1) - queued_notifications = notification_queue.peek_all queued_notification = queued_notifications.first aggregate_failures do expect(queued_notification["jsonrpc"]).to eq("2.0") @@ -1377,12 +1377,11 @@ def rack_env expect(request_store.cancelled?(request_id)).to be true end - it "stores cancellation reason" do + it "stores cancellation in request store" do transport.send(:handle_cancellation, cancellation_message, session_id) request_store = transport.instance_variable_get(:@request_store) - cancellation_info = request_store.get_cancellation_info(request_id) - expect(cancellation_info["reason"]).to eq(reason) + expect(request_store.cancelled?(request_id)).to be true end end @@ -1428,12 +1427,11 @@ def rack_env request_store.register_request(request_id, session_id) end - it "marks request as cancelled with nil reason" do + it "marks request as cancelled" do transport.send(:handle_cancellation, cancellation_without_reason, session_id) request_store = transport.instance_variable_get(:@request_store) - cancellation_info = request_store.get_cancellation_info(request_id) - expect(cancellation_info["reason"]).to be_nil + expect(request_store.cancelled?(request_id)).to be true end end @@ -1470,18 +1468,11 @@ def rack_env }) end - it "registers and unregisters requests during processing" do - transport.handle(env: rack_env) - - request_store = transport.instance_variable_get(:@request_store) - expect(request_store.active?(request_id)).to be false - end - - it "cleans up request from store after processing" do + it "unregisters requests after processing" do transport.handle(env: rack_env) request_store = transport.instance_variable_get(:@request_store) - expect(request_store.active?(request_id)).to be false + expect(request_store.cancelled?(request_id)).to be false end it "provides cancellation context to handlers" do @@ -2093,14 +2084,14 @@ def rack_env transport.send_notification("notifications/progress", {progress: 50}, session_id: target_session_id) notification_queue = transport.instance_variable_get(:@notification_queue) - expect(notification_queue.size).to eq(1) + expect(notification_queue.pop_all.length).to eq(1) end it "queues notification if targeted stream does not exist" do transport.send_notification("notifications/progress", {progress: 50}, session_id: "nonexistent-session") notification_queue = transport.instance_variable_get(:@notification_queue) - expect(notification_queue.size).to eq(1) + expect(notification_queue.pop_all.length).to eq(1) end end @@ -2195,16 +2186,19 @@ def rack_env it "registers ping request in server request store" do transport.send(:send_ping_to_stream, mock_stream, session_id) - pending_requests = server_request_store.get_all_pending_requests - expect(pending_requests.size).to eq(1) - expect(pending_requests.first).to start_with("ping-") + pending_keys = mock_redis.keys("server_request:pending:ping-*") + expect(pending_keys.size).to eq(1) + + ping_id = pending_keys.first.sub("server_request:pending:", "") + expect(ping_id).to start_with("ping-") end it "associates ping with session_id" do transport.send(:send_ping_to_stream, mock_stream, session_id) - pending_requests = server_request_store.get_all_pending_requests - request_info = server_request_store.get_request(pending_requests.first) + pending_keys = mock_redis.keys("server_request:pending:ping-*") + ping_id = pending_keys.first.sub("server_request:pending:", "") + request_info = server_request_store.get_request(ping_id) aggregate_failures do expect(request_info["session_id"]).to eq(session_id) @@ -2255,8 +2249,9 @@ def rack_env transport.send(:monitor_streams) # Simulate receiving ping response - pending_requests = server_request_store.get_all_pending_requests - pending_requests.each do |request_id| + pending_keys = mock_redis.keys("server_request:pending:ping-*") + pending_keys.each do |key| + request_id = key.sub("server_request:pending:", "") server_request_store.mark_completed(request_id) end