From 60770c966389b0a842843af91f1b2e1d6f6f9f10 Mon Sep 17 00:00:00 2001 From: "agentfarmx[bot]" <198411105+agentfarmx[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 03:39:13 +0000 Subject: [PATCH 1/3] feat: implement HTTP server and WebSocket support for web framework --- src/framework/core.zig | 121 ++++++++++++++++++++++- src/framework/router.zig | 55 ++++++++++- src/framework/server.zig | 204 +++++++++++++++++++++++++++++++++++++-- src/main.zig | 196 +++++++++++++++++++++++++++---------- src/root.zig | 14 ++- src/websocket.zig | 39 +++++--- 6 files changed, 547 insertions(+), 82 deletions(-) diff --git a/src/framework/core.zig b/src/framework/core.zig index bdca9c7..32cfdd4 100644 --- a/src/framework/core.zig +++ b/src/framework/core.zig @@ -1,11 +1,130 @@ const std = @import("std"); +pub const Method = enum { + GET, + POST, + PUT, + DELETE, + OPTIONS, + HEAD, + PATCH, + + pub fn fromString(method_str: []const u8) ?Method { + if (std.mem.eql(u8, method_str, "GET")) return .GET; + if (std.mem.eql(u8, method_str, "POST")) return .POST; + if (std.mem.eql(u8, method_str, "PUT")) return .PUT; + if (std.mem.eql(u8, method_str, "DELETE")) return .DELETE; + if (std.mem.eql(u8, method_str, "OPTIONS")) return .OPTIONS; + if (std.mem.eql(u8, method_str, "HEAD")) return .HEAD; + if (std.mem.eql(u8, method_str, "PATCH")) return .PATCH; + return null; + } + + pub fn toString(self: Method) []const u8 { + return switch (self) { + .GET => "GET", + .POST => "POST", + .PUT => "PUT", + .DELETE => "DELETE", + .OPTIONS => "OPTIONS", + .HEAD => "HEAD", + .PATCH => "PATCH", + }; + } +}; + +pub const Request = struct { + method: Method, + path: []const u8, + headers: std.StringHashMap([]const u8), + body: ?[]const u8, + + pub fn init(allocator: std.mem.Allocator) Request { + return .{ + .method = .GET, + .path = "", + .headers = std.StringHashMap([]const u8).init(allocator), + .body = null, + }; + } + + pub fn deinit(self: *Request) void { + self.headers.deinit(); + if (self.body) |body| { + self.headers.allocator.free(body); + } + } +}; + +pub const Response = struct { + status: u16, + headers: std.StringHashMap([]const u8), + body: ?[]const u8, + + pub fn init(allocator: std.mem.Allocator) Response { + return .{ + .status = 200, + .headers = std.StringHashMap([]const u8).init(allocator), + .body = null, + }; + } + + pub fn deinit(self: *Response) void { + self.headers.deinit(); + if (self.body) |body| { + self.headers.allocator.free(body); + } + } +}; + pub const Context = struct { allocator: std.mem.Allocator, - + request: Request, + response: Response, + params: std.StringHashMap([]const u8), + pub fn init(allocator: std.mem.Allocator) Context { return .{ .allocator = allocator, + .request = Request.init(allocator), + .response = Response.init(allocator), + .params = std.StringHashMap([]const u8).init(allocator), }; } + + pub fn deinit(self: *Context) void { + self.request.deinit(); + self.response.deinit(); + self.params.deinit(); + } }; + +pub const Handler = *const fn (*Context) anyerror!void; + +pub const Middleware = struct { + data: ?*anyopaque, + handle_fn: *const fn (*Context, Handler) anyerror!void, + deinit_fn: ?*const fn (*Middleware) void, + + pub fn init( + data: ?*anyopaque, + handle_fn: *const fn (*Context, Handler) anyerror!void, + deinit_fn: ?*const fn (*Middleware) void, + ) Middleware { + return .{ + .data = data, + .handle_fn = handle_fn, + .deinit_fn = deinit_fn, + }; + } + + pub fn handle(self: *const Middleware, ctx: *Context, next: Handler) !void { + return self.handle_fn(ctx, next); + } + + pub fn deinit(self: *Middleware) void { + if (self.deinit_fn) |deinit_fn| { + deinit_fn(self); + } + } +}; \ No newline at end of file diff --git a/src/framework/router.zig b/src/framework/router.zig index 45dae38..6a59652 100644 --- a/src/framework/router.zig +++ b/src/framework/router.zig @@ -78,6 +78,9 @@ pub const Router = struct { // Find matching route const route = self.findRoute(ctx.request.method, ctx.request.path) orelse return error.RouteNotFound; + // Extract path parameters + try self.extractParams(ctx, route.pattern, ctx.request.path); + // If no middleware, just call the handler if (self.global_middleware.items.len == 0) { return route.handler(ctx); @@ -103,22 +106,64 @@ pub const Router = struct { fn matchPattern(self: *Router, pattern: []const u8, path: []const u8) bool { _ = self; + + // Handle root path + if (std.mem.eql(u8, pattern, "/") and std.mem.eql(u8, path, "/")) { + return true; + } + var pattern_parts = std.mem.split(u8, pattern, "/"); var path_parts = std.mem.split(u8, path, "/"); - + + // Skip empty first part if path starts with "/" + if (pattern.len > 0 and pattern[0] == '/') _ = pattern_parts.next(); + if (path.len > 0 and path[0] == '/') _ = path_parts.next(); + while (true) { const pattern_part = pattern_parts.next() orelse { + // If we've reached the end of the pattern, the match is successful + // only if we've also reached the end of the path return path_parts.next() == null; }; - const path_part = path_parts.next() orelse return false; - + + const path_part = path_parts.next() orelse { + // If we've reached the end of the path but not the pattern, + // the match fails + return false; + }; + + // Handle path parameters (starting with ":") if (std.mem.startsWith(u8, pattern_part, ":")) { + // This is a path parameter, it matches any path part continue; } - + + // For regular path parts, they must match exactly if (!std.mem.eql(u8, pattern_part, path_part)) { return false; } } } -}; + + fn extractParams(self: *Router, ctx: *core.Context, pattern: []const u8, path: []const u8) !void { + _ = self; + + var pattern_parts = std.mem.split(u8, pattern, "/"); + var path_parts = std.mem.split(u8, path, "/"); + + // Skip empty first part if path starts with "/" + if (pattern.len > 0 and pattern[0] == '/') _ = pattern_parts.next(); + if (path.len > 0 and path[0] == '/') _ = path_parts.next(); + + while (true) { + const pattern_part = pattern_parts.next() orelse break; + const path_part = path_parts.next() orelse break; + + // Extract parameter if pattern part starts with ":" + if (std.mem.startsWith(u8, pattern_part, ":")) { + const param_name = pattern_part[1..]; // Skip the ":" prefix + try ctx.params.put(param_name, path_part); + } + } + } +}; \ No newline at end of file diff --git a/src/framework/server.zig b/src/framework/server.zig index 6cae4d3..ab9a687 100644 --- a/src/framework/server.zig +++ b/src/framework/server.zig @@ -1,27 +1,217 @@ const std = @import("std"); const core = @import("core"); +const net = std.net; +const Thread = std.Thread; +const Atomic = std.atomic.Atomic; pub const ServerConfig = struct { - port: u16, - host: []const u8, + port: u16 = 8080, + host: []const u8 = "127.0.0.1", + max_connections: usize = 1000, + thread_count: ?usize = null, // If null, use available CPU cores }; pub const Server = struct { allocator: std.mem.Allocator, config: ServerConfig, - + address: net.Address, + listener: ?net.StreamServer = null, + running: Atomic(bool) = Atomic(bool).init(false), + threads: std.ArrayList(Thread) = undefined, + pub fn init(allocator: std.mem.Allocator, config: ServerConfig) !Server { + const address = try net.Address.parseIp(config.host, config.port); + return Server{ .allocator = allocator, .config = config, + .address = address, + .threads = std.ArrayList(Thread).init(allocator), }; } - + pub fn deinit(self: *Server) void { - _ = self; + if (self.running.load(.acquire)) { + self.stop(); + } + + if (self.listener) |*listener| { + listener.deinit(); + } + + self.threads.deinit(); } - + pub fn start(self: *Server) !void { - _ = self; + if (self.running.load(.acquire)) { + return error.ServerAlreadyRunning; + } + + // Initialize the listener + self.listener = net.StreamServer.init(.{ + .reuse_address = true, + }); + + // Bind to the address + try self.listener.?.listen(self.address); + + // Set running flag + self.running.store(true, .release); + + // Determine thread count + const thread_count = self.config.thread_count orelse try Thread.getCpuCount(); + + // Start worker threads + var i: usize = 0; + while (i < thread_count) : (i += 1) { + const thread = try Thread.spawn(.{}, workerThread, .{self}); + try self.threads.append(thread); + } + + std.log.info("Server started on {s}:{d} with {d} threads", .{ + self.config.host, + self.config.port, + thread_count + }); + } + + pub fn stop(self: *Server) void { + if (!self.running.load(.acquire)) { + return; + } + + // Set running flag to false + self.running.store(false, .release); + + // Close the listener + if (self.listener) |*listener| { + listener.close(); + } + + // Wait for all threads to finish + for (self.threads.items) |thread| { + thread.join(); + } + + self.threads.clearAndFree(); + + std.log.info("Server stopped", .{}); + } + + fn workerThread(server: *Server) !void { + while (server.running.load(.acquire)) { + // Accept a connection + const connection = server.listener.?.accept() catch |err| { + if (err == error.ConnectionAborted) { + // This happens when the server is shutting down + if (!server.running.load(.acquire)) { + break; + } + } + std.log.err("Failed to accept connection: {s}", .{@errorName(err)}); + continue; + }; + + // Handle the connection + handleConnection(server, connection) catch |err| { + std.log.err("Error handling connection: {s}", .{@errorName(err)}); + }; + } + } + + fn handleConnection(server: *Server, connection: net.StreamServer.Connection) !void { + defer connection.stream.close(); + + // Create a buffer for reading the request + var buf: [4096]u8 = undefined; + const n = try connection.stream.read(&buf); + + if (n == 0) { + return error.EmptyRequest; + } + + // Parse the request + const request = buf[0..n]; + + // Check if it's a WebSocket upgrade request + if (std.mem.indexOf(u8, request, "Upgrade: websocket") != null) { + // Handle WebSocket upgrade + try handleWebSocketUpgrade(server.allocator, connection.stream, request); + return; + } + + // Create a context for the request + var ctx = core.Context.init(server.allocator); + defer ctx.deinit(); + + // Parse the request and fill the context + try parseRequest(&ctx, request); + + // Set default response headers + try ctx.response.headers.put("Content-Type", "text/plain"); + try ctx.response.headers.put("Server", "Zup"); + + // TODO: Route the request to the appropriate handler + + // For now, just return a simple response + const response = try std.fmt.allocPrint( + server.allocator, + "HTTP/1.1 {d} OK\r\n" ++ + "Content-Type: text/plain\r\n" ++ + "Content-Length: {d}\r\n" ++ + "Connection: close\r\n" ++ + "\r\n" ++ + "Hello from Zup Server!", + .{ ctx.response.status, "Hello from Zup Server!".len }, + ); + defer server.allocator.free(response); + + _ = try connection.stream.write(response); } }; + +fn parseRequest(ctx: *core.Context, request: []const u8) !void { + // Split the request into lines + var lines = std.mem.split(u8, request, "\r\n"); + + // Parse the request line + const request_line = lines.next() orelse return error.InvalidRequest; + var parts = std.mem.split(u8, request_line, " "); + + // Get the method + const method_str = parts.next() orelse return error.InvalidRequest; + ctx.request.method = core.Method.fromString(method_str) orelse return error.UnsupportedMethod; + + // Get the path + const path = parts.next() orelse return error.InvalidRequest; + ctx.request.path = try ctx.allocator.dupe(u8, path); + + // Parse headers + while (lines.next()) |line| { + if (line.len == 0) break; // Empty line indicates end of headers + + const colon_pos = std.mem.indexOf(u8, line, ":") orelse continue; + const header_name = std.mem.trim(u8, line[0..colon_pos], " "); + const header_value = std.mem.trim(u8, line[colon_pos + 1 ..], " "); + + try ctx.request.headers.put( + try ctx.allocator.dupe(u8, header_name), + try ctx.allocator.dupe(u8, header_value), + ); + } + + // TODO: Parse body if present +} + +fn handleWebSocketUpgrade(allocator: std.mem.Allocator, stream: net.Stream, request: []const u8) !void { + // This is a placeholder for WebSocket upgrade handling + // In a real implementation, you would: + // 1. Parse the WebSocket key from the request + // 2. Generate the accept key + // 3. Send the upgrade response + // 4. Handle the WebSocket connection + + _ = allocator; + _ = stream; + _ = request; +} \ No newline at end of file diff --git a/src/main.zig b/src/main.zig index 1aec7fa..a1f99ca 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,36 +1,108 @@ const std = @import("std"); const json = std.json; -const core = @import("core"); -const GrpcRouter = @import("grpc_router").GrpcRouter; -const ArrayHashMap = std.array_hash_map.ArrayHashMap; +const framework = @import("framework"); +const core = @import("framework/core.zig"); +const Server = @import("framework/server.zig").Server; +const ServerConfig = @import("framework/server.zig").ServerConfig; +const Router = @import("framework/router.zig").Router; -// Example procedure handler that returns a greeting -fn greetingHandler(ctx: *core.Context, input: ?json.Value) !json.Value { - const name = if (input) |value| blk: { - if (value.object.get("name")) |name_value| { - break :blk name_value.string; +// Example handler that returns a greeting +fn greetingHandler(ctx: *core.Context) !void { + // Parse request body if present + var name: []const u8 = "World"; + + if (ctx.request.body) |body| { + var parser = json.Parser.init(ctx.allocator, false); + defer parser.deinit(); + + var parsed = parser.parse(body) catch |err| { + std.log.err("Failed to parse JSON: {s}", .{@errorName(err)}); + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Invalid JSON"); + return; + }; + defer parsed.deinit(); + + if (parsed.root.Object.get("name")) |name_value| { + if (name_value == .String) { + name = name_value.String; + } } - break :blk "World"; - } else "World"; - - var result = ArrayHashMap([]const u8, json.Value, std.array_hash_map.StringContext, true).init(ctx.allocator); - try result.put("message", json.Value{ .string = try std.fmt.allocPrint(ctx.allocator, "Hello, {s}!", .{name}) }); - - return json.Value{ .object = result }; + } + + // Set response + ctx.response.status = 200; + try ctx.response.headers.put("Content-Type", "application/json"); + + // Create response JSON + var response = std.ArrayList(u8).init(ctx.allocator); + defer response.deinit(); + + try std.fmt.format(response.writer(), "{{\"message\":\"Hello, {s}!\"}}", .{name}); + ctx.response.body = try ctx.allocator.dupe(u8, response.items); } -// Example procedure handler that adds two numbers -fn addHandler(ctx: *core.Context, input: ?json.Value) !json.Value { - _ = ctx; - if (input == null) return error.MissingInput; - - const a = input.?.object.get("a") orelse return error.MissingFirstNumber; - const b = input.?.object.get("b") orelse return error.MissingSecondNumber; - - if (a != .integer or b != .integer) return error.InvalidNumberFormat; +// Example handler that adds two numbers +fn addHandler(ctx: *core.Context) !void { + // Ensure we have a request body + if (ctx.request.body == null) { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Missing request body"); + return; + } + + // Parse JSON + var parser = json.Parser.init(ctx.allocator, false); + defer parser.deinit(); + + var parsed = parser.parse(ctx.request.body.?) catch |err| { + std.log.err("Failed to parse JSON: {s}", .{@errorName(err)}); + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Invalid JSON"); + return; + }; + defer parsed.deinit(); + + // Extract a and b values + const a_value = parsed.root.Object.get("a") orelse { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Missing 'a' parameter"); + return; + }; + + const b_value = parsed.root.Object.get("b") orelse { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Missing 'b' parameter"); + return; + }; + + // Ensure a and b are integers + if (a_value != .Integer or b_value != .Integer) { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Parameters 'a' and 'b' must be integers"); + return; + } + + // Calculate result + const result = a_value.Integer + b_value.Integer; + + // Set response + ctx.response.status = 200; + try ctx.response.headers.put("Content-Type", "application/json"); + + // Create response JSON + var response = std.ArrayList(u8).init(ctx.allocator); + defer response.deinit(); + + try std.fmt.format(response.writer(), "{{\"result\":{d}}}", .{result}); + ctx.response.body = try ctx.allocator.dupe(u8, response.items); +} - const result = a.integer + b.integer; - return json.Value{ .integer = result }; +// Root handler +fn rootHandler(ctx: *core.Context) !void { + ctx.response.status = 200; + try ctx.response.headers.put("Content-Type", "text/plain"); + ctx.response.body = try ctx.allocator.dupe(u8, "Welcome to Zup Server!"); } pub fn main() !void { @@ -38,41 +110,65 @@ pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); - - // Create gRPC router - var router = GrpcRouter.init(allocator); + + // Create router + var router = Router.init(allocator); defer router.deinit(); - - // Register procedures - try router.procedure("greeting", greetingHandler, null, null); - try router.procedure("add", addHandler, null, null); - - // Start server on port 8080 - std.log.info("Starting gRPC server on port 8080...", .{}); - try router.listen(8080); - + + // Register routes + try router.get("/", rootHandler); + try router.post("/greeting", greetingHandler); + try router.post("/add", addHandler); + + // Create server config + const config = ServerConfig{ + .port = 8080, + .host = "127.0.0.1", + .thread_count = 4, // Use 4 threads or set to null to use all available cores + }; + + // Create and start server + var server = try Server.init(allocator, config); + defer server.deinit(); + + std.log.info("Starting server on {s}:{d}...", .{config.host, config.port}); + + // Start the server with proper error handling + server.start() catch |err| { + std.log.err("Failed to start server: {s}", .{@errorName(err)}); + return err; + }; + // Wait for server to be ready var attempts: usize = 0; const max_attempts = 50; while (attempts < max_attempts) : (attempts += 1) { - if (router.server) |server| { - if (server.running.load(.acquire)) break; - } + if (server.running.load(.acquire)) break; std.time.sleep(100 * std.time.ns_per_ms); } - + if (attempts >= max_attempts) { - std.log.err("Server failed to start", .{}); - return error.ServerStartFailed; + std.log.err("Server failed to start within timeout period", .{}); + return error.ServerStartTimeout; } - + std.log.info("Server is running. Use Ctrl+C to stop.", .{}); - + + // Set up signal handling for graceful shutdown + const sigint = std.os.SIGINT; + _ = std.os.sigaction(sigint, &std.os.Sigaction{ + .handler = .{ .handler = handleSignal }, + .mask = std.os.empty_sigset, + .flags = 0, + }, null); + // Keep main thread alive - while (true) { - if (router.server) |server| { - if (!server.running.load(.acquire)) break; - } + while (server.running.load(.acquire)) { std.time.sleep(1000 * std.time.ns_per_ms); } } + +fn handleSignal(sig: c_int) callconv(.C) void { + std.log.info("Received signal {d}, shutting down...", .{sig}); + // The server will be stopped in the main thread +} \ No newline at end of file diff --git a/src/root.zig b/src/root.zig index 8616695..532bf74 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,5 +1,11 @@ -pub const core = @import("core"); +// Export the core framework components +pub const core = @import("framework/core.zig"); pub const framework = @import("framework"); -pub const schema = @import("schema"); -pub const runtime_router = @import("runtime_router"); -pub const grpc_router = @import("grpc_router"); +pub const websocket = @import("websocket.zig"); +pub const client = @import("client.zig"); +pub const spice = @import("spice.zig"); +pub const bench = @import("bench.zig"); +pub const benchmark = @import("benchmark.zig"); + +// Main entry point +pub const main = @import("main.zig"); \ No newline at end of file diff --git a/src/websocket.zig b/src/websocket.zig index b3ccf1e..0fd5c38 100644 --- a/src/websocket.zig +++ b/src/websocket.zig @@ -5,6 +5,15 @@ const mem = std.mem; const base64 = std.base64; const Sha1 = std.crypto.hash.Sha1; +// Debug flag - set to false in production +const debug_logging = false; + +fn debugLog(comptime fmt: []const u8, args: anytype) void { + if (debug_logging) { + std.debug.print(fmt, args); + } +} + pub const WebSocketFrame = struct { fin: bool = true, opcode: Opcode, @@ -21,13 +30,13 @@ pub const WebSocketFrame = struct { }; pub fn encode(self: WebSocketFrame, allocator: std.mem.Allocator, writer: anytype) !void { - std.debug.print("\nEncoding frame: opcode={}, mask={}, payload={s}\n", .{ self.opcode, self.mask, self.payload }); + debugLog("\nEncoding frame: opcode={}, mask={}, payload={s}\n", .{ self.opcode, self.mask, self.payload }); var first_byte: u8 = 0; if (self.fin) first_byte |= 0x80; first_byte |= @intFromEnum(self.opcode); try writer.writeByte(first_byte); - std.debug.print("First byte: {b:0>8}\n", .{first_byte}); + debugLog("First byte: {b:0>8}\n", .{first_byte}); var second_byte: u8 = 0; if (self.mask) second_byte |= 0x80; @@ -45,13 +54,13 @@ pub const WebSocketFrame = struct { try writer.writeByte(second_byte); try writer.writeInt(u64, @as(u64, @intCast(payload_len)), .big); } - std.debug.print("Second byte: {b:0>8}\n", .{second_byte}); + debugLog("Second byte: {b:0>8}\n", .{second_byte}); if (self.mask) { // Generate random masking key var masking_key: [4]u8 = undefined; std.crypto.random.bytes(&masking_key); - std.debug.print("Masking key: {any}\n", .{masking_key}); + debugLog("Masking key: {any}\n", .{masking_key}); // Write masking key try writer.writeAll(&masking_key); @@ -65,25 +74,25 @@ pub const WebSocketFrame = struct { try masked_bytes.append(masked_byte); } try writer.writeAll(masked_bytes.items); - std.debug.print("Masked payload: {any}\n", .{masked_bytes.items}); + debugLog("Masked payload: {any}\n", .{masked_bytes.items}); } else { try writer.writeAll(self.payload); - std.debug.print("Unmasked payload: {any}\n", .{self.payload}); + debugLog("Unmasked payload: {any}\n", .{self.payload}); } } pub fn decode(allocator: std.mem.Allocator, reader: anytype) !WebSocketFrame { - std.debug.print("\nDecoding frame...\n", .{}); + debugLog("\nDecoding frame...\n", .{}); const first_byte = try reader.readByte(); const fin = (first_byte & 0x80) != 0; const opcode = @as(Opcode, @enumFromInt(first_byte & 0x0F)); - std.debug.print("First byte: {b:0>8}, fin={}, opcode={}\n", .{ first_byte, fin, opcode }); + debugLog("First byte: {b:0>8}, fin={}, opcode={}\n", .{ first_byte, fin, opcode }); const second_byte = try reader.readByte(); const mask = (second_byte & 0x80) != 0; const payload_len = second_byte & 0x7F; - std.debug.print("Second byte: {b:0>8}, mask={}, initial payload_len={}\n", .{ second_byte, mask, payload_len }); + debugLog("Second byte: {b:0>8}, mask={}, initial payload_len={}\n", .{ second_byte, mask, payload_len }); const extended_payload_len: u64 = if (payload_len == 126) try reader.readInt(u16, .big) @@ -92,12 +101,12 @@ pub const WebSocketFrame = struct { else payload_len; - std.debug.print("Extended payload length: {}\n", .{extended_payload_len}); + debugLog("Extended payload length: {}\n", .{extended_payload_len}); const masking_key = if (mask) blk: { var key: [4]u8 = undefined; _ = try reader.readAll(&key); - std.debug.print("Masking key: {any}\n", .{key}); + debugLog("Masking key: {any}\n", .{key}); break :blk key; } else [_]u8{0} ** 4; @@ -108,13 +117,13 @@ pub const WebSocketFrame = struct { const n = try reader.readAll(payload); if (n != extended_payload_len) return error.InvalidFrame; - std.debug.print("Raw payload: {any}\n", .{payload}); + debugLog("Raw payload: {any}\n", .{payload}); if (mask) { for (payload, 0..) |*byte, i| { byte.* ^= masking_key[i % 4]; } - std.debug.print("Unmasked payload: {any}\n", .{payload}); + debugLog("Unmasked payload: {any}\n", .{payload}); } const frame = WebSocketFrame{ @@ -124,7 +133,7 @@ pub const WebSocketFrame = struct { .payload = payload, }; - std.debug.print("Decoded frame: opcode={}, mask={}, payload={s}\n", .{ frame.opcode, frame.mask, frame.payload }); + debugLog("Decoded frame: opcode={}, mask={}, payload={s}\n", .{ frame.opcode, frame.mask, frame.payload }); return frame; } }; @@ -252,4 +261,4 @@ pub fn readMessage(allocator: std.mem.Allocator, stream: net.Stream) !WebSocketF .mask = mask, .payload = payload, }; -} +} \ No newline at end of file From 603a1f40365c161c39542396fa512b7bcc85bb42 Mon Sep 17 00:00:00 2001 From: "agentfarmx[bot]" <198411105+agentfarmx[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 08:09:06 +0000 Subject: [PATCH 2/3] refactor: fix memory management in Request/Response deinit methods The changes properly store the allocator reference before headers are deinitalized, preventing use-after-free issues. Also improves the router implementation with better parameter extraction and adds WebSocket support. --- src/framework/core.zig | 6 +- src/framework/router.zig | 6 +- src/framework/server.zig | 132 +++++++++++++++++++++++++++++++-------- src/main.zig | 12 +++- 4 files changed, 122 insertions(+), 34 deletions(-) diff --git a/src/framework/core.zig b/src/framework/core.zig index 32cfdd4..10795b1 100644 --- a/src/framework/core.zig +++ b/src/framework/core.zig @@ -49,9 +49,10 @@ pub const Request = struct { } pub fn deinit(self: *Request) void { + var allocator = self.headers.allocator; self.headers.deinit(); if (self.body) |body| { - self.headers.allocator.free(body); + allocator.free(body); } } }; @@ -70,9 +71,10 @@ pub const Response = struct { } pub fn deinit(self: *Response) void { + var allocator = self.headers.allocator; self.headers.deinit(); if (self.body) |body| { - self.headers.allocator.free(body); + allocator.free(body); } } }; diff --git a/src/framework/router.zig b/src/framework/router.zig index 6a59652..a287997 100644 --- a/src/framework/router.zig +++ b/src/framework/router.zig @@ -79,7 +79,7 @@ pub const Router = struct { const route = self.findRoute(ctx.request.method, ctx.request.path) orelse return error.RouteNotFound; // Extract path parameters - try self.extractParams(ctx, route.pattern, ctx.request.path); + try extractParams(ctx, route.pattern, ctx.request.path); // If no middleware, just call the handler if (self.global_middleware.items.len == 0) { @@ -145,9 +145,7 @@ pub const Router = struct { } } - fn extractParams(self: *Router, ctx: *core.Context, pattern: []const u8, path: []const u8) !void { - _ = self; - + fn extractParams(ctx: *core.Context, pattern: []const u8, path: []const u8) !void { var pattern_parts = std.mem.split(u8, pattern, "/"); var path_parts = std.mem.split(u8, path, "/"); diff --git a/src/framework/server.zig b/src/framework/server.zig index ab9a687..f1d37c7 100644 --- a/src/framework/server.zig +++ b/src/framework/server.zig @@ -1,5 +1,6 @@ const std = @import("std"); const core = @import("core"); +const Router = @import("router.zig").Router; const net = std.net; const Thread = std.Thread; const Atomic = std.atomic.Atomic; @@ -19,7 +20,7 @@ pub const Server = struct { running: Atomic(bool) = Atomic(bool).init(false), threads: std.ArrayList(Thread) = undefined, - pub fn init(allocator: std.mem.Allocator, config: ServerConfig) !Server { + pub fn init(allocator: std.mem.Allocator, config: ServerConfig, router: ?*const Router) !Server { const address = try net.Address.parseIp(config.host, config.port); return Server{ @@ -27,6 +28,7 @@ pub const Server = struct { .config = config, .address = address, .threads = std.ArrayList(Thread).init(allocator), + .router = router, }; } @@ -151,22 +153,56 @@ pub const Server = struct { try ctx.response.headers.put("Content-Type", "text/plain"); try ctx.response.headers.put("Server", "Zup"); - // TODO: Route the request to the appropriate handler - - // For now, just return a simple response - const response = try std.fmt.allocPrint( - server.allocator, - "HTTP/1.1 {d} OK\r\n" ++ - "Content-Type: text/plain\r\n" ++ - "Content-Length: {d}\r\n" ++ - "Connection: close\r\n" ++ - "\r\n" ++ - "Hello from Zup Server!", - .{ ctx.response.status, "Hello from Zup Server!".len }, - ); - defer server.allocator.free(response); + // Route the request to the appropriate handler if router is available + if (server.router) |router| { + router.handle(&ctx) catch |err| { + if (err == error.RouteNotFound) { + ctx.response.status = 404; + ctx.response.body = try server.allocator.dupe(u8, "Not Found"); + } else { + ctx.response.status = 500; + ctx.response.body = try server.allocator.dupe(u8, "Internal Server Error"); + std.log.err("Error handling request: {s}", .{@errorName(err)}); + } + }; + } else { + // If no router, just return a simple response + ctx.response.body = try server.allocator.dupe(u8, "Hello from Zup Server!"); + } + + // Build and send the HTTP response + var response_buffer = std.ArrayList(u8).init(server.allocator); + defer response_buffer.deinit(); + + // Write status line + try std.fmt.format(response_buffer.writer(), "HTTP/1.1 {d} {s}\r\n", .{ + ctx.response.status, + if (ctx.response.status == 200) "OK" else "Error", + }); + + // Write headers + var header_it = ctx.response.headers.iterator(); + while (header_it.next()) |entry| { + try std.fmt.format(response_buffer.writer(), "{s}: {s}\r\n", .{ + entry.key_ptr.*, + entry.value_ptr.*, + }); + } + + // Write Content-Length header + const body_len = if (ctx.response.body) |body| body.len else 0; + try std.fmt.format(response_buffer.writer(), "Content-Length: {d}\r\n", .{body_len}); - _ = try connection.stream.write(response); + // End headers + try response_buffer.appendSlice("\r\n"); + + // Write body if present + if (ctx.response.body) |body| { + try response_buffer.appendSlice(body); + } + + // Send the response + _ = try connection.stream.write(response_buffer.items); } }; @@ -200,18 +236,62 @@ fn parseRequest(ctx: *core.Context, request: []const u8) !void { ); } - // TODO: Parse body if present + // Parse body if present + // Find the empty line that separates headers from body + var body_start: usize = 0; + const headers_end = std.mem.indexOf(u8, request, "\r\n\r\n"); + if (headers_end) |pos| { + body_start = pos + 4; // Skip the \r\n\r\n + + // Check if there's a body + if (body_start < request.len) { + const body = request[body_start..]; + if (body.len > 0) { + ctx.request.body = try ctx.allocator.dupe(u8, body); + } + } + } } fn handleWebSocketUpgrade(allocator: std.mem.Allocator, stream: net.Stream, request: []const u8) !void { - // This is a placeholder for WebSocket upgrade handling - // In a real implementation, you would: - // 1. Parse the WebSocket key from the request - // 2. Generate the accept key - // 3. Send the upgrade response - // 4. Handle the WebSocket connection + const websocket = @import("../websocket.zig"); + try websocket.handleUpgrade(allocator, stream, request); - _ = allocator; - _ = stream; - _ = request; + // After upgrade, handle the WebSocket connection + // This is a simple echo server for demonstration + while (true) { + var frame = websocket.readMessage(allocator, stream) catch |err| { + if (err == error.ConnectionClosed) { + break; + } + std.log.err("Error reading WebSocket message: {s}", .{@errorName(err)}); + break; + }; + defer allocator.free(frame.payload); + + // Handle different frame types + switch (frame.opcode) { + .text, .binary => { + // Echo the message back + try websocket.writeMessage(allocator, stream, frame.payload); + }, + .close => { + // Send close frame and exit + try websocket.writeMessage(allocator, stream, ""); + break; + }, + .ping => { + // Respond to ping with pong + const pong_frame = websocket.WebSocketFrame{ + .opcode = .pong, + .payload = frame.payload, + }; + var frame_buf: [1024]u8 = undefined; + var fbs = std.io.fixedBufferStream(&frame_buf); + try pong_frame.encode(allocator, fbs.writer()); + _ = try stream.write(fbs.getWritten()); + }, + else => {}, // Ignore other frame types + } + } } \ No newline at end of file diff --git a/src/main.zig b/src/main.zig index a1f99ca..33a6a10 100644 --- a/src/main.zig +++ b/src/main.zig @@ -6,6 +6,9 @@ const Server = @import("framework/server.zig").Server; const ServerConfig = @import("framework/server.zig").ServerConfig; const Router = @import("framework/router.zig").Router; +// Global variable to hold server reference for signal handling +var global_server: ?*Server = null; + // Example handler that returns a greeting fn greetingHandler(ctx: *core.Context) !void { // Parse request body if present @@ -128,9 +131,12 @@ pub fn main() !void { }; // Create and start server - var server = try Server.init(allocator, config); + var server = try Server.init(allocator, config, &router); defer server.deinit(); + // Set global server reference for signal handling + global_server = &server; + std.log.info("Starting server on {s}:{d}...", .{config.host, config.port}); // Start the server with proper error handling @@ -170,5 +176,7 @@ pub fn main() !void { fn handleSignal(sig: c_int) callconv(.C) void { std.log.info("Received signal {d}, shutting down...", .{sig}); - // The server will be stopped in the main thread + if (global_server) |server| { + server.stop(); + } } \ No newline at end of file From 04da0d32b3c071843fb57d3ca9793267eb28ecf1 Mon Sep 17 00:00:00 2001 From: "agentfarmx[bot]" <198411105+agentfarmx[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 17:27:03 +0000 Subject: [PATCH 3/3] refactor: update web framework architecture with Router separation The changes introduce a more modular structure by separating routing from the Server class, adding helper functions for responses, and improving tests. The PR also updates method signatures and refactors connection handling for better clarity. --- src/framework/example.zig | 108 ++++++++++++++------- src/framework/server.zig | 2 +- src/framework/server_test.zig | 17 ++-- src/framework/tests.zig | 175 +++++++++++++++++++++++----------- src/framework/trpc_test.zig | 70 +++++++++++--- 5 files changed, 264 insertions(+), 108 deletions(-) diff --git a/src/framework/example.zig b/src/framework/example.zig index f42e39f..e03c2a1 100644 --- a/src/framework/example.zig +++ b/src/framework/example.zig @@ -1,7 +1,8 @@ const std = @import("std"); const Server = @import("server.zig").Server; -const Config = @import("server.zig").Config; +const ServerConfig = @import("server.zig").ServerConfig; const core = @import("core.zig"); +const Router = @import("router.zig").Router; // Example middleware that logs requests const LoggerMiddleware = struct { @@ -27,9 +28,25 @@ const LoggerMiddleware = struct { } }; +// Helper function to set text response +fn setText(ctx: *core.Context, text: []const u8) !void { + ctx.response.body = try ctx.allocator.dupe(u8, text); + try ctx.response.headers.put("Content-Type", "text/plain"); +} + +// Helper function to set JSON response +fn setJson(ctx: *core.Context, data: anytype) !void { + var json_string = std.ArrayList(u8).init(ctx.allocator); + defer json_string.deinit(); + + try std.json.stringify(data, .{}, json_string.writer()); + ctx.response.body = try ctx.allocator.dupe(u8, json_string.items); + try ctx.response.headers.put("Content-Type", "application/json"); +} + // Example handlers fn homeHandlerImpl(ctx: *core.Context) !void { - try ctx.text("Welcome to Zup!"); + try setText(ctx, "Welcome to Zup!"); } fn jsonHandlerImpl(ctx: *core.Context) !void { @@ -37,7 +54,7 @@ fn jsonHandlerImpl(ctx: *core.Context) !void { .message = "Hello, JSON!", .timestamp = std.time.timestamp(), }; - try ctx.json(data); + try setJson(ctx, data); } fn userHandlerImpl(ctx: *core.Context) !void { @@ -47,11 +64,15 @@ fn userHandlerImpl(ctx: *core.Context) !void { .name = "Example User", .email = "user@example.com", }; - try ctx.json(response); + try setJson(ctx, response); } fn echoHandlerImpl(ctx: *core.Context) !void { - try ctx.text(ctx.request.body); + if (ctx.request.body) |body| { + try setText(ctx, body); + } else { + try setText(ctx, "No body provided"); + } } pub fn main() !void { @@ -60,24 +81,28 @@ pub fn main() !void { defer _ = gpa.deinit(); const allocator = gpa.allocator(); - // Create server with custom config - var server = try Server.init(allocator, .{ - .address = "127.0.0.1", - .port = 8080, - .thread_count = 4, - }); - defer server.deinit(); + // Create router + var router = Router.init(allocator); + defer router.deinit(); // Add global middleware var logger = LoggerMiddleware.init(); defer logger.deinit(); - try server.use(core.Middleware.init(logger, LoggerMiddleware.handle)); + try router.use(core.Middleware.init(logger, LoggerMiddleware.handle)); // Define routes - try server.get("/", homeHandlerImpl); - try server.get("/json", jsonHandlerImpl); - try server.get("/users/:id", userHandlerImpl); - try server.post("/echo", echoHandlerImpl); + try router.get("/", homeHandlerImpl); + try router.get("/json", jsonHandlerImpl); + try router.get("/users/:id", userHandlerImpl); + try router.post("/echo", echoHandlerImpl); + + // Create server with custom config + var server = try Server.init(allocator, .{ + .host = "127.0.0.1", + .port = 8080, + .thread_count = 4, + }, &router); + defer server.deinit(); // Start server std.log.info("Server running at http://127.0.0.1:8080", .{}); @@ -88,33 +113,46 @@ test "basic routes" { const testing = std.testing; const allocator = testing.allocator; - var server = try Server.init(allocator, .{ - .port = 0, // Random port for testing - }); - defer server.deinit(); + // Create router + var router = Router.init(allocator); + defer router.deinit(); // Add test routes - try server.get("/test", &struct { + try router.get("/test", &struct { fn handler(ctx: *core.Context) !void { - try ctx.text("test ok"); + try setText(ctx, "test ok"); } }.handler); - try server.post("/echo", echoHandlerImpl); + try router.post("/echo", echoHandlerImpl); + + // Create server + var server = try Server.init(allocator, .{ + .port = 0, // Random port for testing + .thread_count = 1, + }, &router); + defer server.deinit(); // Start server in background - const thread = try std.Thread.spawn(.{}, Server.start, .{&server}); - defer { - server.running.store(false, .release); - thread.join(); - } + var running = true; + const thread = try std.Thread.spawn(.{}, struct { + fn run(srv: *Server, is_running: *bool) void { + srv.start() catch |err| { + std.debug.print("Server error: {}\n", .{err}); + }; + is_running.* = false; + } + }.run, .{&server, &running}); // Wait a bit for server to start - std.time.sleep(10 * std.time.ns_per_ms); + std.time.sleep(100 * std.time.ns_per_ms); + + // Get server address + const server_address = server.address; // Test GET request { - const client = try std.net.tcpConnectToAddress(server.listener.listen_address); + const client = try std.net.tcpConnectToAddress(server_address); defer client.close(); try client.writer().writeAll( @@ -134,7 +172,7 @@ test "basic routes" { // Test POST request { - const client = try std.net.tcpConnectToAddress(server.listener.listen_address); + const client = try std.net.tcpConnectToAddress(server_address); defer client.close(); const body = "Hello, Echo!"; @@ -156,4 +194,8 @@ test "basic routes" { try testing.expect(std.mem.indexOf(u8, response, body) != null); } -} + + // Stop server + server.stop(); + thread.join(); +} \ No newline at end of file diff --git a/src/framework/server.zig b/src/framework/server.zig index f1d37c7..a64b78b 100644 --- a/src/framework/server.zig +++ b/src/framework/server.zig @@ -206,7 +206,7 @@ pub const Server = struct { } }; -fn parseRequest(ctx: *core.Context, request: []const u8) !void { +pub fn parseRequest(ctx: *core.Context, request: []const u8) !void { // Split the request into lines var lines = std.mem.split(u8, request, "\r\n"); diff --git a/src/framework/server_test.zig b/src/framework/server_test.zig index 674b70f..1504952 100644 --- a/src/framework/server_test.zig +++ b/src/framework/server_test.zig @@ -1,20 +1,27 @@ const std = @import("std"); const testing = std.testing; const Server = @import("server.zig").Server; +const Router = @import("router.zig").Router; test "server - basic start stop" { std.debug.print("\n=== Starting basic server test ===\n", .{}); + // Create a router + var router = Router.init(testing.allocator); + defer router.deinit(); + + // Initialize server with the router var server = try Server.init(testing.allocator, .{ .port = 0, .thread_count = 1, // Minimize threads for testing - }); + }, &router); + defer server.deinit(); var running = true; const thread = try std.Thread.spawn(.{}, struct { fn run(srv: *Server, is_running: *bool) void { std.debug.print("Server thread starting...\n", .{}); - srv.listen() catch |err| { + srv.start() catch |err| { std.debug.print("Server error: {}\n", .{err}); }; is_running.* = false; @@ -42,8 +49,4 @@ test "server - basic start stop" { thread.join(); std.debug.print("Server thread joined\n", .{}); - - // Clean up resources - server.deinit(); - std.debug.print("Server resources cleaned up\n", .{}); -} +} \ No newline at end of file diff --git a/src/framework/tests.zig b/src/framework/tests.zig index c63c90d..406b7f8 100644 --- a/src/framework/tests.zig +++ b/src/framework/tests.zig @@ -2,6 +2,45 @@ const std = @import("std"); const testing = std.testing; const core = @import("core.zig"); const Server = @import("server.zig").Server; +const Router = @import("router.zig").Router; + +// Helper function to parse a request since core.Request.parse() no longer exists +fn parseTestRequest(allocator: std.mem.Allocator, raw_request: []const u8) !core.Request { + var ctx = core.Context.init(allocator); + defer ctx.deinit(); + + try parseRequest(&ctx, raw_request); + + // Create a new request to return (since we're deinit'ing the context) + var request = core.Request.init(allocator); + request.method = ctx.request.method; + request.path = try allocator.dupe(u8, ctx.request.path); + + // Copy headers + var it = ctx.request.headers.iterator(); + while (it.next()) |entry| { + try request.headers.put( + try allocator.dupe(u8, entry.key_ptr.*), + try allocator.dupe(u8, entry.value_ptr.*) + ); + } + + // Copy body if present + if (ctx.request.body) |body| { + request.body = try allocator.dupe(u8, body); + } + + return request; +} + +// Import the parseRequest function from server.zig +const parseRequest = @import("server.zig").parseRequest; + +// Helper function to set text response +fn setText(ctx: *core.Context, text: []const u8) !void { + ctx.response.body = try ctx.allocator.dupe(u8, text); + try ctx.response.headers.put("Content-Type", "text/plain"); +} test "memory safety - request parsing" { const allocator = testing.allocator; @@ -16,7 +55,7 @@ test "memory safety - request parsing" { std.debug.print("\nRaw request:\n{s}\n", .{raw_request}); - var request = try core.Request.parse(allocator, raw_request); + var request = try parseTestRequest(allocator, raw_request); defer { std.debug.print("\nDeinit request\n", .{}); request.deinit(); @@ -24,7 +63,7 @@ test "memory safety - request parsing" { try testing.expectEqualStrings("/test", request.path); try testing.expectEqualStrings("localhost:8080", request.headers.get("Host").?); - try testing.expectEqual(@as(usize, 0), request.body.len); + try testing.expect(request.body == null); } test "memory safety - response handling" { @@ -40,9 +79,9 @@ test "memory safety - response handling" { // Add a small body first std.debug.print("\nTesting small body\n", .{}); { - const small_body = try core.Request.allocBody(allocator, "Hello"); - response.setBody(small_body); - try testing.expectEqualStrings("Hello", response.body); + const small_body = try allocator.dupe(u8, "Hello"); + response.body = small_body; + try testing.expectEqualStrings("Hello", response.body.?); } // Now test with a larger body @@ -52,10 +91,16 @@ test "memory safety - response handling" { defer allocator.free(large_body); @memset(large_body, 'A'); - const body = try core.Request.allocBody(allocator, large_body); - response.setBody(body); - try testing.expectEqual(@as(usize, 1024), response.body.len); - try testing.expect(response.body[0] == 'A'); + const body_copy = try allocator.dupe(u8, large_body); + + // Free previous body + if (response.body) |old_body| { + allocator.free(old_body); + } + + response.body = body_copy; + try testing.expectEqual(@as(usize, 1024), response.body.?.len); + try testing.expect(response.body.?[0] == 'A'); } std.debug.print("\nLarge body test complete\n", .{}); @@ -64,40 +109,47 @@ test "memory safety - response handling" { test "single request" { const allocator = testing.allocator; - std.debug.print("\nInit server\n", .{}); - var server = try Server.init(allocator, .{ - .port = 0, // Random port for testing - .thread_count = 1, - }); - defer { - std.debug.print("\nDeinit server\n", .{}); - server.deinit(); - } - + std.debug.print("\nInit router\n", .{}); + var router = Router.init(allocator); + defer router.deinit(); + // Add test endpoint - try server.get("/test", &struct { + try router.get("/test", &struct { fn handler(ctx: *core.Context) !void { std.debug.print("\nHandling request\n", .{}); - try ctx.text("ok"); + try setText(ctx, "ok"); std.debug.print("\nRequest handled\n", .{}); } }.handler); + std.debug.print("\nInit server\n", .{}); + var server = try Server.init(allocator, .{ + .port = 0, // Random port for testing + .thread_count = 1, + }, &router); + defer server.deinit(); + // Start server in background std.debug.print("\nStarting server\n", .{}); - const thread = try std.Thread.spawn(.{}, Server.start, .{&server}); - defer { - std.debug.print("\nStopping server\n", .{}); - server.running.store(false, .release); - thread.join(); - } - + var running = true; + const thread = try std.Thread.spawn(.{}, struct { + fn run(srv: *Server, is_running: *bool) void { + srv.start() catch |err| { + std.debug.print("Server error: {}\n", .{err}); + }; + is_running.* = false; + } + }.run, .{&server, &running}); + // Wait for server to start - std.time.sleep(10 * std.time.ns_per_ms); - + std.time.sleep(100 * std.time.ns_per_ms); + + // Get server address + const server_address = server.address; + std.debug.print("\nMaking request to {}\n", .{server_address}); + // Make request - std.debug.print("\nMaking request to {}\n", .{server.listener.listen_address}); - const client = try std.net.tcpConnectToAddress(server.listener.listen_address); + const client = try std.net.tcpConnectToAddress(server_address); defer client.close(); const request = @@ -118,43 +170,54 @@ test "single request" { try testing.expect(std.mem.indexOf(u8, response, "ok") != null); std.debug.print("\nTest complete\n", .{}); + + // Stop server + server.stop(); + thread.join(); } test "concurrent requests" { const allocator = testing.allocator; - std.debug.print("\nInit server\n", .{}); - var server = try Server.init(allocator, .{ - .port = 0, // Random port for testing - .thread_count = 4, - }); - defer { - std.debug.print("\nDeinit server\n", .{}); - server.deinit(); - } - + std.debug.print("\nInit router\n", .{}); + var router = Router.init(allocator); + defer router.deinit(); + // Add test endpoint - try server.get("/concurrent", &struct { + try router.get("/concurrent", &struct { fn handler(ctx: *core.Context) !void { std.debug.print("\nHandling request in worker thread\n", .{}); // Simulate work std.time.sleep(10 * std.time.ns_per_ms); - try ctx.text("ok"); + try setText(ctx, "ok"); std.debug.print("\nRequest handled\n", .{}); } }.handler); + std.debug.print("\nInit server\n", .{}); + var server = try Server.init(allocator, .{ + .port = 0, // Random port for testing + .thread_count = 4, + }, &router); + defer server.deinit(); + // Start server in background std.debug.print("\nStarting server\n", .{}); - const thread = try std.Thread.spawn(.{}, Server.start, .{&server}); - defer { - std.debug.print("\nStopping server\n", .{}); - server.running.store(false, .release); - thread.join(); - } - + var running = true; + const thread = try std.Thread.spawn(.{}, struct { + fn run(srv: *Server, is_running: *bool) void { + srv.start() catch |err| { + std.debug.print("Server error: {}\n", .{err}); + }; + is_running.* = false; + } + }.run, .{&server, &running}); + // Wait for server to start - std.time.sleep(10 * std.time.ns_per_ms); + std.time.sleep(100 * std.time.ns_per_ms); + + // Get server address + const server_address = server.address; // Make concurrent requests const RequestThread = struct { @@ -186,11 +249,15 @@ test "concurrent requests" { std.debug.print("\nStarting concurrent requests\n", .{}); var threads: [2]std.Thread = undefined; for (&threads) |*t| { - t.* = try std.Thread.spawn(.{}, RequestThread.make_request, .{server.listener.listen_address}); + t.* = try std.Thread.spawn(.{}, RequestThread.make_request, .{server_address}); } for (threads) |t| { t.join(); } std.debug.print("\nConcurrent requests complete\n", .{}); -} + + // Stop server + server.stop(); + thread.join(); +} \ No newline at end of file diff --git a/src/framework/trpc_test.zig b/src/framework/trpc_test.zig index 52328c2..a7b2ec7 100644 --- a/src/framework/trpc_test.zig +++ b/src/framework/trpc_test.zig @@ -1,22 +1,57 @@ const std = @import("std"); const testing = std.testing; const Server = @import("server.zig").Server; +const Router = @import("router.zig").Router; const core = @import("core.zig"); +const parseRequest = @import("server.zig").parseRequest; + +// Helper function to parse a request +fn parseTestRequest(allocator: std.mem.Allocator, raw_request: []const u8) !core.Request { + var ctx = core.Context.init(allocator); + defer ctx.deinit(); + + try parseRequest(&ctx, raw_request); + + // Create a new request to return (since we're deinit'ing the context) + var request = core.Request.init(allocator); + request.method = ctx.request.method; + request.path = try allocator.dupe(u8, ctx.request.path); + + // Copy headers + var it = ctx.request.headers.iterator(); + while (it.next()) |entry| { + try request.headers.put( + try allocator.dupe(u8, entry.key_ptr.*), + try allocator.dupe(u8, entry.value_ptr.*) + ); + } + + // Copy body if present + if (ctx.request.body) |body| { + request.body = try allocator.dupe(u8, body); + } + + return request; +} test "trpc - basic procedure call" { std.debug.print("\n=== Starting TRPC basic procedure call test ===\n", .{}); + // Create a router + var router = Router.init(testing.allocator); + defer router.deinit(); + var server = try Server.init(testing.allocator, .{ .port = 0, .thread_count = 1, // Minimize threads for testing - }); + }, &router); defer server.deinit(); var running = true; const thread = try std.Thread.spawn(.{}, struct { fn run(srv: *Server, is_running: *bool) void { std.debug.print("Server thread starting...\n", .{}); - srv.listen() catch |err| { + srv.start() catch |err| { std.debug.print("Server error: {}\n", .{err}); }; is_running.* = false; @@ -29,14 +64,27 @@ test "trpc - basic procedure call" { std.time.sleep(100 * std.time.ns_per_ms); // Make a request to the server - const client = try std.net.Client.init(testing.allocator); - defer client.deinit(); + // Note: std.net.Client might not exist, so we'll use direct TCP connection instead + const server_address = server.address; + const client = try std.net.tcpConnectToAddress(server_address); + defer client.close(); + + const request_str = + \\POST /trpc HTTP/1.1 + \\Host: localhost + \\Content-Type: application/json + \\Content-Length: 57 + \\ + \\{"method":"add","params":{"a":1,"b":2},"id":1} + ; - const request = try core.Request.parse(testing.allocator, "POST /trpc HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/json\r\nContent-Length: 57\r\n\r\n{\"method\":\"add\",\"params\":{\"a\":1,\"b\":2},\"id\":1}"); - defer request.deinit(); + std.debug.print("Sending request:\n{s}\n", .{request_str}); + try client.writer().writeAll(request_str); - const response = try client.request(request); - defer response.deinit(); + var buf: [1024]u8 = undefined; + const n = try client.read(&buf); + const response = buf[0..n]; + std.debug.print("Received response:\n{s}\n", .{response}); // Stop server std.debug.print("Stopping server...\n", .{}); @@ -54,8 +102,4 @@ test "trpc - basic procedure call" { thread.join(); std.debug.print("Server thread joined\n", .{}); - - // Clean up resources - server.deinit(); - std.debug.print("Server resources cleaned up\n", .{}); -} +} \ No newline at end of file