diff --git a/include/httpd.hrl b/include/httpd.hrl index 791b784..0ac3b2a 100644 --- a/include/httpd.hrl +++ b/include/httpd.hrl @@ -18,6 +18,7 @@ -define(INTERNAL_SERVER_ERROR, 500). -define(BAD_REQUEST, 400). -define(NOT_FOUND, 404). +-define(NOT_ALLOWED, 405). -define(OK, 200). -define(CONTINUE, 100). -define(SWITCHING_PROTOCOLS, 101). diff --git a/src/gen_tcp_server.erl b/src/gen_tcp_server.erl index 8acaf10..7802736 100644 --- a/src/gen_tcp_server.erl +++ b/src/gen_tcp_server.erl @@ -32,14 +32,23 @@ -callback handle_receive(Socket :: term(), Packet :: binary(), State :: term()) -> {reply, Packet :: iolist(), NewState :: term()} | {noreply, NewState :: term()} | {close, Packet :: iolist()} | close. --callback handle_tcp_closed(Socket :: term(), State :: term()) -> ok. +-callback handle_tcp_closed(Socket :: term(), State :: term()) -> NewState :: term(). + +%% Optional callback: invoked for messages that gen_tcp_server does not handle +%% itself (e.g. internal timer messages). Return {noreply, NewState} to keep +%% the connection open, or {close, Socket, NewState} to close a specific socket. +-callback handle_info(Msg :: term(), State :: term()) -> + {noreply, NewState :: term()} | {close, Socket :: term(), NewState :: term()}. +-optional_callbacks([handle_info/2]). % -define(TRACE_ENABLED, true). -include_lib("atomvm_httpd/include/trace.hrl"). -record(state, { handler, - handler_state + handler_state, + connections = #{}, + max_connections = 0 }). -define(DEFAULT_BIND_OPTIONS, #{ @@ -78,17 +87,21 @@ stop(Server) -> %% @hidden init({BindOptions, SocketOptions, Handler, Args}) -> Self = self(), + MaxConnections = maps:get(max_connections, SocketOptions, 0), + %% Strip max_connections before passing to set_socket_options/2 so that + %% socket:setopt/3 is never called with an unknown option key. + CleanSocketOptions = maps:remove(max_connections, SocketOptions), case socket:open(inet, stream, tcp) of {ok, Socket} -> - ok = set_socket_options(Socket, SocketOptions), + ok = set_socket_options(Socket, CleanSocketOptions), case socket:bind(Socket, BindOptions) of ok -> case socket:listen(Socket) of ok -> - spawn(fun() -> accept(Self, Socket) end), + spawn_link(fun() -> accept(Self, Socket) end), case Handler:init(Args) of {ok, HandlerState} -> - {ok, #state{handler = Handler, handler_state = HandlerState}}; + {ok, #state{handler = Handler, handler_state = HandlerState, max_connections = MaxConnections}}; HandlerError -> try_close(Socket), {stop, {handler_error, HandlerError}} @@ -108,7 +121,7 @@ init({Socket, Handler, Args}) -> Self = self(), case Handler:init(Args) of {ok, HandlerState} -> - spawn(fun() -> loop(Self, Socket) end), + spawn_link(fun() -> loop(Self, Socket) end), {ok, #state{handler = Handler, handler_state = HandlerState}}; HandlerError -> {stop, {handler_error, HandlerError}} @@ -123,14 +136,55 @@ handle_cast(_Msg, State) -> {noreply, State}. %% @hidden +handle_info({new_connection, Socket}, State) -> + #state{connections=Conns, max_connections=MaxConns} = State, + case MaxConns > 0 andalso map_size(Conns) >= MaxConns of + true -> + ?TRACE("Connection limit reached (~p), rejecting ~p at accept", [MaxConns, Socket]), + try_close(Socket), + {noreply, State}; + false -> + ?TRACE("Tracking new connection ~p (~p/~p)", [Socket, map_size(Conns) + 1, MaxConns]), + {noreply, State#state{connections = Conns#{Socket => true}}} + end; handle_info({tcp_closed, Socket}, State) -> ?TRACE("TCP Socket closed ~p", [Socket]), - #state{handler=Handler, handler_state=HandlerState} = State, + #state{handler=Handler, handler_state=HandlerState, connections=Conns} = State, NewHandlerState = Handler:handle_tcp_closed(Socket, HandlerState), - {noreply, State#state{handler_state=NewHandlerState}}; + NewConns = maps:remove(Socket, Conns), + {noreply, State#state{handler_state=NewHandlerState, connections=NewConns}}; handle_info({tcp, Socket, Packet}, State) -> - #state{handler=Handler, handler_state=HandlerState} = State, ?TRACE("received packet: len(~p) from ~p", [erlang:byte_size(Packet), socket:peername(Socket)]), + handle_tcp_data(Socket, Packet, State); +handle_info({'EXIT', _Pid, _Reason}, State) -> + ?TRACE("Linked process ~p exited: ~p", [_Pid, _Reason]), + {noreply, State}; +handle_info(Info, State) -> + %% Forward unrecognised messages to the handler if it exports handle_info/2. + %% The handler may return {noreply, NewState} or {close, Socket, NewState}. + #state{handler=Handler, handler_state=HandlerState} = State, + case erlang:function_exported(Handler, handle_info, 2) of + true -> + case Handler:handle_info(Info, HandlerState) of + {noreply, NewHandlerState} -> + {noreply, State#state{handler_state = NewHandlerState}}; + {close, Socket, NewHandlerState} -> + ?TRACE("handle_info requested close for socket ~p", [Socket]), + try_close(Socket), + {noreply, State#state{handler_state = NewHandlerState}} + end; + false -> + io:format("Received spurious info msg: ~p~n", [Info]), + {noreply, State} + end. + +%% @hidden +terminate(_Reason, _State) -> + ok. + +%% @private +handle_tcp_data(Socket, Packet, State) -> + #state{handler=Handler, handler_state=HandlerState} = State, case Handler:handle_receive(Socket, Packet, HandlerState) of {reply, ResponsePacket, ResponseState} -> ?TRACE("Sending reply to endpoint ~p", [socket:peername(Socket)]), @@ -153,7 +207,7 @@ handle_info({tcp, Socket, Packet}, State) -> ok -> try_close(Socket); {error, closed} -> - ok; %% Already closed, nothing to do + ok; {error, _Reason} -> try_close(Socket) end, @@ -166,14 +220,7 @@ handle_info({tcp, Socket, Packet}, State) -> ?TRACE("Unexpected response from handler ~p: ~p", [Handler, _SomethingElse]), try_close(Socket), {noreply, State} - end; -handle_info(Info, State) -> - io:format("Received spurious info msg: ~p~n", [Info]), - {noreply, State}. - -%% @hidden -terminate(_Reason, _State) -> - ok. + end. %% %% internal functions @@ -275,7 +322,11 @@ accept(ControllingProcess, ListenSocket) -> case socket:accept(ListenSocket) of {ok, Connection} -> ?TRACE("Accepted connection from ~p", [socket:peername(Connection)]), - spawn(fun() -> accept(ControllingProcess, ListenSocket) end), + %% Notify controlling process immediately so max_connections is enforced + %% at accept time (before any data arrives). The controlling process may + %% close the socket if the limit is exceeded; loop/2 will detect the close. + ControllingProcess ! {new_connection, Connection}, + spawn_link(fun() -> accept(ControllingProcess, ListenSocket) end), loop(ControllingProcess, Connection); _Error -> ?TRACE("Error accepting connection: ~p", [_Error]), @@ -295,6 +346,9 @@ loop(ControllingProcess, Connection) -> ?TRACE("Peer closed connection ~p", [Connection]), ControllingProcess ! {tcp_closed, Connection}, ok; + {error, timeout} -> + ?TRACE("Timeout on recv from ~p, retrying", [Connection]), + loop(ControllingProcess, Connection); {error, _SomethingElse} -> ?TRACE("Some other error occurred ~p: ~p", [Connection, _SomethingElse]), try_close(Connection) diff --git a/src/httpd.erl b/src/httpd.erl index 4d1dc76..7b227d4 100644 --- a/src/httpd.erl +++ b/src/httpd.erl @@ -17,8 +17,8 @@ -module(httpd). --export([start/2, start/3, start/4, start_link/2, start_link/3, start_link/4, stop/1]). --export([init/1, handle_receive/3, handle_tcp_closed/2]). +-export([start/2, start/3, start/4, start/5, start_link/2, start_link/3, start_link/4, start_link/5, stop/1]). +-export([init/1, handle_receive/3, handle_tcp_closed/2, handle_info/2]). -ifdef(TEST). -export([maybe_parse_http_request/1, handle_request_state/3, get_request_state/1]). @@ -43,7 +43,8 @@ query_params := query_params(), headers := #{binary() := binary()}, body := binary(), - socket := term() + socket := term(), + version := binary() }. -type handler_config() :: #{ module := module(), @@ -70,36 +71,50 @@ config, pending_request_map = #{}, ws_socket_map = #{}, - pending_buffer_map = #{} + pending_buffer_map = #{}, + pending_timer_map = #{}, + request_timeout = 30000 }). %% %% API %% +-type options() :: #{ + request_timeout => pos_integer() +}. + -spec start(Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start(Port, Config) -> - start(any, Port, #{}, Config). + start(any, Port, #{}, #{}, Config). -spec start(Address :: address(), Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start(Address, Port, Config) -> - start(Address, Port, #{}, Config). + start(Address, Port, #{}, #{}, Config). -spec start(Address :: address(), Port :: portnum(), SocketOptions :: map(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start(Address, Port, SocketOptions, Config) -> - gen_tcp_server:start(#{addr => Address, port => Port}, SocketOptions, ?MODULE, Config). + start(Address, Port, SocketOptions, #{}, Config). + +-spec start(Address :: address(), Port :: portnum(), SocketOptions :: map(), Options :: options(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. +start(Address, Port, SocketOptions, Options, Config) -> + gen_tcp_server:start(#{addr => Address, port => Port}, SocketOptions, ?MODULE, {Options, Config}). -spec start_link(Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start_link(Port, Config) -> - start_link(any, Port, #{}, Config). + start_link(any, Port, #{}, #{}, Config). -spec start_link(Address :: address(), Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start_link(Address, Port, Config) -> - start_link(Address, Port, #{}, Config). + start_link(Address, Port, #{}, #{}, Config). -spec start_link(Address :: address(), Port :: portnum(), SocketOptions :: map(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start_link(Address, Port, SocketOptions, Config) -> - gen_tcp_server:start_link(#{addr => Address, port => Port}, SocketOptions, ?MODULE, Config). + start_link(Address, Port, SocketOptions, #{}, Config). + +-spec start_link(Address :: address(), Port :: portnum(), SocketOptions :: map(), Options :: options(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. +start_link(Address, Port, SocketOptions, Options, Config) -> + gen_tcp_server:start_link(#{addr => Address, port => Port}, SocketOptions, ?MODULE, {Options, Config}). stop(Httpd) -> gen_tcp_server:stop(Httpd). @@ -109,7 +124,11 @@ stop(Httpd) -> %% %% @hidden +init({Options, Config}) -> + Timeout = maps:get(request_timeout, Options, 30000), + {ok, #state{config = Config, request_timeout = Timeout}}; init(Config) -> + %% Backwards-compatible: called with just Config (no Options). {ok, #state{config = Config}}. %% @hidden @@ -143,7 +162,7 @@ handle_http_request(Socket, Packet, State) -> case maybe_parse_http_request(AccumulatedPacket) of {more, IncompletePacket} -> NewBufferMap = BufferMap#{Socket => IncompletePacket}, - {noreply, State#state{pending_buffer_map = NewBufferMap}}; + {noreply, start_request_timer(Socket, State#state{pending_buffer_map = NewBufferMap})}; {ok, HttpRequest} -> CleanBufferMap = maps:remove(Socket, BufferMap), CleanState = State#state{pending_buffer_map = CleanBufferMap}, @@ -152,7 +171,11 @@ handle_http_request(Socket, Packet, State) -> method := Method, headers := Headers } = HttpRequest, - case get_protocol(Method, Headers) of + case Method of + undefined -> + {close, create_error(?NOT_ALLOWED, method_not_allowed)}; + _ -> + case get_protocol(Method, Headers) of http -> case init_handler(HttpRequest, CleanState) of {ok, {Handler, HandlerState, PathSuffix, HandlerConfig}} -> @@ -169,30 +192,37 @@ handle_http_request(Socket, Packet, State) -> end; ws -> ?TRACE("Protocol is ws", []), - Config = CleanState#state.config, - Path = maps:get(path, HttpRequest), - case get_handler(Path, Config) of - {ok, PathSuffix, EntryConfig} -> - WsHandler = maps:get(handler, EntryConfig), - ?TRACE("Got handler ~p", [WsHandler]), - HandlerConfig = maps:get(handler_config, EntryConfig, #{}), - case WsHandler:start(Socket, PathSuffix, HandlerConfig) of - {ok, WebSocket} -> - ?TRACE("Started web socket handler: ~p", [WebSocket]), - NewWebSocketMap = maps:put(Socket, WebSocket, CleanState#state.ws_socket_map), - NewState = CleanState#state{ws_socket_map = NewWebSocketMap}, - ReplyToken = get_reply_token(maps:get(headers, HttpRequest)), - ReplyHeaders = #{"Upgrade" => "websocket", "Connection" => "Upgrade", "Sec-WebSocket-Accept" => ReplyToken}, - Reply = create_reply(?SWITCHING_PROTOCOLS, ReplyHeaders, <<"">>), - ?TRACE("Sending web socket upgrade reply: ~p", [Reply]), - {reply, Reply, NewState}; + Headers = maps:get(headers, HttpRequest, #{}), + case get_ws_key(Headers) of + {ok, WebSocketKey} -> + ReplyToken = get_reply_token(WebSocketKey), + Config = CleanState#state.config, + Path = maps:get(path, HttpRequest), + case get_handler(Path, Config) of + {ok, PathSuffix, EntryConfig} -> + WsHandler = maps:get(handler, EntryConfig), + ?TRACE("Got handler ~p", [WsHandler]), + HandlerConfig = maps:get(handler_config, EntryConfig, #{}), + case WsHandler:start(Socket, PathSuffix, HandlerConfig) of + {ok, WebSocket} -> + ?TRACE("Started web socket handler: ~p", [WebSocket]), + NewWebSocketMap = maps:put(Socket, WebSocket, CleanState#state.ws_socket_map), + NewState = CleanState#state{ws_socket_map = NewWebSocketMap}, + ReplyHeaders = #{"Upgrade" => "websocket", "Connection" => "Upgrade", "Sec-WebSocket-Accept" => ReplyToken}, + Reply = create_reply(?SWITCHING_PROTOCOLS, ReplyHeaders, <<"">>), + ?TRACE("Sending web socket upgrade reply: ~p", [Reply]), + {reply, Reply, NewState}; + Error -> + ?TRACE("Web socket error: ~p", [Error]), + {close, create_error(?INTERNAL_SERVER_ERROR, {web_socket_error, Error})} + end; Error -> - ?TRACE("Web socket error: ~p", [Error]), {close, create_error(?INTERNAL_SERVER_ERROR, {web_socket_error, Error})} end; - Error -> - {close, create_error(?INTERNAL_SERVER_ERROR, {web_socket_error, Error})} + error -> + {close, create_error(?BAD_REQUEST, missing_websocket_key)} end + end end; {error, Reason} -> {close, create_error(?BAD_REQUEST, Reason)} @@ -232,28 +262,29 @@ handle_request_state(Socket, HttpRequest, State) -> complete -> ?TRACE("Request complete. Handling...", []), NewPendingRequestMap = maps:remove(Socket, PendingRequestMap), - call_http_req_handler(Socket, HttpRequest, State#state{pending_request_map = NewPendingRequestMap}); + CleanState = stop_request_timer(Socket, State#state{pending_request_map = NewPendingRequestMap}), + call_http_req_handler(Socket, HttpRequest, CleanState); expect_continue -> Headers = maps:get(headers, HttpRequest), - NewHeaders = maps:remove(<<"Expect">>, Headers), + NewHeaders = maps:remove(<<"expect">>, Headers), NewHttpRequest = HttpRequest#{headers := NewHeaders}, Reply = create_reply(?CONTINUE, #{}, <<"">>), NewPendingRequestMap = PendingRequestMap#{Socket => NewHttpRequest}, - {reply, Reply, State#state{pending_request_map = NewPendingRequestMap}}; + {reply, Reply, start_request_timer(Socket, State#state{pending_request_map = NewPendingRequestMap})}; wait_for_body -> NewPendingRequestMap = PendingRequestMap#{Socket => HttpRequest}, - {noreply, State#state{pending_request_map = NewPendingRequestMap}} + {noreply, start_request_timer(Socket, State#state{pending_request_map = NewPendingRequestMap})} end. %% @private get_request_state(HttpRequest) -> Headers = maps:get(headers, HttpRequest), - case maps:get(<<"Expect">>, Headers, undefined) of + case maps:get(<<"expect">>, Headers, undefined) of <<"100-continue">> -> ?TRACE("Expect: 100-continue", []), expect_continue; undefined -> - case maps:get(<<"Content-Length">>, Headers, undefined) of + case maps:get(<<"content-length">>, Headers, undefined) of undefined -> ?TRACE("No content length; request complete", []), complete; @@ -284,20 +315,27 @@ call_http_req_handler(Socket, HttpRequest, State) -> {noreply, NewHandlerState} -> NewState = update_state(Socket, HttpRequest, NewHandlerState, State), {noreply, NewState}; - %% reply + %% reply — always keeps the socket open (gen_tcp_server treats {reply,...} as keep-open). + %% NOTE: HTTP/1.0 default-close and Connection: close semantics are not yet implemented; + %% that negotiation is deferred to a follow-up. Handlers that need to force a close + %% should return {close, ...} instead. {reply, Reply, NewHandlerState} -> NewState = update_state(Socket, HttpRequest, NewHandlerState, State), {reply, create_reply(?OK, #{"Content-Type" => "application/octet-stream"}, Reply), NewState}; {reply, ReplyHeaders, Reply, NewHandlerState} -> NewState = update_state(Socket, HttpRequest, NewHandlerState, State), {reply, create_reply(?OK, ReplyHeaders, Reply), NewState}; - %% close + %% close — handler explicitly requests connection close; always honour it + %% regardless of the client's keep-alive preference, preserving the documented + %% httpd_handler contract that {close, ...} means "send response, close connection". close -> {close, State}; {close, Reply} -> - {close, create_reply(?OK, #{"Content-Type" => "application/octet-stream"}, Reply)}; + ReplyPacket = create_reply(?OK, #{"Content-Type" => "application/octet-stream"}, Reply), + {close, ReplyPacket}; {close, ReplyHeaders, Reply} -> - {close, create_reply(?OK, ReplyHeaders, Reply)}; + ReplyPacket = create_reply(?OK, ReplyHeaders, Reply), + {close, ReplyPacket}; %% errors {error, not_found} -> {close, create_error(?NOT_FOUND, not_found)}; @@ -319,9 +357,13 @@ update_state(Socket, HttpRequest, HandlerState, State) -> %% @hidden handle_tcp_closed(Socket, State) -> - NewPendingRequestMap = maps:remove(Socket, State#state.pending_request_map), - NewPendingBufferMap = maps:remove(Socket, State#state.pending_buffer_map), - CleanState = State#state{ + %% Cancel any pending request timer so it cannot fire after the socket is gone + %% and deliver a stale {request_timeout, Socket} message that might accidentally + %% close a future connection reusing the same socket term. + TimerCancelledState = stop_request_timer(Socket, State), + NewPendingRequestMap = maps:remove(Socket, TimerCancelledState#state.pending_request_map), + NewPendingBufferMap = maps:remove(Socket, TimerCancelledState#state.pending_buffer_map), + CleanState = TimerCancelledState#state{ pending_request_map = NewPendingRequestMap, pending_buffer_map = NewPendingBufferMap }, @@ -334,13 +376,43 @@ handle_tcp_closed(Socket, State) -> CleanState#state{ws_socket_map = NewWebSocketMap} end. +%% @hidden +%% Validate timer-tagged request timeout messages. Using the TimerRef in the +%% message tag makes this race-free: if stop_request_timer/2 cancelled the +%% timer before the message was delivered, the ref will no longer be in +%% pending_timer_map and we ignore the stale message; if the timer fired first, +%% the ref matches and we correctly close the socket. +handle_info({request_timeout, Socket, Tag}, State) -> + TimerMap = State#state.pending_timer_map, + case maps:get(Socket, TimerMap, undefined) of + {_TimerRef, Tag} -> + %% Tag matches the current timer — the request genuinely timed out. + ?TRACE("Request timeout confirmed for socket ~p (tag ~p)", [Socket, Tag]), + NewTimerMap = maps:remove(Socket, TimerMap), + NewState = State#state{pending_timer_map = NewTimerMap}, + {close, Socket, NewState}; + _ -> + %% Tag does not match: timer was cancelled and a new one installed + %% (keep-alive), or the entry was already removed (request completed). + %% Ignore the stale message. + ?TRACE("Ignoring stale request_timeout for socket ~p (tag ~p)", [Socket, Tag]), + {noreply, State} + end; +handle_info(_Msg, State) -> + {noreply, State}. + %% %% Internal functions %% %% @private -get_reply_token(Headers) -> - #{<<"Sec-WebSocket-Key">> := WebSocketKey} = Headers, +get_ws_key(#{<<"sec-websocket-key">> := Key}) -> + {ok, Key}; +get_ws_key(_) -> + error. + +%% @private +get_reply_token(WebSocketKey) -> MagicKey = <<"258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>, PreImage = <>, ReplyToken = base64:encode(crypto:hash(sha, PreImage)), @@ -348,14 +420,14 @@ get_reply_token(Headers) -> ReplyToken. %% @private -parse_http_request(Packet) -> - {Heading, HeadingRest} = parse_heading(Packet, start, [], #{}), - {Headers, Body} = parse_header(HeadingRest, #{}), +parse_http_request(HeadingList, Body) -> + {Heading, _HeadingRest} = parse_heading(HeadingList, start, [], #{}), + {Headers, _} = parse_header(_HeadingRest, #{}), maps:merge( Heading, #{ headers => Headers, - body => erlang:list_to_binary(Body) + body => Body } ). @@ -363,9 +435,11 @@ maybe_parse_http_request(Packet) when is_binary(Packet) -> case find_header_delimiter(Packet) of nomatch -> {more, Packet}; - {_Pos, _Len} -> + {Pos, Len} -> try - {ok, parse_http_request(binary_to_list(Packet))} + HeaderEnd = Pos + Len, + <> = Packet, + {ok, parse_http_request(binary_to_list(HeadingPart), Body)} catch throw:Reason -> {error, Reason}; @@ -410,8 +484,13 @@ parse_heading([$\s|Rest], wait_version, Tmp, Accum) -> parse_heading(Packet, wait_version, Tmp, Accum) -> parse_heading(Packet, in_version, Tmp, Accum); %% in_version state -parse_heading([$\n|Rest], in_version, _Tmp, Accum) -> - {Accum, Rest}; +parse_heading([$\n|Rest], in_version, Tmp, Accum) -> + RawVersion = lists:reverse(Tmp), + Version = case RawVersion of + [$\r | Clean] -> list_to_binary(Clean); + _ -> list_to_binary(RawVersion) + end, + {Accum#{version => Version}, Rest}; parse_heading([C|Rest], in_version, Tmp, Accum) -> parse_heading(Rest, in_version, [C|Tmp], Accum); %% error state @@ -439,9 +518,12 @@ parse_line(_Packet, _Accum) -> %% @private split_header(Header) -> - [Key, Value] = string:split(Header, ":"), - %% TODO to_lower the key - {list_to_binary(string:trim(Key)), list_to_binary(string:trim(Value))}. + case string:split(Header, ":") of + [Key, Value] -> + {list_to_binary(string:to_lower(string:trim(Key))), list_to_binary(string:trim(Value))}; + _ -> + throw(bad_header) + end. normalize_uri(Uri) -> case string:split(Uri, "?", leading) of @@ -458,8 +540,17 @@ tokenize_path(Path) -> %% @private parse_query_params(QueryParamString) -> NVPairsStrings = string:split(QueryParamString, "&", all), - NVPairLists = [string:split(NVPairString, "=") || NVPairString <- NVPairsStrings], - maps:from_list([{list_to_atom(Key), url_decode(Value, [])} || [Key, Value] <- NVPairLists]). + maps:from_list([parse_query_param(NVPairString) || NVPairString <- NVPairsStrings]). + +parse_query_param(NVPairString) -> + case string:split(NVPairString, "=") of + [Key] -> + {list_to_binary(Key), <<"">>}; + [Key, Value] -> + %% url_decode/2 returns a charlist; convert to binary so all + %% query param values are binaries as declared in query_params(). + {list_to_binary(Key), list_to_binary(url_decode(Value, []))} + end. % from https://docs.microfocus.com/OMi/10.62/Content/OMi/ExtGuide/ExtApps/URL_encoding.htm url_decode([], Accum) -> @@ -558,7 +649,7 @@ starts_with([_H1|_], [_H2|_]) -> %% @private -get_protocol(get, #{<<"Upgrade">> := <<"websocket">>, <<"Connection">> := Upgrade, <<"Sec-WebSocket-Key">> := _, <<"Sec-WebSocket-Version">> := <<"13">>} = _Headers) -> +get_protocol(get, #{<<"upgrade">> := <<"websocket">>, <<"connection">> := Upgrade, <<"sec-websocket-key">> := _, <<"sec-websocket-version">> := <<"13">>} = _Headers) -> case str(string:to_upper(binary_to_list(Upgrade)), "UPGRADE") of 0 -> http; @@ -580,7 +671,12 @@ create_reply(StatusCode, ContentType, Reply) when is_list(ContentType) orelse is create_reply(StatusCode, #{"Content-Type" => ContentType}, Reply); create_reply(StatusCode, Headers, Reply) when is_map(Headers) -> ReplyLen = erlang:iolist_size(Reply), - HeadersWithLen = ensure_content_length(Headers, ReplyLen), + %% Normalize all header keys to lowercase binary before computing + %% Content-Length so that ensure_content_length/2 can reliably strip any + %% pre-existing content-length variant (e.g. "Content-Length", <<"Content-Length">>) + %% and avoid emitting duplicate headers. + NormalizedHeaders = normalize_headers(Headers), + HeadersWithLen = ensure_content_length(NormalizedHeaders, ReplyLen), [ <<"HTTP/1.1 ">>, erlang:integer_to_binary(StatusCode), <<" ">>, moniker(StatusCode), <<"\r\n">>, @@ -591,20 +687,33 @@ create_reply(StatusCode, Headers, Reply) when is_map(Headers) -> ]. %% @private -ensure_content_length(Headers, ReplyLen) -> - LenBin = erlang:integer_to_binary(ReplyLen), - CleanHeaders = remove_content_length_header(Headers), - CleanHeaders#{<<"Content-Length">> => LenBin}. +%% Rewrite every key in a response-header map to a lowercase binary so that +%% ensure_content_length/2 and to_headers_list/1 always operate on a uniform +%% representation regardless of whether the caller used strings, binaries, or +%% mixed-case atoms. +normalize_headers(Headers) -> + maps:fold( + fun(Key, Value, Acc) -> + NormKey = normalize_header_key(Key), + Acc#{NormKey => Value} + end, + #{}, + Headers + ). + +normalize_header_key(Key) when is_binary(Key) -> + list_to_binary(string:to_lower(binary_to_list(Key))); +normalize_header_key(Key) when is_list(Key) -> + list_to_binary(string:to_lower(Key)); +normalize_header_key(Key) when is_atom(Key) -> + list_to_binary(string:to_lower(atom_to_list(Key))). %% @private -remove_content_length_header(Headers) -> - KeysToRemove = [ - "Content-Length", - <<"Content-Length">>, - "content-length", - <<"content-length">> - ], - lists:foldl(fun(Key, Acc) -> maps:remove(Key, Acc) end, Headers, KeysToRemove). +ensure_content_length(Headers, ReplyLen) -> + LenBin = erlang:integer_to_binary(ReplyLen), + %% After normalize_headers/1 the key is always <<"content-length">>. + CleanHeaders = maps:remove(<<"content-length">>, Headers), + CleanHeaders#{<<"content-length">> => LenBin}. %% @private maybe_binary_to_string(Bin) when is_binary(Bin) -> @@ -640,6 +749,8 @@ moniker(?BAD_REQUEST) -> <<"BAD_REQUEST">>; moniker(?NOT_FOUND) -> <<"NOT_FOUND">>; +moniker(?NOT_ALLOWED) -> + <<"METHOD_NOT_ALLOWED">>; moniker(?CONTINUE) -> <<"Continue">>; moniker(?SWITCHING_PROTOCOLS) -> @@ -658,3 +769,34 @@ method_to_atom("DELETE") -> delete; method_to_atom(_) -> undefined. + +%% @private +%% Each timer entry is stored as {TimerRef, Tag} where Tag = make_ref(). +%% Tag is embedded in the {request_timeout, Socket, Tag} message so that +%% handle_info/2 can compare it against the current map entry and safely +%% ignore any stale messages that arrive after cancel_timer/1 was called +%% (they carry an old Tag that no longer matches). This makes the timeout +%% handling race-free without needing a receive-flush. +start_request_timer(Socket, State) -> + Timeout = State#state.request_timeout, + TimerMap = State#state.pending_timer_map, + %% Cancel any pre-existing timer for this socket. + case maps:get(Socket, TimerMap, undefined) of + undefined -> ok; + {OldRef, _OldTag} -> erlang:cancel_timer(OldRef) + end, + Tag = make_ref(), + TimerRef = erlang:send_after(Timeout, self(), {request_timeout, Socket, Tag}), + State#state{pending_timer_map = TimerMap#{Socket => {TimerRef, Tag}}}. + +%% @private +stop_request_timer(Socket, State) -> + TimerMap = State#state.pending_timer_map, + %% Cancel the timer. Any {request_timeout, Socket, Tag} already in the + %% mailbox carries the old Tag; handle_info/2 will ignore it because we + %% remove the entry from pending_timer_map here — no receive-flush needed. + case maps:get(Socket, TimerMap, undefined) of + undefined -> ok; + {Ref, _Tag} -> erlang:cancel_timer(Ref) + end, + State#state{pending_timer_map = maps:remove(Socket, TimerMap)}. diff --git a/src/httpd_env_api_handler.erl b/src/httpd_env_api_handler.erl index 8d2df4b..c251e55 100644 --- a/src/httpd_env_api_handler.erl +++ b/src/httpd_env_api_handler.erl @@ -31,64 +31,69 @@ handle_api_request(get, [Application, Param | Rest], _HttpRequest, _Args) -> ?TRACE("Application: ~p Param: ~p, Rest: ~p", [Application, Param, Rest]), - ApplicationAtom = bin_to_atom(Application), - ParamAtom = bin_to_atom(Param), - Result = case avm_application:get_env(ApplicationAtom, ParamAtom) of - undefined -> - undefined; - {ok, Value} -> - find_value_in_path(Value, Rest) - end, - case Result of - undefined -> - {error, not_found}; - _ -> - {ok, Result} + case to_existing_atoms(Application, Param) of + {ok, ApplicationAtom, ParamAtom} -> + Result = case avm_application:get_env(ApplicationAtom, ParamAtom) of + undefined -> + undefined; + {ok, Value} -> + find_value_in_path(Value, Rest) + end, + case Result of + undefined -> + {error, not_found}; + _ -> + {ok, Result} + end; + error -> + {error, not_found} end; handle_api_request(post, [Application, Param | Rest], HttpRequest, _Args) -> ?TRACE("Application: ~p Param: ~p, Rest: ~p", [Application, Param, Rest]), - QueryParams = maps:get(query_params, HttpRequest, #{}), - ?TRACE("QueryParams: ~p", [QueryParams]), - - ApplicationAtom = bin_to_atom(Application), - ParamAtom = bin_to_atom(Param), - - NewValue = create_value(Rest, QueryParams, #{}), - ?TRACE("NewValue: ~p", [NewValue]), - MergedValue = case avm_application:get_env(ApplicationAtom, ParamAtom) of - undefined -> - NewValue; - {ok, OldValue} -> - ?TRACE("merging OldValue: ~p NewValue: ~p", [OldValue, NewValue]), - map_utils:deep_maps_merge(OldValue, NewValue) - end, - - ?TRACE("QueryParams: ~p MergedValue: ~p", [QueryParams, MergedValue]), - ok = avm_application:set_env(ApplicationAtom, ParamAtom, MergedValue, [{persistent, true}]); + case to_existing_atoms(Application, Param) of + {ok, ApplicationAtom, ParamAtom} -> + QueryParams = maps:get(query_params, HttpRequest, #{}), + ?TRACE("QueryParams: ~p", [QueryParams]), + + NewValue = create_value(Rest, QueryParams, #{}), + ?TRACE("NewValue: ~p", [NewValue]), + MergedValue = case avm_application:get_env(ApplicationAtom, ParamAtom) of + undefined -> + NewValue; + {ok, OldValue} -> + ?TRACE("merging OldValue: ~p NewValue: ~p", [OldValue, NewValue]), + map_utils:deep_maps_merge(OldValue, NewValue) + end, + + ?TRACE("QueryParams: ~p MergedValue: ~p", [QueryParams, MergedValue]), + ok = avm_application:set_env(ApplicationAtom, ParamAtom, MergedValue, [{persistent, true}]); + error -> + {error, not_found} + end; handle_api_request(delete, [Application, Param | Rest], _HttpRequest, _Args) -> ?TRACE("Application: ~p Param: ~p, Rest: ~p", [Application, Param, Rest]), - ApplicationAtom = bin_to_atom(Application), - ParamAtom = bin_to_atom(Param), - Result = case avm_application:get_env(ApplicationAtom, ParamAtom) of - undefined -> - undefined; - {ok, Env} -> - %% TODO memory leak - Path = [bin_to_atom(P) || P <- Rest], - ?TRACE("Removing path ~p from env ~p", [Path, Env]), - map_utils:remove_entry_in_path(Env, Path) - end, - case Result of - undefined -> - {error, not_found}; - NewEnv -> - ?TRACE("NewEnv: ~p", [NewEnv]), - avm_application:set_env(ApplicationAtom, ParamAtom, NewEnv), - ok + case to_existing_atoms(Application, Param) of + {ok, ApplicationAtom, ParamAtom} -> + Result = case avm_application:get_env(ApplicationAtom, ParamAtom) of + undefined -> + undefined; + {ok, Env} -> + map_utils:remove_entry_in_path(Env, Rest) + end, + case Result of + undefined -> + {error, not_found}; + NewEnv -> + ?TRACE("NewEnv: ~p", [NewEnv]), + avm_application:set_env(ApplicationAtom, ParamAtom, NewEnv), + ok + end; + error -> + {error, not_found} end; handle_api_request(Method, Path, _HttpRequest, _Args) -> @@ -98,21 +103,37 @@ handle_api_request(Method, Path, _HttpRequest, _Args) -> find_value_in_path(Map, []) -> Map; find_value_in_path(Value, [H | T]) when is_map(Value) -> - %% TODO binary to atom here is bad - case maps:get(bin_to_atom(H), Value, undefined) of + case maps:get(H, Value, undefined) of undefined -> - undefined; + case to_existing_atom(H) of + {ok, Atom} -> find_value_in_path(maps:get(Atom, Value, undefined), T); + error -> undefined + end; V -> find_value_in_path(V, T) end; find_value_in_path(_Value, _Path) -> undefined. -bin_to_atom(Bin) -> - list_to_atom(binary_to_list(Bin)). - create_value([], QueryParams, Accum) -> maps:merge(Accum, QueryParams); create_value([H | T], QueryParams, Accum) -> - %% TODO binary to atom here is bad - #{bin_to_atom(H) => create_value(T, QueryParams, Accum)}. + #{H => create_value(T, QueryParams, Accum)}. + +to_existing_atoms(A, B) -> + case to_existing_atom(A) of + {ok, AtomA} -> + case to_existing_atom(B) of + {ok, AtomB} -> {ok, AtomA, AtomB}; + error -> error + end; + error -> + error + end. + +to_existing_atom(Bin) -> + try list_to_existing_atom(binary_to_list(Bin)) of + Atom -> {ok, Atom} + catch + error:badarg -> error + end. diff --git a/src/httpd_handler.erl b/src/httpd_handler.erl index d6a1ba0..dab32f8 100644 --- a/src/httpd_handler.erl +++ b/src/httpd_handler.erl @@ -31,7 +31,8 @@ method => http_method(), path => http_path(), headers => http_headers(), - body => binary() + body => binary(), + version => binary() }. %% diff --git a/src/httpd_ota_handler.erl b/src/httpd_ota_handler.erl index 45b5475..baad923 100644 --- a/src/httpd_ota_handler.erl +++ b/src/httpd_ota_handler.erl @@ -76,4 +76,4 @@ handle_http_req(_HttpRequest, _State) -> get_content_length(Headers) -> %% TODO handle case - erlang:binary_to_integer(maps:get(<<"Content-Length">>, Headers, <<"0">>)). + erlang:binary_to_integer(maps:get(<<"content-length">>, Headers, <<"0">>)). diff --git a/src/httpd_ws_handler.erl b/src/httpd_ws_handler.erl index 0493503..1fa6f01 100644 --- a/src/httpd_ws_handler.erl +++ b/src/httpd_ws_handler.erl @@ -22,7 +22,7 @@ -export([send/2]). -behavior(gen_server). --export([init/1, handle_cast/2, handle_call/3, handle_info/2, terminate/2]). +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2]). % -define(TRACE_ENABLED, true). -include_lib("atomvm_httpd/include/trace.hrl"). @@ -65,12 +65,7 @@ handle_web_socket_message(WebSocket, Packet) -> gen_server:cast(WebSocket, {message, Packet}). send(WebSocket, Packet) -> - case self() of - WebSocket -> - throw(badarg); - _ -> - gen_server:call(WebSocket, {send, Packet}) - end. + gen_server:cast(WebSocket, {send, Packet}). %% @@ -138,13 +133,17 @@ handle_cast({message, Packet}, State) -> ?TRACE("ParseFrameError: ~p", [ParseFrameError]), socket:close(Socket), {stop, ParseFrameError, State} - end. - + end; %% @hidden -handle_call({send, Packet}, _From, State) -> +handle_cast({send, Packet}, State) -> ?TRACE("Sending packet ~p", [Packet]), - Reply = do_send(State#state.socket, Packet, text), - {reply, Reply, State}. + do_send(State#state.socket, Packet, text), + {noreply, State}. + +%% @hidden +handle_call(_Request, _From, State) -> + {reply, ok, State}. + %% @hidden handle_info(_Msg, State) -> @@ -242,26 +241,26 @@ extract_payload(Mask, PayloadLen, Data) -> end. %% @private -unmask(MaskingKey, MaskedPayload) -> - unmask(MaskingKey, MaskedPayload, 0, []). - -unmask(_MaskingKey, <<"">>, _I, Accum) -> - % ?TRACE("unmasked Accum: ~p", [Accum]), - list_to_binary(lists:reverse(Accum)); -unmask(MaskingKey, <>, I, Accum) -> - MaskingOctet = octet(MaskingKey, I rem 4), - % ?TRACE("H: ~p, MaskingOctet: ~p", [H, MaskingOctet]), - unmask(MaskingKey, T, I + 1, [MaskingOctet bxor H | Accum]). +unmask(<>, Payload) -> + Size = byte_size(Payload), + FullChunks = Size bsr 2, + Rem = Size band 3, + case Rem of + 0 -> + << <<(A bxor K0), (B bxor K1), (C bxor K2), (D bxor K3)>> || + <> <= Payload >>; + _ -> + ChunkSize = FullChunks bsl 2, + <> = Payload, + Unmasked = << <<(A bxor K0), (B bxor K1), (C bxor K2), (D bxor K3)>> || + <> <= ChunkedPart >>, + <> + end. -%% @private -octet(<>, 0) -> - First; -octet(<<_:1/binary, Second:8, _/binary>>, 1) -> - Second; -octet(<<_:2/binary, Third:8, _/binary>>, 2) -> - Third; -octet(<<_:3/binary, Fourth:8, _/binary>>, 3) -> - Fourth. +unmask_rem(<<>>, _, _, _, _, 0) -> <<>>; +unmask_rem(<>, K0, _, _, _, 1) -> <<(B bxor K0)>>; +unmask_rem(<>, K0, K1, _, _, 2) -> <<(B bxor K0), (C bxor K1)>>; +unmask_rem(<>, K0, K1, K2, _, 3) -> <<(B bxor K0), (C bxor K1), (D bxor K2)>>. %% @private do_send(Socket, Packet, Mode) -> @@ -277,7 +276,7 @@ frame(Packet, Mode) when is_binary(Packet) -> Opcode = case Mode of text -> 16#01; binary -> 16#02; _ -> 16#01 end, FinOpcode = Fin bor Opcode, PayloadLen = erlang:byte_size(Packet), - case {PayloadLen =< 125, PayloadLen =< 65536} of + case {PayloadLen =< 125, PayloadLen < 65536} of {true, _} -> NoMask = 16#7F, MaskLen = NoMask band PayloadLen, diff --git a/test/atomvm_httpd_test.exs b/test/atomvm_httpd_test.exs index a51c2aa..49c71f0 100644 --- a/test/atomvm_httpd_test.exs +++ b/test/atomvm_httpd_test.exs @@ -13,7 +13,7 @@ defmodule HttpdUnitTest do assert :post = Map.fetch!(http_request, :method) headers = Map.fetch!(http_request, :headers) - assert <<"11">> = Map.fetch!(headers, <<"Content-Length">>) + assert <<"11">> = Map.fetch!(headers, <<"content-length">>) assert <<"hello=world">> = Map.fetch!(http_request, :body) end @@ -23,16 +23,31 @@ defmodule HttpdUnitTest do assert {:ok, http_request} = :httpd.maybe_parse_http_request(request) headers = Map.fetch!(http_request, :headers) - assert <<"value200">> = Map.fetch!(headers, <<"X-Test-200">>) + assert <<"value200">> = Map.fetch!(headers, <<"x-test-200">>) end test "handle_request_state stores partial body until complete" do socket = make_ref() - http_request = %{headers: %{<<"Content-Length">> => <<"5">>}, body: <<"12">>} - state = {:state, [], %{}, %{}, %{}} - - assert {:noreply, {:state, [], %{^socket => ^http_request}, %{}, %{}}} = - :httpd.handle_request_state(socket, http_request, state) + http_request = %{headers: %{<<"content-length">> => <<"5">>}, body: <<"12">>} + state = {:state, [], %{}, %{}, %{}, %{}, 30000} + + assert {:noreply, result_state} = :httpd.handle_request_state(socket, http_request, state) + + # Destructure the result state tuple: {state, config, pending_request_map, + # ws_socket_map, pending_buffer_map, pending_timer_map, request_timeout} + {:state, _config, pending_request_map, _ws, _buf, pending_timer_map, _timeout} = result_state + + # Partial request should be stored in the pending map + assert %{^socket => ^http_request} = pending_request_map + + # A request timer should have been started for the socket. + # The entry is {TimerRef, Tag} — both are opaque references. + timer_entry = Map.get(pending_timer_map, socket) + assert is_tuple(timer_entry) and tuple_size(timer_entry) == 2, + "expected a {timer_ref, tag} tuple in pending_timer_map for the socket" + {t_ref, t_tag} = timer_entry + assert is_reference(t_ref) + assert is_reference(t_tag) assert :wait_for_body = :httpd.get_request_state(http_request) end diff --git a/test/httpd_integration_test.exs b/test/httpd_integration_test.exs index 496e2a6..9ca5068 100644 --- a/test/httpd_integration_test.exs +++ b/test/httpd_integration_test.exs @@ -38,7 +38,7 @@ defmodule HttpdIntegrationTest do try do request_chunks = [ - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 11\r\n\r\nhe", + "POST / HTTP/1.1\r\nHost: example.com\r\ncontent-length: 11\r\n\r\nhe", "llo=", "world" ] @@ -69,7 +69,7 @@ defmodule HttpdIntegrationTest do assert_receive {:http_request, request}, @receive_timeout headers = Map.fetch!(request, :headers) - assert <<"value123">> = Map.fetch!(headers, <<"X-Custom-Header">>) + assert <<"value123">> = Map.fetch!(headers, <<"x-custom-header">>) assert {:ok, response} = :gen_tcp.recv(socket, 0, @receive_timeout) assert response =~ "HTTP/1.1 200 OK" @@ -114,7 +114,7 @@ defmodule HttpdIntegrationTest do [headers, body] = :binary.split(response, <<"\r\n\r\n">>) assert String.contains?(headers, "HTTP/1.1 200 OK") - assert String.contains?(headers, "Content-Length: " <> @large_iolist_len) + assert String.contains?(headers, "content-length: " <> @large_iolist_len) assert byte_size(body) == :erlang.iolist_size(@large_iolist) after :gen_tcp.close(socket) @@ -140,8 +140,8 @@ defmodule HttpdIntegrationTest do assert String.contains?(headers, "HTTP/1.1 200 OK") expected_length = :erlang.iolist_size(iolist) - assert String.contains?(headers, "Content-Length: #{expected_length}"), - "Expected Content-Length: #{expected_length}" + assert String.contains?(headers, "content-length: #{expected_length}"), + "Expected content-length: #{expected_length}" assert body == expected_body assert byte_size(body) == expected_length @@ -169,8 +169,8 @@ defmodule HttpdIntegrationTest do assert String.contains?(headers, "HTTP/1.1 200 OK") expected_length = :erlang.iolist_size(iolist) - assert String.contains?(headers, "Content-Length: #{expected_length}"), - "Expected Content-Length: #{expected_length}" + assert String.contains?(headers, "content-length: #{expected_length}"), + "Expected content-length: #{expected_length}" assert body == expected_body assert byte_size(body) == expected_length @@ -198,8 +198,8 @@ defmodule HttpdIntegrationTest do assert String.contains?(headers, "HTTP/1.1 200 OK") expected_length = :erlang.iolist_size(iolist) - assert String.contains?(headers, "Content-Length: #{expected_length}"), - "Expected Content-Length: #{expected_length}" + assert String.contains?(headers, "content-length: #{expected_length}"), + "Expected content-length: #{expected_length}" assert body == expected_body assert byte_size(body) == expected_length @@ -227,8 +227,8 @@ defmodule HttpdIntegrationTest do assert String.contains?(headers, "HTTP/1.1 200 OK") expected_length = :erlang.iolist_size(iolist) - assert String.contains?(headers, "Content-Length: #{expected_length}"), - "Expected Content-Length: #{expected_length}" + assert String.contains?(headers, "content-length: #{expected_length}"), + "Expected content-length: #{expected_length}" assert body == expected_body assert byte_size(body) == expected_length @@ -256,8 +256,8 @@ defmodule HttpdIntegrationTest do assert String.contains?(headers, "HTTP/1.1 200 OK") expected_length = :erlang.iolist_size(iolist) - assert String.contains?(headers, "Content-Length: #{expected_length}"), - "Expected Content-Length: #{expected_length}" + assert String.contains?(headers, "content-length: #{expected_length}"), + "Expected content-length: #{expected_length}" assert body == expected_body assert byte_size(body) == expected_length @@ -285,8 +285,8 @@ defmodule HttpdIntegrationTest do assert String.contains?(headers, "HTTP/1.1 200 OK") expected_length = :erlang.iolist_size(iolist) - assert String.contains?(headers, "Content-Length: #{expected_length}"), - "Expected Content-Length: #{expected_length}" + assert String.contains?(headers, "content-length: #{expected_length}"), + "Expected content-length: #{expected_length}" assert body == expected_body assert byte_size(body) == expected_length @@ -295,6 +295,55 @@ defmodule HttpdIntegrationTest do end end + describe "request timeout" do + # A deliberately short timeout so tests finish quickly. + @timeout_ms 300 + + # Start a second httpd with the short timeout; override :port in context. + setup do + port = find_free_tcp_port() + config = [{[], %{handler: TestEchoHandler, handler_config: %{test_pid: self()}}}] + + {:ok, server} = + :httpd.start_link(:any, port, %{}, %{request_timeout: @timeout_ms}, config) + + Process.sleep(20) + + on_exit(fn -> + if Process.alive?(server), do: :httpd.stop(server) + end) + + {:ok, port: port} + end + + test "closes socket when request headers are never completed", %{port: port} do + {:ok, socket} = connect(port) + # Send an incomplete request — no \r\n\r\n header terminator, so the + # server buffers and starts the request timer. + :ok = :gen_tcp.send(socket, "GET / HTTP/1.1\r\nHost: example.com\r\n") + + # Wait well past the configured timeout and expect the server to close. + Process.sleep(@timeout_ms + 200) + assert {:error, :closed} = :gen_tcp.recv(socket, 0, 500) + :gen_tcp.close(socket) + end + + test "closes socket when declared body is never fully delivered", %{port: port} do + {:ok, socket} = connect(port) + # Headers are complete, but Content-Length claims 100 bytes and we only + # send 5. The server should start waiting for the rest and time out. + :ok = + :gen_tcp.send( + socket, + "POST / HTTP/1.1\r\nHost: example.com\r\ncontent-length: 100\r\n\r\nhello" + ) + + Process.sleep(@timeout_ms + 200) + assert {:error, :closed} = :gen_tcp.recv(socket, 0, 500) + :gen_tcp.close(socket) + end + end + defp connect(port) do :gen_tcp.connect(~c"localhost", port, [:binary, active: false, packet: :raw]) end diff --git a/test/httpd_websocket_test.exs b/test/httpd_websocket_test.exs index 3216a04..dae8bd5 100644 --- a/test/httpd_websocket_test.exs +++ b/test/httpd_websocket_test.exs @@ -57,9 +57,10 @@ defmodule HttpdWebsocketTest do # Receive complete upgrade response response = read_http_response(socket) assert response =~ "HTTP/1.1 101 Switching Protocols" - assert response =~ "Upgrade: websocket" - assert response =~ "Connection: Upgrade" - assert response =~ "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + # Response headers are normalized to lowercase keys by the server. + assert response =~ "upgrade: websocket" + assert response =~ "connection: Upgrade" + assert response =~ "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" # Verify handler received init assert_receive {:ws_init, _websocket, _path}, @receive_timeout