diff --git a/src/framework/core.zig b/src/framework/core.zig index bdca9c7..10795b1 100644 --- a/src/framework/core.zig +++ b/src/framework/core.zig @@ -1,11 +1,132 @@ const std = @import("std"); +pub const Method = enum { + GET, + POST, + PUT, + DELETE, + OPTIONS, + HEAD, + PATCH, + + pub fn fromString(method_str: []const u8) ?Method { + if (std.mem.eql(u8, method_str, "GET")) return .GET; + if (std.mem.eql(u8, method_str, "POST")) return .POST; + if (std.mem.eql(u8, method_str, "PUT")) return .PUT; + if (std.mem.eql(u8, method_str, "DELETE")) return .DELETE; + if (std.mem.eql(u8, method_str, "OPTIONS")) return .OPTIONS; + if (std.mem.eql(u8, method_str, "HEAD")) return .HEAD; + if (std.mem.eql(u8, method_str, "PATCH")) return .PATCH; + return null; + } + + pub fn toString(self: Method) []const u8 { + return switch (self) { + .GET => "GET", + .POST => "POST", + .PUT => "PUT", + .DELETE => "DELETE", + .OPTIONS => "OPTIONS", + .HEAD => "HEAD", + .PATCH => "PATCH", + }; + } +}; + +pub const Request = struct { + method: Method, + path: []const u8, + headers: std.StringHashMap([]const u8), + body: ?[]const u8, + + pub fn init(allocator: std.mem.Allocator) Request { + return .{ + .method = .GET, + .path = "", + .headers = std.StringHashMap([]const u8).init(allocator), + .body = null, + }; + } + + pub fn deinit(self: *Request) void { + var allocator = self.headers.allocator; + self.headers.deinit(); + if (self.body) |body| { + allocator.free(body); + } + } +}; + +pub const Response = struct { + status: u16, + headers: std.StringHashMap([]const u8), + body: ?[]const u8, + + pub fn init(allocator: std.mem.Allocator) Response { + return .{ + .status = 200, + .headers = std.StringHashMap([]const u8).init(allocator), + .body = null, + }; + } + + pub fn deinit(self: *Response) void { + var allocator = self.headers.allocator; + self.headers.deinit(); + if (self.body) |body| { + allocator.free(body); + } + } +}; + pub const Context = struct { allocator: std.mem.Allocator, - + request: Request, + response: Response, + params: std.StringHashMap([]const u8), + pub fn init(allocator: std.mem.Allocator) Context { return .{ .allocator = allocator, + .request = Request.init(allocator), + .response = Response.init(allocator), + .params = std.StringHashMap([]const u8).init(allocator), }; } + + pub fn deinit(self: *Context) void { + self.request.deinit(); + self.response.deinit(); + self.params.deinit(); + } }; + +pub const Handler = *const fn (*Context) anyerror!void; + +pub const Middleware = struct { + data: ?*anyopaque, + handle_fn: *const fn (*Context, Handler) anyerror!void, + deinit_fn: ?*const fn (*Middleware) void, + + pub fn init( + data: ?*anyopaque, + handle_fn: *const fn (*Context, Handler) anyerror!void, + deinit_fn: ?*const fn (*Middleware) void, + ) Middleware { + return .{ + .data = data, + .handle_fn = handle_fn, + .deinit_fn = deinit_fn, + }; + } + + pub fn handle(self: *const Middleware, ctx: *Context, next: Handler) !void { + return self.handle_fn(ctx, next); + } + + pub fn deinit(self: *Middleware) void { + if (self.deinit_fn) |deinit_fn| { + deinit_fn(self); + } + } +}; \ No newline at end of file diff --git a/src/framework/example.zig b/src/framework/example.zig index f42e39f..e03c2a1 100644 --- a/src/framework/example.zig +++ b/src/framework/example.zig @@ -1,7 +1,8 @@ const std = @import("std"); const Server = @import("server.zig").Server; -const Config = @import("server.zig").Config; +const ServerConfig = @import("server.zig").ServerConfig; const core = @import("core.zig"); +const Router = @import("router.zig").Router; // Example middleware that logs requests const LoggerMiddleware = struct { @@ -27,9 +28,25 @@ const LoggerMiddleware = struct { } }; +// Helper function to set text response +fn setText(ctx: *core.Context, text: []const u8) !void { + ctx.response.body = try ctx.allocator.dupe(u8, text); + try ctx.response.headers.put("Content-Type", "text/plain"); +} + +// Helper function to set JSON response +fn setJson(ctx: *core.Context, data: anytype) !void { + var json_string = std.ArrayList(u8).init(ctx.allocator); + defer json_string.deinit(); + + try std.json.stringify(data, .{}, json_string.writer()); + ctx.response.body = try ctx.allocator.dupe(u8, json_string.items); + try ctx.response.headers.put("Content-Type", "application/json"); +} + // Example handlers fn homeHandlerImpl(ctx: *core.Context) !void { - try ctx.text("Welcome to Zup!"); + try setText(ctx, "Welcome to Zup!"); } fn jsonHandlerImpl(ctx: *core.Context) !void { @@ -37,7 +54,7 @@ fn jsonHandlerImpl(ctx: *core.Context) !void { .message = "Hello, JSON!", .timestamp = std.time.timestamp(), }; - try ctx.json(data); + try setJson(ctx, data); } fn userHandlerImpl(ctx: *core.Context) !void { @@ -47,11 +64,15 @@ fn userHandlerImpl(ctx: *core.Context) !void { .name = "Example User", .email = "user@example.com", }; - try ctx.json(response); + try setJson(ctx, response); } fn echoHandlerImpl(ctx: *core.Context) !void { - try ctx.text(ctx.request.body); + if (ctx.request.body) |body| { + try setText(ctx, body); + } else { + try setText(ctx, "No body provided"); + } } pub fn main() !void { @@ -60,24 +81,28 @@ pub fn main() !void { defer _ = gpa.deinit(); const allocator = gpa.allocator(); - // Create server with custom config - var server = try Server.init(allocator, .{ - .address = "127.0.0.1", - .port = 8080, - .thread_count = 4, - }); - defer server.deinit(); + // Create router + var router = Router.init(allocator); + defer router.deinit(); // Add global middleware var logger = LoggerMiddleware.init(); defer logger.deinit(); - try server.use(core.Middleware.init(logger, LoggerMiddleware.handle)); + try router.use(core.Middleware.init(logger, LoggerMiddleware.handle)); // Define routes - try server.get("/", homeHandlerImpl); - try server.get("/json", jsonHandlerImpl); - try server.get("/users/:id", userHandlerImpl); - try server.post("/echo", echoHandlerImpl); + try router.get("/", homeHandlerImpl); + try router.get("/json", jsonHandlerImpl); + try router.get("/users/:id", userHandlerImpl); + try router.post("/echo", echoHandlerImpl); + + // Create server with custom config + var server = try Server.init(allocator, .{ + .host = "127.0.0.1", + .port = 8080, + .thread_count = 4, + }, &router); + defer server.deinit(); // Start server std.log.info("Server running at http://127.0.0.1:8080", .{}); @@ -88,33 +113,46 @@ test "basic routes" { const testing = std.testing; const allocator = testing.allocator; - var server = try Server.init(allocator, .{ - .port = 0, // Random port for testing - }); - defer server.deinit(); + // Create router + var router = Router.init(allocator); + defer router.deinit(); // Add test routes - try server.get("/test", &struct { + try router.get("/test", &struct { fn handler(ctx: *core.Context) !void { - try ctx.text("test ok"); + try setText(ctx, "test ok"); } }.handler); - try server.post("/echo", echoHandlerImpl); + try router.post("/echo", echoHandlerImpl); + + // Create server + var server = try Server.init(allocator, .{ + .port = 0, // Random port for testing + .thread_count = 1, + }, &router); + defer server.deinit(); // Start server in background - const thread = try std.Thread.spawn(.{}, Server.start, .{&server}); - defer { - server.running.store(false, .release); - thread.join(); - } + var running = true; + const thread = try std.Thread.spawn(.{}, struct { + fn run(srv: *Server, is_running: *bool) void { + srv.start() catch |err| { + std.debug.print("Server error: {}\n", .{err}); + }; + is_running.* = false; + } + }.run, .{&server, &running}); // Wait a bit for server to start - std.time.sleep(10 * std.time.ns_per_ms); + std.time.sleep(100 * std.time.ns_per_ms); + + // Get server address + const server_address = server.address; // Test GET request { - const client = try std.net.tcpConnectToAddress(server.listener.listen_address); + const client = try std.net.tcpConnectToAddress(server_address); defer client.close(); try client.writer().writeAll( @@ -134,7 +172,7 @@ test "basic routes" { // Test POST request { - const client = try std.net.tcpConnectToAddress(server.listener.listen_address); + const client = try std.net.tcpConnectToAddress(server_address); defer client.close(); const body = "Hello, Echo!"; @@ -156,4 +194,8 @@ test "basic routes" { try testing.expect(std.mem.indexOf(u8, response, body) != null); } -} + + // Stop server + server.stop(); + thread.join(); +} \ No newline at end of file diff --git a/src/framework/router.zig b/src/framework/router.zig index 45dae38..a287997 100644 --- a/src/framework/router.zig +++ b/src/framework/router.zig @@ -78,6 +78,9 @@ pub const Router = struct { // Find matching route const route = self.findRoute(ctx.request.method, ctx.request.path) orelse return error.RouteNotFound; + // Extract path parameters + try extractParams(ctx, route.pattern, ctx.request.path); + // If no middleware, just call the handler if (self.global_middleware.items.len == 0) { return route.handler(ctx); @@ -103,22 +106,62 @@ pub const Router = struct { fn matchPattern(self: *Router, pattern: []const u8, path: []const u8) bool { _ = self; + + // Handle root path + if (std.mem.eql(u8, pattern, "/") and std.mem.eql(u8, path, "/")) { + return true; + } + var pattern_parts = std.mem.split(u8, pattern, "/"); var path_parts = std.mem.split(u8, path, "/"); - + + // Skip empty first part if path starts with "/" + if (pattern.len > 0 and pattern[0] == '/') _ = pattern_parts.next(); + if (path.len > 0 and path[0] == '/') _ = path_parts.next(); + while (true) { const pattern_part = pattern_parts.next() orelse { + // If we've reached the end of the pattern, the match is successful + // only if we've also reached the end of the path return path_parts.next() == null; }; - const path_part = path_parts.next() orelse return false; - + + const path_part = path_parts.next() orelse { + // If we've reached the end of the path but not the pattern, + // the match fails + return false; + }; + + // Handle path parameters (starting with ":") if (std.mem.startsWith(u8, pattern_part, ":")) { + // This is a path parameter, it matches any path part continue; } - + + // For regular path parts, they must match exactly if (!std.mem.eql(u8, pattern_part, path_part)) { return false; } } } -}; + + fn extractParams(ctx: *core.Context, pattern: []const u8, path: []const u8) !void { + var pattern_parts = std.mem.split(u8, pattern, "/"); + var path_parts = std.mem.split(u8, path, "/"); + + // Skip empty first part if path starts with "/" + if (pattern.len > 0 and pattern[0] == '/') _ = pattern_parts.next(); + if (path.len > 0 and path[0] == '/') _ = path_parts.next(); + + while (true) { + const pattern_part = pattern_parts.next() orelse break; + const path_part = path_parts.next() orelse break; + + // Extract parameter if pattern part starts with ":" + if (std.mem.startsWith(u8, pattern_part, ":")) { + const param_name = pattern_part[1..]; // Skip the ":" prefix + try ctx.params.put(param_name, path_part); + } + } + } +}; \ No newline at end of file diff --git a/src/framework/server.zig b/src/framework/server.zig index 6cae4d3..a64b78b 100644 --- a/src/framework/server.zig +++ b/src/framework/server.zig @@ -1,27 +1,297 @@ const std = @import("std"); const core = @import("core"); +const Router = @import("router.zig").Router; +const net = std.net; +const Thread = std.Thread; +const Atomic = std.atomic.Atomic; pub const ServerConfig = struct { - port: u16, - host: []const u8, + port: u16 = 8080, + host: []const u8 = "127.0.0.1", + max_connections: usize = 1000, + thread_count: ?usize = null, // If null, use available CPU cores }; pub const Server = struct { allocator: std.mem.Allocator, config: ServerConfig, - - pub fn init(allocator: std.mem.Allocator, config: ServerConfig) !Server { + address: net.Address, + listener: ?net.StreamServer = null, + running: Atomic(bool) = Atomic(bool).init(false), + threads: std.ArrayList(Thread) = undefined, + + pub fn init(allocator: std.mem.Allocator, config: ServerConfig, router: ?*const Router) !Server { + const address = try net.Address.parseIp(config.host, config.port); + return Server{ .allocator = allocator, .config = config, + .address = address, + .threads = std.ArrayList(Thread).init(allocator), + .router = router, }; } - + pub fn deinit(self: *Server) void { - _ = self; + if (self.running.load(.acquire)) { + self.stop(); + } + + if (self.listener) |*listener| { + listener.deinit(); + } + + self.threads.deinit(); } - + pub fn start(self: *Server) !void { - _ = self; + if (self.running.load(.acquire)) { + return error.ServerAlreadyRunning; + } + + // Initialize the listener + self.listener = net.StreamServer.init(.{ + .reuse_address = true, + }); + + // Bind to the address + try self.listener.?.listen(self.address); + + // Set running flag + self.running.store(true, .release); + + // Determine thread count + const thread_count = self.config.thread_count orelse try Thread.getCpuCount(); + + // Start worker threads + var i: usize = 0; + while (i < thread_count) : (i += 1) { + const thread = try Thread.spawn(.{}, workerThread, .{self}); + try self.threads.append(thread); + } + + std.log.info("Server started on {s}:{d} with {d} threads", .{ + self.config.host, + self.config.port, + thread_count + }); + } + + pub fn stop(self: *Server) void { + if (!self.running.load(.acquire)) { + return; + } + + // Set running flag to false + self.running.store(false, .release); + + // Close the listener + if (self.listener) |*listener| { + listener.close(); + } + + // Wait for all threads to finish + for (self.threads.items) |thread| { + thread.join(); + } + + self.threads.clearAndFree(); + + std.log.info("Server stopped", .{}); + } + + fn workerThread(server: *Server) !void { + while (server.running.load(.acquire)) { + // Accept a connection + const connection = server.listener.?.accept() catch |err| { + if (err == error.ConnectionAborted) { + // This happens when the server is shutting down + if (!server.running.load(.acquire)) { + break; + } + } + std.log.err("Failed to accept connection: {s}", .{@errorName(err)}); + continue; + }; + + // Handle the connection + handleConnection(server, connection) catch |err| { + std.log.err("Error handling connection: {s}", .{@errorName(err)}); + }; + } + } + + fn handleConnection(server: *Server, connection: net.StreamServer.Connection) !void { + defer connection.stream.close(); + + // Create a buffer for reading the request + var buf: [4096]u8 = undefined; + const n = try connection.stream.read(&buf); + + if (n == 0) { + return error.EmptyRequest; + } + + // Parse the request + const request = buf[0..n]; + + // Check if it's a WebSocket upgrade request + if (std.mem.indexOf(u8, request, "Upgrade: websocket") != null) { + // Handle WebSocket upgrade + try handleWebSocketUpgrade(server.allocator, connection.stream, request); + return; + } + + // Create a context for the request + var ctx = core.Context.init(server.allocator); + defer ctx.deinit(); + + // Parse the request and fill the context + try parseRequest(&ctx, request); + + // Set default response headers + try ctx.response.headers.put("Content-Type", "text/plain"); + try ctx.response.headers.put("Server", "Zup"); + + // Route the request to the appropriate handler if router is available + if (server.router) |router| { + router.handle(&ctx) catch |err| { + if (err == error.RouteNotFound) { + ctx.response.status = 404; + ctx.response.body = try server.allocator.dupe(u8, "Not Found"); + } else { + ctx.response.status = 500; + ctx.response.body = try server.allocator.dupe(u8, "Internal Server Error"); + std.log.err("Error handling request: {s}", .{@errorName(err)}); + } + }; + } else { + // If no router, just return a simple response + ctx.response.body = try server.allocator.dupe(u8, "Hello from Zup Server!"); + } + + // Build and send the HTTP response + var response_buffer = std.ArrayList(u8).init(server.allocator); + defer response_buffer.deinit(); + + // Write status line + try std.fmt.format(response_buffer.writer(), "HTTP/1.1 {d} {s}\r\n", .{ + ctx.response.status, + if (ctx.response.status == 200) "OK" else "Error", + }); + + // Write headers + var header_it = ctx.response.headers.iterator(); + while (header_it.next()) |entry| { + try std.fmt.format(response_buffer.writer(), "{s}: {s}\r\n", .{ + entry.key_ptr.*, + entry.value_ptr.*, + }); + } + + // Write Content-Length header + const body_len = if (ctx.response.body) |body| body.len else 0; + try std.fmt.format(response_buffer.writer(), "Content-Length: {d}\r\n", .{body_len}); + + // End headers + try response_buffer.appendSlice("\r\n"); + + // Write body if present + if (ctx.response.body) |body| { + try response_buffer.appendSlice(body); + } + + // Send the response + _ = try connection.stream.write(response_buffer.items); } }; + +pub fn parseRequest(ctx: *core.Context, request: []const u8) !void { + // Split the request into lines + var lines = std.mem.split(u8, request, "\r\n"); + + // Parse the request line + const request_line = lines.next() orelse return error.InvalidRequest; + var parts = std.mem.split(u8, request_line, " "); + + // Get the method + const method_str = parts.next() orelse return error.InvalidRequest; + ctx.request.method = core.Method.fromString(method_str) orelse return error.UnsupportedMethod; + + // Get the path + const path = parts.next() orelse return error.InvalidRequest; + ctx.request.path = try ctx.allocator.dupe(u8, path); + + // Parse headers + while (lines.next()) |line| { + if (line.len == 0) break; // Empty line indicates end of headers + + const colon_pos = std.mem.indexOf(u8, line, ":") orelse continue; + const header_name = std.mem.trim(u8, line[0..colon_pos], " "); + const header_value = std.mem.trim(u8, line[colon_pos + 1 ..], " "); + + try ctx.request.headers.put( + try ctx.allocator.dupe(u8, header_name), + try ctx.allocator.dupe(u8, header_value), + ); + } + + // Parse body if present + // Find the empty line that separates headers from body + var body_start: usize = 0; + const headers_end = std.mem.indexOf(u8, request, "\r\n\r\n"); + if (headers_end) |pos| { + body_start = pos + 4; // Skip the \r\n\r\n + + // Check if there's a body + if (body_start < request.len) { + const body = request[body_start..]; + if (body.len > 0) { + ctx.request.body = try ctx.allocator.dupe(u8, body); + } + } + } +} + +fn handleWebSocketUpgrade(allocator: std.mem.Allocator, stream: net.Stream, request: []const u8) !void { + const websocket = @import("../websocket.zig"); + try websocket.handleUpgrade(allocator, stream, request); + + // After upgrade, handle the WebSocket connection + // This is a simple echo server for demonstration + while (true) { + var frame = websocket.readMessage(allocator, stream) catch |err| { + if (err == error.ConnectionClosed) { + break; + } + std.log.err("Error reading WebSocket message: {s}", .{@errorName(err)}); + break; + }; + defer allocator.free(frame.payload); + + // Handle different frame types + switch (frame.opcode) { + .text, .binary => { + // Echo the message back + try websocket.writeMessage(allocator, stream, frame.payload); + }, + .close => { + // Send close frame and exit + try websocket.writeMessage(allocator, stream, ""); + break; + }, + .ping => { + // Respond to ping with pong + const pong_frame = websocket.WebSocketFrame{ + .opcode = .pong, + .payload = frame.payload, + }; + var frame_buf: [1024]u8 = undefined; + var fbs = std.io.fixedBufferStream(&frame_buf); + try pong_frame.encode(allocator, fbs.writer()); + _ = try stream.write(fbs.getWritten()); + }, + else => {}, // Ignore other frame types + } + } +} \ No newline at end of file diff --git a/src/framework/server_test.zig b/src/framework/server_test.zig index 674b70f..1504952 100644 --- a/src/framework/server_test.zig +++ b/src/framework/server_test.zig @@ -1,20 +1,27 @@ const std = @import("std"); const testing = std.testing; const Server = @import("server.zig").Server; +const Router = @import("router.zig").Router; test "server - basic start stop" { std.debug.print("\n=== Starting basic server test ===\n", .{}); + // Create a router + var router = Router.init(testing.allocator); + defer router.deinit(); + + // Initialize server with the router var server = try Server.init(testing.allocator, .{ .port = 0, .thread_count = 1, // Minimize threads for testing - }); + }, &router); + defer server.deinit(); var running = true; const thread = try std.Thread.spawn(.{}, struct { fn run(srv: *Server, is_running: *bool) void { std.debug.print("Server thread starting...\n", .{}); - srv.listen() catch |err| { + srv.start() catch |err| { std.debug.print("Server error: {}\n", .{err}); }; is_running.* = false; @@ -42,8 +49,4 @@ test "server - basic start stop" { thread.join(); std.debug.print("Server thread joined\n", .{}); - - // Clean up resources - server.deinit(); - std.debug.print("Server resources cleaned up\n", .{}); -} +} \ No newline at end of file diff --git a/src/framework/tests.zig b/src/framework/tests.zig index c63c90d..406b7f8 100644 --- a/src/framework/tests.zig +++ b/src/framework/tests.zig @@ -2,6 +2,45 @@ const std = @import("std"); const testing = std.testing; const core = @import("core.zig"); const Server = @import("server.zig").Server; +const Router = @import("router.zig").Router; + +// Helper function to parse a request since core.Request.parse() no longer exists +fn parseTestRequest(allocator: std.mem.Allocator, raw_request: []const u8) !core.Request { + var ctx = core.Context.init(allocator); + defer ctx.deinit(); + + try parseRequest(&ctx, raw_request); + + // Create a new request to return (since we're deinit'ing the context) + var request = core.Request.init(allocator); + request.method = ctx.request.method; + request.path = try allocator.dupe(u8, ctx.request.path); + + // Copy headers + var it = ctx.request.headers.iterator(); + while (it.next()) |entry| { + try request.headers.put( + try allocator.dupe(u8, entry.key_ptr.*), + try allocator.dupe(u8, entry.value_ptr.*) + ); + } + + // Copy body if present + if (ctx.request.body) |body| { + request.body = try allocator.dupe(u8, body); + } + + return request; +} + +// Import the parseRequest function from server.zig +const parseRequest = @import("server.zig").parseRequest; + +// Helper function to set text response +fn setText(ctx: *core.Context, text: []const u8) !void { + ctx.response.body = try ctx.allocator.dupe(u8, text); + try ctx.response.headers.put("Content-Type", "text/plain"); +} test "memory safety - request parsing" { const allocator = testing.allocator; @@ -16,7 +55,7 @@ test "memory safety - request parsing" { std.debug.print("\nRaw request:\n{s}\n", .{raw_request}); - var request = try core.Request.parse(allocator, raw_request); + var request = try parseTestRequest(allocator, raw_request); defer { std.debug.print("\nDeinit request\n", .{}); request.deinit(); @@ -24,7 +63,7 @@ test "memory safety - request parsing" { try testing.expectEqualStrings("/test", request.path); try testing.expectEqualStrings("localhost:8080", request.headers.get("Host").?); - try testing.expectEqual(@as(usize, 0), request.body.len); + try testing.expect(request.body == null); } test "memory safety - response handling" { @@ -40,9 +79,9 @@ test "memory safety - response handling" { // Add a small body first std.debug.print("\nTesting small body\n", .{}); { - const small_body = try core.Request.allocBody(allocator, "Hello"); - response.setBody(small_body); - try testing.expectEqualStrings("Hello", response.body); + const small_body = try allocator.dupe(u8, "Hello"); + response.body = small_body; + try testing.expectEqualStrings("Hello", response.body.?); } // Now test with a larger body @@ -52,10 +91,16 @@ test "memory safety - response handling" { defer allocator.free(large_body); @memset(large_body, 'A'); - const body = try core.Request.allocBody(allocator, large_body); - response.setBody(body); - try testing.expectEqual(@as(usize, 1024), response.body.len); - try testing.expect(response.body[0] == 'A'); + const body_copy = try allocator.dupe(u8, large_body); + + // Free previous body + if (response.body) |old_body| { + allocator.free(old_body); + } + + response.body = body_copy; + try testing.expectEqual(@as(usize, 1024), response.body.?.len); + try testing.expect(response.body.?[0] == 'A'); } std.debug.print("\nLarge body test complete\n", .{}); @@ -64,40 +109,47 @@ test "memory safety - response handling" { test "single request" { const allocator = testing.allocator; - std.debug.print("\nInit server\n", .{}); - var server = try Server.init(allocator, .{ - .port = 0, // Random port for testing - .thread_count = 1, - }); - defer { - std.debug.print("\nDeinit server\n", .{}); - server.deinit(); - } - + std.debug.print("\nInit router\n", .{}); + var router = Router.init(allocator); + defer router.deinit(); + // Add test endpoint - try server.get("/test", &struct { + try router.get("/test", &struct { fn handler(ctx: *core.Context) !void { std.debug.print("\nHandling request\n", .{}); - try ctx.text("ok"); + try setText(ctx, "ok"); std.debug.print("\nRequest handled\n", .{}); } }.handler); + std.debug.print("\nInit server\n", .{}); + var server = try Server.init(allocator, .{ + .port = 0, // Random port for testing + .thread_count = 1, + }, &router); + defer server.deinit(); + // Start server in background std.debug.print("\nStarting server\n", .{}); - const thread = try std.Thread.spawn(.{}, Server.start, .{&server}); - defer { - std.debug.print("\nStopping server\n", .{}); - server.running.store(false, .release); - thread.join(); - } - + var running = true; + const thread = try std.Thread.spawn(.{}, struct { + fn run(srv: *Server, is_running: *bool) void { + srv.start() catch |err| { + std.debug.print("Server error: {}\n", .{err}); + }; + is_running.* = false; + } + }.run, .{&server, &running}); + // Wait for server to start - std.time.sleep(10 * std.time.ns_per_ms); - + std.time.sleep(100 * std.time.ns_per_ms); + + // Get server address + const server_address = server.address; + std.debug.print("\nMaking request to {}\n", .{server_address}); + // Make request - std.debug.print("\nMaking request to {}\n", .{server.listener.listen_address}); - const client = try std.net.tcpConnectToAddress(server.listener.listen_address); + const client = try std.net.tcpConnectToAddress(server_address); defer client.close(); const request = @@ -118,43 +170,54 @@ test "single request" { try testing.expect(std.mem.indexOf(u8, response, "ok") != null); std.debug.print("\nTest complete\n", .{}); + + // Stop server + server.stop(); + thread.join(); } test "concurrent requests" { const allocator = testing.allocator; - std.debug.print("\nInit server\n", .{}); - var server = try Server.init(allocator, .{ - .port = 0, // Random port for testing - .thread_count = 4, - }); - defer { - std.debug.print("\nDeinit server\n", .{}); - server.deinit(); - } - + std.debug.print("\nInit router\n", .{}); + var router = Router.init(allocator); + defer router.deinit(); + // Add test endpoint - try server.get("/concurrent", &struct { + try router.get("/concurrent", &struct { fn handler(ctx: *core.Context) !void { std.debug.print("\nHandling request in worker thread\n", .{}); // Simulate work std.time.sleep(10 * std.time.ns_per_ms); - try ctx.text("ok"); + try setText(ctx, "ok"); std.debug.print("\nRequest handled\n", .{}); } }.handler); + std.debug.print("\nInit server\n", .{}); + var server = try Server.init(allocator, .{ + .port = 0, // Random port for testing + .thread_count = 4, + }, &router); + defer server.deinit(); + // Start server in background std.debug.print("\nStarting server\n", .{}); - const thread = try std.Thread.spawn(.{}, Server.start, .{&server}); - defer { - std.debug.print("\nStopping server\n", .{}); - server.running.store(false, .release); - thread.join(); - } - + var running = true; + const thread = try std.Thread.spawn(.{}, struct { + fn run(srv: *Server, is_running: *bool) void { + srv.start() catch |err| { + std.debug.print("Server error: {}\n", .{err}); + }; + is_running.* = false; + } + }.run, .{&server, &running}); + // Wait for server to start - std.time.sleep(10 * std.time.ns_per_ms); + std.time.sleep(100 * std.time.ns_per_ms); + + // Get server address + const server_address = server.address; // Make concurrent requests const RequestThread = struct { @@ -186,11 +249,15 @@ test "concurrent requests" { std.debug.print("\nStarting concurrent requests\n", .{}); var threads: [2]std.Thread = undefined; for (&threads) |*t| { - t.* = try std.Thread.spawn(.{}, RequestThread.make_request, .{server.listener.listen_address}); + t.* = try std.Thread.spawn(.{}, RequestThread.make_request, .{server_address}); } for (threads) |t| { t.join(); } std.debug.print("\nConcurrent requests complete\n", .{}); -} + + // Stop server + server.stop(); + thread.join(); +} \ No newline at end of file diff --git a/src/framework/trpc_test.zig b/src/framework/trpc_test.zig index 52328c2..a7b2ec7 100644 --- a/src/framework/trpc_test.zig +++ b/src/framework/trpc_test.zig @@ -1,22 +1,57 @@ const std = @import("std"); const testing = std.testing; const Server = @import("server.zig").Server; +const Router = @import("router.zig").Router; const core = @import("core.zig"); +const parseRequest = @import("server.zig").parseRequest; + +// Helper function to parse a request +fn parseTestRequest(allocator: std.mem.Allocator, raw_request: []const u8) !core.Request { + var ctx = core.Context.init(allocator); + defer ctx.deinit(); + + try parseRequest(&ctx, raw_request); + + // Create a new request to return (since we're deinit'ing the context) + var request = core.Request.init(allocator); + request.method = ctx.request.method; + request.path = try allocator.dupe(u8, ctx.request.path); + + // Copy headers + var it = ctx.request.headers.iterator(); + while (it.next()) |entry| { + try request.headers.put( + try allocator.dupe(u8, entry.key_ptr.*), + try allocator.dupe(u8, entry.value_ptr.*) + ); + } + + // Copy body if present + if (ctx.request.body) |body| { + request.body = try allocator.dupe(u8, body); + } + + return request; +} test "trpc - basic procedure call" { std.debug.print("\n=== Starting TRPC basic procedure call test ===\n", .{}); + // Create a router + var router = Router.init(testing.allocator); + defer router.deinit(); + var server = try Server.init(testing.allocator, .{ .port = 0, .thread_count = 1, // Minimize threads for testing - }); + }, &router); defer server.deinit(); var running = true; const thread = try std.Thread.spawn(.{}, struct { fn run(srv: *Server, is_running: *bool) void { std.debug.print("Server thread starting...\n", .{}); - srv.listen() catch |err| { + srv.start() catch |err| { std.debug.print("Server error: {}\n", .{err}); }; is_running.* = false; @@ -29,14 +64,27 @@ test "trpc - basic procedure call" { std.time.sleep(100 * std.time.ns_per_ms); // Make a request to the server - const client = try std.net.Client.init(testing.allocator); - defer client.deinit(); + // Note: std.net.Client might not exist, so we'll use direct TCP connection instead + const server_address = server.address; + const client = try std.net.tcpConnectToAddress(server_address); + defer client.close(); + + const request_str = + \\POST /trpc HTTP/1.1 + \\Host: localhost + \\Content-Type: application/json + \\Content-Length: 57 + \\ + \\{"method":"add","params":{"a":1,"b":2},"id":1} + ; - const request = try core.Request.parse(testing.allocator, "POST /trpc HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/json\r\nContent-Length: 57\r\n\r\n{\"method\":\"add\",\"params\":{\"a\":1,\"b\":2},\"id\":1}"); - defer request.deinit(); + std.debug.print("Sending request:\n{s}\n", .{request_str}); + try client.writer().writeAll(request_str); - const response = try client.request(request); - defer response.deinit(); + var buf: [1024]u8 = undefined; + const n = try client.read(&buf); + const response = buf[0..n]; + std.debug.print("Received response:\n{s}\n", .{response}); // Stop server std.debug.print("Stopping server...\n", .{}); @@ -54,8 +102,4 @@ test "trpc - basic procedure call" { thread.join(); std.debug.print("Server thread joined\n", .{}); - - // Clean up resources - server.deinit(); - std.debug.print("Server resources cleaned up\n", .{}); -} +} \ No newline at end of file diff --git a/src/main.zig b/src/main.zig index 1aec7fa..33a6a10 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,36 +1,111 @@ const std = @import("std"); const json = std.json; -const core = @import("core"); -const GrpcRouter = @import("grpc_router").GrpcRouter; -const ArrayHashMap = std.array_hash_map.ArrayHashMap; +const framework = @import("framework"); +const core = @import("framework/core.zig"); +const Server = @import("framework/server.zig").Server; +const ServerConfig = @import("framework/server.zig").ServerConfig; +const Router = @import("framework/router.zig").Router; -// Example procedure handler that returns a greeting -fn greetingHandler(ctx: *core.Context, input: ?json.Value) !json.Value { - const name = if (input) |value| blk: { - if (value.object.get("name")) |name_value| { - break :blk name_value.string; - } - break :blk "World"; - } else "World"; - - var result = ArrayHashMap([]const u8, json.Value, std.array_hash_map.StringContext, true).init(ctx.allocator); - try result.put("message", json.Value{ .string = try std.fmt.allocPrint(ctx.allocator, "Hello, {s}!", .{name}) }); +// Global variable to hold server reference for signal handling +var global_server: ?*Server = null; - return json.Value{ .object = result }; +// Example handler that returns a greeting +fn greetingHandler(ctx: *core.Context) !void { + // Parse request body if present + var name: []const u8 = "World"; + + if (ctx.request.body) |body| { + var parser = json.Parser.init(ctx.allocator, false); + defer parser.deinit(); + + var parsed = parser.parse(body) catch |err| { + std.log.err("Failed to parse JSON: {s}", .{@errorName(err)}); + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Invalid JSON"); + return; + }; + defer parsed.deinit(); + + if (parsed.root.Object.get("name")) |name_value| { + if (name_value == .String) { + name = name_value.String; + } + } + } + + // Set response + ctx.response.status = 200; + try ctx.response.headers.put("Content-Type", "application/json"); + + // Create response JSON + var response = std.ArrayList(u8).init(ctx.allocator); + defer response.deinit(); + + try std.fmt.format(response.writer(), "{{\"message\":\"Hello, {s}!\"}}", .{name}); + ctx.response.body = try ctx.allocator.dupe(u8, response.items); } -// Example procedure handler that adds two numbers -fn addHandler(ctx: *core.Context, input: ?json.Value) !json.Value { - _ = ctx; - if (input == null) return error.MissingInput; - - const a = input.?.object.get("a") orelse return error.MissingFirstNumber; - const b = input.?.object.get("b") orelse return error.MissingSecondNumber; - - if (a != .integer or b != .integer) return error.InvalidNumberFormat; +// Example handler that adds two numbers +fn addHandler(ctx: *core.Context) !void { + // Ensure we have a request body + if (ctx.request.body == null) { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Missing request body"); + return; + } + + // Parse JSON + var parser = json.Parser.init(ctx.allocator, false); + defer parser.deinit(); + + var parsed = parser.parse(ctx.request.body.?) catch |err| { + std.log.err("Failed to parse JSON: {s}", .{@errorName(err)}); + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Invalid JSON"); + return; + }; + defer parsed.deinit(); + + // Extract a and b values + const a_value = parsed.root.Object.get("a") orelse { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Missing 'a' parameter"); + return; + }; + + const b_value = parsed.root.Object.get("b") orelse { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Missing 'b' parameter"); + return; + }; + + // Ensure a and b are integers + if (a_value != .Integer or b_value != .Integer) { + ctx.response.status = 400; + ctx.response.body = try ctx.allocator.dupe(u8, "Parameters 'a' and 'b' must be integers"); + return; + } + + // Calculate result + const result = a_value.Integer + b_value.Integer; + + // Set response + ctx.response.status = 200; + try ctx.response.headers.put("Content-Type", "application/json"); + + // Create response JSON + var response = std.ArrayList(u8).init(ctx.allocator); + defer response.deinit(); + + try std.fmt.format(response.writer(), "{{\"result\":{d}}}", .{result}); + ctx.response.body = try ctx.allocator.dupe(u8, response.items); +} - const result = a.integer + b.integer; - return json.Value{ .integer = result }; +// Root handler +fn rootHandler(ctx: *core.Context) !void { + ctx.response.status = 200; + try ctx.response.headers.put("Content-Type", "text/plain"); + ctx.response.body = try ctx.allocator.dupe(u8, "Welcome to Zup Server!"); } pub fn main() !void { @@ -38,41 +113,70 @@ pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); - - // Create gRPC router - var router = GrpcRouter.init(allocator); + + // Create router + var router = Router.init(allocator); defer router.deinit(); - - // Register procedures - try router.procedure("greeting", greetingHandler, null, null); - try router.procedure("add", addHandler, null, null); - - // Start server on port 8080 - std.log.info("Starting gRPC server on port 8080...", .{}); - try router.listen(8080); - + + // Register routes + try router.get("/", rootHandler); + try router.post("/greeting", greetingHandler); + try router.post("/add", addHandler); + + // Create server config + const config = ServerConfig{ + .port = 8080, + .host = "127.0.0.1", + .thread_count = 4, // Use 4 threads or set to null to use all available cores + }; + + // Create and start server + var server = try Server.init(allocator, config, &router); + defer server.deinit(); + + // Set global server reference for signal handling + global_server = &server; + + std.log.info("Starting server on {s}:{d}...", .{config.host, config.port}); + + // Start the server with proper error handling + server.start() catch |err| { + std.log.err("Failed to start server: {s}", .{@errorName(err)}); + return err; + }; + // Wait for server to be ready var attempts: usize = 0; const max_attempts = 50; while (attempts < max_attempts) : (attempts += 1) { - if (router.server) |server| { - if (server.running.load(.acquire)) break; - } + if (server.running.load(.acquire)) break; std.time.sleep(100 * std.time.ns_per_ms); } - + if (attempts >= max_attempts) { - std.log.err("Server failed to start", .{}); - return error.ServerStartFailed; + std.log.err("Server failed to start within timeout period", .{}); + return error.ServerStartTimeout; } - + std.log.info("Server is running. Use Ctrl+C to stop.", .{}); - + + // Set up signal handling for graceful shutdown + const sigint = std.os.SIGINT; + _ = std.os.sigaction(sigint, &std.os.Sigaction{ + .handler = .{ .handler = handleSignal }, + .mask = std.os.empty_sigset, + .flags = 0, + }, null); + // Keep main thread alive - while (true) { - if (router.server) |server| { - if (!server.running.load(.acquire)) break; - } + while (server.running.load(.acquire)) { std.time.sleep(1000 * std.time.ns_per_ms); } } + +fn handleSignal(sig: c_int) callconv(.C) void { + std.log.info("Received signal {d}, shutting down...", .{sig}); + if (global_server) |server| { + server.stop(); + } +} \ No newline at end of file diff --git a/src/root.zig b/src/root.zig index 8616695..532bf74 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,5 +1,11 @@ -pub const core = @import("core"); +// Export the core framework components +pub const core = @import("framework/core.zig"); pub const framework = @import("framework"); -pub const schema = @import("schema"); -pub const runtime_router = @import("runtime_router"); -pub const grpc_router = @import("grpc_router"); +pub const websocket = @import("websocket.zig"); +pub const client = @import("client.zig"); +pub const spice = @import("spice.zig"); +pub const bench = @import("bench.zig"); +pub const benchmark = @import("benchmark.zig"); + +// Main entry point +pub const main = @import("main.zig"); \ No newline at end of file diff --git a/src/websocket.zig b/src/websocket.zig index b3ccf1e..0fd5c38 100644 --- a/src/websocket.zig +++ b/src/websocket.zig @@ -5,6 +5,15 @@ const mem = std.mem; const base64 = std.base64; const Sha1 = std.crypto.hash.Sha1; +// Debug flag - set to false in production +const debug_logging = false; + +fn debugLog(comptime fmt: []const u8, args: anytype) void { + if (debug_logging) { + std.debug.print(fmt, args); + } +} + pub const WebSocketFrame = struct { fin: bool = true, opcode: Opcode, @@ -21,13 +30,13 @@ pub const WebSocketFrame = struct { }; pub fn encode(self: WebSocketFrame, allocator: std.mem.Allocator, writer: anytype) !void { - std.debug.print("\nEncoding frame: opcode={}, mask={}, payload={s}\n", .{ self.opcode, self.mask, self.payload }); + debugLog("\nEncoding frame: opcode={}, mask={}, payload={s}\n", .{ self.opcode, self.mask, self.payload }); var first_byte: u8 = 0; if (self.fin) first_byte |= 0x80; first_byte |= @intFromEnum(self.opcode); try writer.writeByte(first_byte); - std.debug.print("First byte: {b:0>8}\n", .{first_byte}); + debugLog("First byte: {b:0>8}\n", .{first_byte}); var second_byte: u8 = 0; if (self.mask) second_byte |= 0x80; @@ -45,13 +54,13 @@ pub const WebSocketFrame = struct { try writer.writeByte(second_byte); try writer.writeInt(u64, @as(u64, @intCast(payload_len)), .big); } - std.debug.print("Second byte: {b:0>8}\n", .{second_byte}); + debugLog("Second byte: {b:0>8}\n", .{second_byte}); if (self.mask) { // Generate random masking key var masking_key: [4]u8 = undefined; std.crypto.random.bytes(&masking_key); - std.debug.print("Masking key: {any}\n", .{masking_key}); + debugLog("Masking key: {any}\n", .{masking_key}); // Write masking key try writer.writeAll(&masking_key); @@ -65,25 +74,25 @@ pub const WebSocketFrame = struct { try masked_bytes.append(masked_byte); } try writer.writeAll(masked_bytes.items); - std.debug.print("Masked payload: {any}\n", .{masked_bytes.items}); + debugLog("Masked payload: {any}\n", .{masked_bytes.items}); } else { try writer.writeAll(self.payload); - std.debug.print("Unmasked payload: {any}\n", .{self.payload}); + debugLog("Unmasked payload: {any}\n", .{self.payload}); } } pub fn decode(allocator: std.mem.Allocator, reader: anytype) !WebSocketFrame { - std.debug.print("\nDecoding frame...\n", .{}); + debugLog("\nDecoding frame...\n", .{}); const first_byte = try reader.readByte(); const fin = (first_byte & 0x80) != 0; const opcode = @as(Opcode, @enumFromInt(first_byte & 0x0F)); - std.debug.print("First byte: {b:0>8}, fin={}, opcode={}\n", .{ first_byte, fin, opcode }); + debugLog("First byte: {b:0>8}, fin={}, opcode={}\n", .{ first_byte, fin, opcode }); const second_byte = try reader.readByte(); const mask = (second_byte & 0x80) != 0; const payload_len = second_byte & 0x7F; - std.debug.print("Second byte: {b:0>8}, mask={}, initial payload_len={}\n", .{ second_byte, mask, payload_len }); + debugLog("Second byte: {b:0>8}, mask={}, initial payload_len={}\n", .{ second_byte, mask, payload_len }); const extended_payload_len: u64 = if (payload_len == 126) try reader.readInt(u16, .big) @@ -92,12 +101,12 @@ pub const WebSocketFrame = struct { else payload_len; - std.debug.print("Extended payload length: {}\n", .{extended_payload_len}); + debugLog("Extended payload length: {}\n", .{extended_payload_len}); const masking_key = if (mask) blk: { var key: [4]u8 = undefined; _ = try reader.readAll(&key); - std.debug.print("Masking key: {any}\n", .{key}); + debugLog("Masking key: {any}\n", .{key}); break :blk key; } else [_]u8{0} ** 4; @@ -108,13 +117,13 @@ pub const WebSocketFrame = struct { const n = try reader.readAll(payload); if (n != extended_payload_len) return error.InvalidFrame; - std.debug.print("Raw payload: {any}\n", .{payload}); + debugLog("Raw payload: {any}\n", .{payload}); if (mask) { for (payload, 0..) |*byte, i| { byte.* ^= masking_key[i % 4]; } - std.debug.print("Unmasked payload: {any}\n", .{payload}); + debugLog("Unmasked payload: {any}\n", .{payload}); } const frame = WebSocketFrame{ @@ -124,7 +133,7 @@ pub const WebSocketFrame = struct { .payload = payload, }; - std.debug.print("Decoded frame: opcode={}, mask={}, payload={s}\n", .{ frame.opcode, frame.mask, frame.payload }); + debugLog("Decoded frame: opcode={}, mask={}, payload={s}\n", .{ frame.opcode, frame.mask, frame.payload }); return frame; } }; @@ -252,4 +261,4 @@ pub fn readMessage(allocator: std.mem.Allocator, stream: net.Stream) !WebSocketF .mask = mask, .payload = payload, }; -} +} \ No newline at end of file