diff --git a/README.md b/README.md index d76a541..4c4449a 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,7 @@ > one server. no resistance. -A persistent LSP process manager daemon for Neovim. Fixes memory bloat, stuck diagnostics, monorepo server duplication, and session degradation — the recurring pain points in Neovim's LSP lifecycle. - -## The Problem - -Neovim starts a new LSP server per session, leaks memory, leaves stuck diagnostics on detach, and spawns duplicate servers in monorepos. ohm solves it at the daemon layer. +A persistent LSP process manager daemon for Neovim. Neovim starts a fresh server per session — ohm replaces that with one shared server per `{root_dir, language}` pair, fixing memory bloat, stuck diagnostics, and monorepo duplication at the daemon layer. ## How It Works @@ -28,7 +24,7 @@ Neovim instances (any number) - **Grace period** — when refs hit 0, waits 10s before killing. Reopen a file within the window to cancel. - **Diagnostic fence** — sends `textDocument/didClose` before detach to prevent stuck diagnostics. - **Respawn** — crashed servers are automatically restarted without losing the proxy socket. -- **Watchdog** — kills servers exceeding 1500MB RSS or frozen for 5+ minutes. +- **Watchdog** — kills runaway or frozen servers automatically. - **Shutdown interception** — intercepts client `shutdown`/`exit` so individual session closes don't kill the shared server. ## Requirements @@ -146,6 +142,10 @@ mkdir -p tmp && go run . tmp/ohm.sock mkdir -p tmp && go run . --debug tmp/ohm.sock ``` +## Architecture + +See [docs/architecture.md](docs/architecture.md) for a deep dive: two-socket design, request flow, ID rewriting, initialize caching, respawn, and the concurrency model. + ## License MIT diff --git a/daemon/daemon.go b/daemon/daemon.go index bb80ac9..9ee2794 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -135,6 +135,16 @@ func (d *Daemon) handleAttach(msg AttachMsg) (string, error) { } func (d *Daemon) respawnServer(key ServerKey) { + // Cancel any pending kill timer — it was set for the crashed process and + // would otherwise fire on the newly-spawned one. + d.mu.Lock() + if timer, ok := d.pendingKill[key]; ok { + timer.Stop() + delete(d.pendingKill, key) + slog.Info("respawn: cancelled pending kill", "lang", key.LanguageID) + } + d.mu.Unlock() + server, ok := d.registry.Get(key) if !ok { return @@ -179,6 +189,8 @@ func captureStderr(proc *Process, lang string) { } } +// proxySocketPath returns a stable socket path for the per-server LSP proxy. +// The 4-byte hash prefix is for uniqueness across root+lang pairs, not security. func (d *Daemon) proxySocketPath(key ServerKey) string { h := sha256.Sum256([]byte(key.RootDir + "|" + key.LanguageID)) name := fmt.Sprintf("ohm-%s-%x.sock", key.LanguageID, h[:4]) @@ -276,6 +288,11 @@ func (d *Daemon) handleConn(conn net.Conn) { } switch msg.Method { case "attach": + if len(msg.Params) == 0 { + slog.Error("attach: missing params") + h.WriteResponse(conn, msg.MsgID, nil) + continue + } var a AttachMsg if err := h.DecodeParam(&a, msg.Params[0]); err != nil { slog.Error("decode attach", "err", err) @@ -291,6 +308,10 @@ func (d *Daemon) handleConn(conn net.Conn) { h.WriteResponse(conn, msg.MsgID, socketPath) case "detach": + if len(msg.Params) == 0 { + slog.Error("detach: missing params") + continue + } var a DetachMsg if err := h.DecodeParam(&a, msg.Params[0]); err != nil { slog.Error("decode detach", "err", err) diff --git a/daemon/frame_test.go b/daemon/frame_test.go new file mode 100644 index 0000000..f188177 --- /dev/null +++ b/daemon/frame_test.go @@ -0,0 +1,118 @@ +package daemon + +import ( + "bufio" + "bytes" + "strings" + "testing" +) + +func TestWriteReadFrame_roundtrip(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","method":"initialized","params":{}}`) + + var buf bytes.Buffer + if err := WriteFrame(&buf, body); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + + got, err := ReadFrame(bufio.NewReader(&buf)) + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if !bytes.Equal(got, body) { + t.Errorf("roundtrip mismatch\ngot: %s\nwant: %s", got, body) + } +} + +func TestWriteFrame_header(t *testing.T) { + body := []byte(`{}`) + var buf bytes.Buffer + if err := WriteFrame(&buf, body); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + s := buf.String() + if !strings.HasPrefix(s, "Content-Length: 2\r\n\r\n") { + t.Errorf("unexpected header: %q", s) + } +} + +func TestWriteReadFrame_empty(t *testing.T) { + body := []byte{} + var buf bytes.Buffer + if err := WriteFrame(&buf, body); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + got, err := ReadFrame(bufio.NewReader(&buf)) + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if len(got) != 0 { + t.Errorf("expected empty body, got %q", got) + } +} + +func TestReadFrame_missingContentLength(t *testing.T) { + // Headers with no Content-Length, then blank line + raw := "X-Custom: foo\r\n\r\n" + _, err := ReadFrame(bufio.NewReader(strings.NewReader(raw))) + if err == nil { + t.Fatal("expected error for missing Content-Length, got nil") + } + if !strings.Contains(err.Error(), "missing Content-Length") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestReadFrame_truncatedBody(t *testing.T) { + // Content-Length says 100 bytes but only 3 bytes follow + raw := "Content-Length: 100\r\n\r\nabc" + _, err := ReadFrame(bufio.NewReader(strings.NewReader(raw))) + if err == nil { + t.Fatal("expected error for truncated body, got nil") + } +} + +func TestReadFrame_multipleHeaders(t *testing.T) { + // LSP spec allows extra headers before Content-Length + body := []byte(`{"id":1}`) + var buf bytes.Buffer + buf.WriteString("Content-Type: application/vscode-jsonrpc; charset=utf-8\r\n") + if err := WriteFrame(&buf, body); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + // Rewrite: put extra header before Content-Length in a fresh buffer + combined := "Content-Type: application/vscode-jsonrpc\r\nContent-Length: 8\r\n\r\n{\"id\":1}" + got, err := ReadFrame(bufio.NewReader(strings.NewReader(combined))) + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if string(got) != `{"id":1}` { + t.Errorf("got %q", got) + } +} + +func TestWriteReadFrame_multipleMessages(t *testing.T) { + msgs := [][]byte{ + []byte(`{"id":1,"method":"initialize"}`), + []byte(`{"id":2,"method":"shutdown"}`), + []byte(`{"method":"exit"}`), + } + + var buf bytes.Buffer + for _, m := range msgs { + if err := WriteFrame(&buf, m); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + } + + r := bufio.NewReader(&buf) + for i, want := range msgs { + got, err := ReadFrame(r) + if err != nil { + t.Fatalf("msg %d: ReadFrame: %v", i, err) + } + if !bytes.Equal(got, want) { + t.Errorf("msg %d: got %s, want %s", i, got, want) + } + } +} diff --git a/daemon/lsp_client.go b/daemon/lsp_client.go index 442506b..1508e07 100644 --- a/daemon/lsp_client.go +++ b/daemon/lsp_client.go @@ -89,6 +89,7 @@ func RunClient(args []string) error { } func parseClientArgs(args []string) (socket, root, lang string, cmd []string, err error) { + hasSep := false for i := 0; i < len(args); i++ { switch args[i] { case "--socket": @@ -110,8 +111,9 @@ func parseClientArgs(args []string) (socket, root, lang string, cmd []string, er } lang = args[i] case "--": + hasSep = true cmd = args[i+1:] - return + i = len(args) // consumed; exit loop to run validation below } } if socket == "" { @@ -120,7 +122,7 @@ func parseClientArgs(args []string) (socket, root, lang string, cmd []string, er err = fmt.Errorf("missing --root") } else if lang == "" { err = fmt.Errorf("missing --lang") - } else { + } else if !hasSep { err = fmt.Errorf("missing -- ") } return diff --git a/daemon/lsp_client_test.go b/daemon/lsp_client_test.go new file mode 100644 index 0000000..fb108bf --- /dev/null +++ b/daemon/lsp_client_test.go @@ -0,0 +1,97 @@ +package daemon + +import ( + "testing" +) + +func TestParseClientArgs_valid(t *testing.T) { + args := []string{"--socket", "/tmp/ohm.sock", "--root", "/srv/proj", "--lang", "go", "--", "gopls", "-v"} + socket, root, lang, cmd, err := parseClientArgs(args) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if socket != "/tmp/ohm.sock" { + t.Errorf("socket: got %q", socket) + } + if root != "/srv/proj" { + t.Errorf("root: got %q", root) + } + if lang != "go" { + t.Errorf("lang: got %q", lang) + } + if len(cmd) != 2 || cmd[0] != "gopls" || cmd[1] != "-v" { + t.Errorf("cmd: got %v", cmd) + } +} + +func TestParseClientArgs_differentOrder(t *testing.T) { + args := []string{"--lang", "rust", "--root", "/ws", "--socket", "/var/ohm.sock", "--", "rust-analyzer"} + socket, root, lang, cmd, err := parseClientArgs(args) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if socket != "/var/ohm.sock" { + t.Errorf("socket: got %q", socket) + } + if root != "/ws" { + t.Errorf("root: got %q", root) + } + if lang != "rust" { + t.Errorf("lang: got %q", lang) + } + if len(cmd) != 1 || cmd[0] != "rust-analyzer" { + t.Errorf("cmd: got %v", cmd) + } +} + +func TestParseClientArgs_noCmd(t *testing.T) { + // -- present but no command after it + args := []string{"--socket", "/s", "--root", "/r", "--lang", "go", "--"} + _, _, _, cmd, err := parseClientArgs(args) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cmd) != 0 { + t.Errorf("expected empty cmd, got %v", cmd) + } +} + +func TestParseClientArgs_missingSocket(t *testing.T) { + args := []string{"--root", "/r", "--lang", "go", "--", "gopls"} + _, _, _, _, err := parseClientArgs(args) + if err == nil { + t.Fatal("expected error for missing --socket") + } +} + +func TestParseClientArgs_missingRoot(t *testing.T) { + args := []string{"--socket", "/s", "--lang", "go", "--", "gopls"} + _, _, _, _, err := parseClientArgs(args) + if err == nil { + t.Fatal("expected error for missing --root") + } +} + +func TestParseClientArgs_missingLang(t *testing.T) { + args := []string{"--socket", "/s", "--root", "/r", "--", "gopls"} + _, _, _, _, err := parseClientArgs(args) + if err == nil { + t.Fatal("expected error for missing --lang") + } +} + +func TestParseClientArgs_missingDoubleDash(t *testing.T) { + args := []string{"--socket", "/s", "--root", "/r", "--lang", "go"} + _, _, _, _, err := parseClientArgs(args) + if err == nil { + t.Fatal("expected error for missing -- separator") + } +} + +func TestParseClientArgs_socketMissingValue(t *testing.T) { + args := []string{"--socket"} + _, _, _, _, err := parseClientArgs(args) + if err == nil { + t.Fatal("expected error when --socket has no value") + } +} diff --git a/daemon/multiplexer.go b/daemon/multiplexer.go index 5b4d2b7..ebbfb02 100644 --- a/daemon/multiplexer.go +++ b/daemon/multiplexer.go @@ -34,6 +34,10 @@ func (c *Client) sendLoop() { defer c.conn.Close() for body := range c.sendCh { if err := WriteFrame(c.conn, body); err != nil { + // Close conn immediately so serveClient's ReadFrame returns, + // which triggers removeClient → shutdown → channel close. + // Without this, the drain loop below would block indefinitely. + c.conn.Close() for range c.sendCh { } return @@ -84,18 +88,23 @@ type Mux struct { lastNs atomic.Int64 // UnixNano of last LSP response, read by supervisor - // initResponse caches the initialize response body (with global ID). - // Once set, new clients get this instead of a real initialize round-trip. - initMu sync.RWMutex + // initResponse caches the initialize response body (stored with the global + // rewritten ID; the original client ID is substituted on each send). + // initInFlight gates the single real initialize round-trip; concurrent + // callers wait on initReady, which is closed when the response arrives. + initMu sync.Mutex initResponse []byte + initInFlight bool + initReady chan struct{} // closed when initResponse is populated onExit func() // called when LSP stdout closes } func newMux(proc *Process) *Mux { m := &Mux{ - proc: proc, - pending: make(map[uint64]*pendingReq), + proc: proc, + pending: make(map[uint64]*pendingReq), + initReady: make(chan struct{}), } m.lastNs.Store(time.Now().UnixNano()) return m @@ -154,15 +163,31 @@ func (m *Mux) serveClient(c *Client) { continue } - // initialize: if already done, return cached response immediately. - m.initMu.RLock() - cached := m.initResponse - m.initMu.RUnlock() - - if p.method == "initialize" && cached != nil { - out := rewriteIDRaw(cached, p.rawID) - c.write(out) - continue + // initialize: serialize all callers so only one round-trip reaches + // the LSP server. Concurrent callers wait on initReady. + if p.method == "initialize" { + m.initMu.Lock() + switch { + case m.initResponse != nil: + cached := m.initResponse + m.initMu.Unlock() + c.write(rewriteIDRaw(cached, p.rawID)) + continue + case m.initInFlight: + m.initMu.Unlock() + <-m.initReady + m.initMu.Lock() + cached := m.initResponse + m.initMu.Unlock() + if cached != nil { + c.write(rewriteIDRaw(cached, p.rawID)) + } + continue + default: + m.initInFlight = true + m.initMu.Unlock() + // fall through: forward this one initialize to the server + } } globalID := m.nextID.Add(1) @@ -217,11 +242,12 @@ func (m *Mux) Broadcast() { if ok { if req.client != nil { out := rewriteIDRaw(body, req.originalID) - // Cache the initialize response for future clients. + // Cache the initialize response and unblock any concurrent waiters. if req.method == "initialize" { m.initMu.Lock() if m.initResponse == nil { m.initResponse = body // keep global ID; rewrite on send + close(m.initReady) } m.initMu.Unlock() } diff --git a/daemon/multiplexer_test.go b/daemon/multiplexer_test.go new file mode 100644 index 0000000..7e219c3 --- /dev/null +++ b/daemon/multiplexer_test.go @@ -0,0 +1,393 @@ +package daemon + +import ( + "bufio" + "encoding/json" + "net" + "sync" + "testing" + "time" +) + +// --- helpers --- + +// testMux wires a Mux to in-process net.Pipe connections so tests can act as +// both the LSP server and any number of Neovim clients. +type testMux struct { + mux *Mux + serverIn net.Conn // read what the mux forwarded to the LSP server + serverOut net.Conn // write LSP server responses into the mux +} + +func newTestMux(t *testing.T) *testMux { + t.Helper() + // stdinPair: mux writes to stdinServer; test reads from stdinClient + stdinClient, stdinServer := net.Pipe() + // stdoutPair: mux reads from stdoutClient; test writes to stdoutServer + stdoutClient, stdoutServer := net.Pipe() + + proc := &Process{ + Stdin: stdinServer, + Stdout: stdoutClient, + PID: 9999, + } + + mux := newMux(proc) + go mux.Broadcast() + + t.Cleanup(func() { + stdoutServer.Close() + stdinClient.Close() + }) + + return &testMux{ + mux: mux, + serverIn: stdinClient, + serverOut: stdoutServer, + } +} + +// addClient connects a new client to the mux, returning the client-side conn. +func (tm *testMux) addClient(t *testing.T) (clientConn net.Conn, r *bufio.Reader) { + t.Helper() + clientConn, muxConn := net.Pipe() + tm.mux.AddClient(muxConn) + return clientConn, bufio.NewReader(clientConn) +} + +func sendFrameConn(t *testing.T, conn net.Conn, body []byte) { + t.Helper() + conn.SetDeadline(time.Now().Add(2 * time.Second)) + if err := WriteFrame(conn, body); err != nil { + t.Fatalf("sendFrame: %v", err) + } +} + +func recvFrameConn(t *testing.T, r *bufio.Reader, conn net.Conn) []byte { + t.Helper() + conn.SetDeadline(time.Now().Add(2 * time.Second)) + body, err := ReadFrame(r) + if err != nil { + t.Fatalf("recvFrame: %v", err) + } + return body +} + +func mustMarshal(v interface{}) []byte { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +} + +func lspRequest(id interface{}, method string) []byte { + return mustMarshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": map[string]interface{}{}, + }) +} + +func lspResponse(id interface{}, result interface{}) []byte { + return mustMarshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": result, + }) +} + +func lspNotification(method string) []byte { + return mustMarshal(map[string]interface{}{ + "jsonrpc": "2.0", + "method": method, + "params": map[string]interface{}{}, + }) +} + +func extractID(t *testing.T, body []byte) json.RawMessage { + t.Helper() + var obj struct { + ID json.RawMessage `json:"id"` + } + if err := json.Unmarshal(body, &obj); err != nil { + t.Fatalf("extractID unmarshal: %v", err) + } + return obj.ID +} + +func extractMethod(t *testing.T, body []byte) string { + t.Helper() + var obj struct { + Method string `json:"method"` + } + if err := json.Unmarshal(body, &obj); err != nil { + t.Fatalf("extractMethod: %v", err) + } + return obj.Method +} + +// --- peekLSP --- + +func TestPeekLSP_request(t *testing.T) { + body := lspRequest(42, "textDocument/hover") + p := peekLSP(body) + if !p.hasID { + t.Error("expected hasID") + } + if !p.hasMethod { + t.Error("expected hasMethod") + } + if p.method != "textDocument/hover" { + t.Errorf("method: got %q", p.method) + } +} + +func TestPeekLSP_response(t *testing.T) { + body := lspResponse(7, map[string]interface{}{"result": "ok"}) + p := peekLSP(body) + if !p.hasID { + t.Error("expected hasID") + } + if p.hasMethod { + t.Error("expected no method") + } +} + +func TestPeekLSP_notification(t *testing.T) { + body := lspNotification("textDocument/publishDiagnostics") + p := peekLSP(body) + if p.hasID { + t.Error("expected no id") + } + if !p.hasMethod { + t.Error("expected hasMethod") + } + if p.method != "textDocument/publishDiagnostics" { + t.Errorf("method: got %q", p.method) + } +} + +// --- rewriteID --- + +func TestRewriteIDUint(t *testing.T) { + body := lspRequest(1, "initialize") + out := rewriteIDUint(body, 999) + + var obj map[string]json.RawMessage + json.Unmarshal(out, &obj) + var got uint64 + json.Unmarshal(obj["id"], &got) + if got != 999 { + t.Errorf("id: want 999, got %d", got) + } +} + +func TestRewriteIDRaw(t *testing.T) { + body := lspResponse(999, nil) // global id + origID := json.RawMessage(`"client-req-1"`) + out := rewriteIDRaw(body, origID) + + var obj map[string]json.RawMessage + json.Unmarshal(out, &obj) + if string(obj["id"]) != `"client-req-1"` { + t.Errorf("id: got %s", obj["id"]) + } +} + +func TestRewriteIDUint_preservesOtherFields(t *testing.T) { + body := lspRequest(1, "shutdown") + out := rewriteIDUint(body, 42) + if extractMethod(t, out) != "shutdown" { + t.Error("method lost after rewrite") + } +} + +// --- mux routing --- + +// TestMux_requestIDRewriting verifies that the mux assigns a global ID when +// forwarding a request and restores the original ID in the response. +func TestMux_requestIDRewriting(t *testing.T) { + tm := newTestMux(t) + clientConn, clientR := tm.addClient(t) + + // Send a hover request with client id=5. + sendFrameConn(t, clientConn, lspRequest(5, "textDocument/hover")) + + // Mux should forward to server with a new global id (not 5). + serverBody := recvFrameConn(t, bufio.NewReader(tm.serverIn), tm.serverIn) + globalID := extractID(t, serverBody) + if string(globalID) == "5" { + t.Error("mux should rewrite client ID to a global ID") + } + + // Server responds with global id. + var gid uint64 + json.Unmarshal(globalID, &gid) + sendFrameConn(t, tm.serverOut, lspResponse(gid, map[string]string{"result": "hover"})) + + // Client should receive response with its original id=5. + resp := recvFrameConn(t, clientR, clientConn) + if string(extractID(t, resp)) != "5" { + t.Errorf("client id not restored, got: %s", extractID(t, resp)) + } +} + +// TestMux_shutdownIntercept verifies that shutdown is answered locally and +// never forwarded to the LSP server, keeping gopls alive. +func TestMux_shutdownIntercept(t *testing.T) { + tm := newTestMux(t) + clientConn, clientR := tm.addClient(t) + + sendFrameConn(t, clientConn, lspRequest(1, "shutdown")) + + // Client gets an immediate fake response. + resp := recvFrameConn(t, clientR, clientConn) + if string(extractID(t, resp)) != "1" { + t.Errorf("fake shutdown response id wrong: %s", extractID(t, resp)) + } + + // Nothing should reach the server. + tm.serverIn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + var buf [1]byte + if _, err := tm.serverIn.Read(buf[:]); err == nil { + t.Error("shutdown was forwarded to server, should have been intercepted") + } +} + +// TestMux_exitDropped verifies that exit notifications from clients are +// dropped and never forwarded — gopls must not be killed by a client disconnect. +func TestMux_exitDropped(t *testing.T) { + tm := newTestMux(t) + clientConn, _ := tm.addClient(t) + + sendFrameConn(t, clientConn, lspNotification("exit")) + + tm.serverIn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + var buf [1]byte + if _, err := tm.serverIn.Read(buf[:]); err == nil { + t.Error("exit notification was forwarded to server, should have been dropped") + } +} + +// TestMux_broadcastNotification verifies that a server-pushed notification +// reaches all connected clients. +func TestMux_broadcastNotification(t *testing.T) { + tm := newTestMux(t) + + connA, rA := tm.addClient(t) + connB, rB := tm.addClient(t) + + diag := lspNotification("textDocument/publishDiagnostics") + sendFrameConn(t, tm.serverOut, diag) + + bodyA := recvFrameConn(t, rA, connA) + bodyB := recvFrameConn(t, rB, connB) + + if extractMethod(t, bodyA) != "textDocument/publishDiagnostics" { + t.Errorf("client A: got method %q", extractMethod(t, bodyA)) + } + if extractMethod(t, bodyB) != "textDocument/publishDiagnostics" { + t.Errorf("client B: got method %q", extractMethod(t, bodyB)) + } +} + +// TestMux_initializeCaching verifies that the second client receives the +// cached initialize response and no second initialize reaches the server. +func TestMux_initializeCaching(t *testing.T) { + tm := newTestMux(t) + + // --- Client A: full initialize round-trip --- + connA, rA := tm.addClient(t) + sendFrameConn(t, connA, lspRequest(1, "initialize")) + + // Mux forwards to server; read and capture global id. + serverFrame := recvFrameConn(t, bufio.NewReader(tm.serverIn), tm.serverIn) + var gid uint64 + json.Unmarshal(extractID(t, serverFrame), &gid) + + // Server responds. + sendFrameConn(t, tm.serverOut, lspResponse(gid, map[string]interface{}{"capabilities": map[string]bool{}})) + + // Client A gets response with its original id=1. + respA := recvFrameConn(t, rA, connA) + if string(extractID(t, respA)) != "1" { + t.Errorf("client A: id wrong: %s", extractID(t, respA)) + } + + // --- Client B: should get cached response, nothing forwarded to server --- + connB, rB := tm.addClient(t) + sendFrameConn(t, connB, lspRequest(2, "initialize")) + + respB := recvFrameConn(t, rB, connB) + if string(extractID(t, respB)) != "2" { + t.Errorf("client B: id wrong: %s", extractID(t, respB)) + } + + // Nothing new should arrive at the server. + tm.serverIn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + var buf [1]byte + if _, err := tm.serverIn.Read(buf[:]); err == nil { + t.Error("second initialize was forwarded to server; should have used cache") + } +} + +// TestMux_concurrentInitialize verifies that when two clients send initialize +// simultaneously on a fresh mux, only one initialize reaches the server. +// This exercises the initInFlight / initReady synchronization (Bug D fix). +func TestMux_concurrentInitialize(t *testing.T) { + tm := newTestMux(t) + + connA, rA := tm.addClient(t) + connB, rB := tm.addClient(t) + + // Both clients fire initialize at roughly the same time. + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); sendFrameConn(t, connA, lspRequest(1, "initialize")) }() + go func() { defer wg.Done(); sendFrameConn(t, connB, lspRequest(2, "initialize")) }() + wg.Wait() + + // Exactly one initialize should reach the server. + serverFrame := recvFrameConn(t, bufio.NewReader(tm.serverIn), tm.serverIn) + if extractMethod(t, serverFrame) != "initialize" { + t.Fatalf("expected initialize, got %q", extractMethod(t, serverFrame)) + } + + // Verify no second initialize arrives within a short window. + tm.serverIn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + var buf [1]byte + if _, err := tm.serverIn.Read(buf[:]); err == nil { + t.Error("two initializes reached the server; only one should be forwarded") + } + + // Respond so both clients get their answer. + var gid uint64 + json.Unmarshal(extractID(t, serverFrame), &gid) + sendFrameConn(t, tm.serverOut, lspResponse(gid, map[string]interface{}{})) + + // Both clients must receive a response with their original IDs. + done := make(chan string, 2) + go func() { + resp := recvFrameConn(t, rA, connA) + done <- string(extractID(t, resp)) + }() + go func() { + resp := recvFrameConn(t, rB, connB) + done <- string(extractID(t, resp)) + }() + + ids := map[string]bool{} + for i := 0; i < 2; i++ { + select { + case id := <-done: + ids[id] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for initialize responses") + } + } + if !ids["1"] || !ids["2"] { + t.Errorf("expected ids 1 and 2, got: %v", ids) + } +} diff --git a/daemon/supervisor.go b/daemon/supervisor.go index dda51a7..3479242 100644 --- a/daemon/supervisor.go +++ b/daemon/supervisor.go @@ -6,11 +6,14 @@ import ( ) const ( - memLimitMB = 1500 - frozenThreshold = 5 * time.Minute + memLimitMB = 1500 // kill server if RSS exceeds this + frozenThreshold = 5 * time.Minute // kill server if no response received in this window watchInterval = 30 * time.Second ) +// startWatchdog launches a background goroutine that periodically checks each +// registered LSP server for runaway memory usage or a frozen response stream. +// Servers that fail either check are killed; mux.onExit triggers a respawn. func (d *Daemon) startWatchdog() { go func() { for { diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..24164f8 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,230 @@ +# ohm — Architecture + +## Why ohm exists + +Neovim's built-in LSP client starts a fresh server process per session. This creates several recurring problems in real-world use: + +- **Memory bloat** — gopls for a large Go repo can use 500–1500MB. Three Neovim sessions = three copies. +- **Monorepo duplication** — opening the same root directory in multiple windows spawns redundant servers that index the same files independently. +- **Stuck diagnostics** — when a Neovim session closes, the LSP server receives no `textDocument/didClose`, leaving stale diagnostics on the next open. +- **Session degradation** — long-running LSP servers accumulate state; a server spawned fresh each session never benefits from warmup. + +ohm moves the LSP server lifecycle out of Neovim and into a persistent daemon. One server per `{root_dir, language}` pair, shared across every Neovim session, for the lifetime of the workstation session. + +--- + +## Two-socket design + +ohm uses two distinct Unix sockets with different protocols: + +``` +┌─────────────────────────────────────────────────────┐ +│ Neovim instance │ +│ │ +│ lspconfig (ohm --client per buffer) │ +│ │ stdio LSP JSON-RPC │ +│ ▼ │ +│ ohm --client bridge process │ +│ │ unix socket LSP JSON-RPC │ +│ ▼ │ +│ [proxy socket] ◄──── per server, raw LSP │ +└─────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────┐ +│ ohm daemon │ +│ │ +│ Mux (fan-out, ID rewriting) │ +│ │ stdio LSP JSON-RPC │ +│ ▼ │ +│ LSP server process (gopls, rust-analyzer, ...) │ +└─────────────────────────────────────────────────────┘ + + [control socket] ◄──── msgpack-rpc, persistent channel + │ + ▼ + Neovim shim (client.lua rpcrequest/rpcnotify) +``` + +**Control socket** (`ohm.sock`) — speaks msgpack-rpc (Neovim's native RPC protocol). Used by the Lua plugin to send `attach`/`detach`/`status` commands. One persistent connection per Neovim instance. + +**Proxy socket** (e.g. `ohm-go-a3f1b2c4.sock`) — speaks raw LSP JSON-RPC. One socket per registered `{root_dir, language}` server. The `ohm --client` bridge connects here and forwards bytes to/from Neovim's lspconfig. + +Keeping the protocols separate avoids the corruption that plagued V1, where LSP frames and msgpack-rpc frames shared a single socket. + +--- + +## Request flow + +### Attach (new buffer opened) + +1. Neovim opens `main.go` in a Go project. +2. lspconfig calls `vim.lsp.rpc.start`, which ohm's `wire_lspconfig` hook has overridden to launch `ohm --client --socket --root --lang go -- gopls`. +3. `ohm --client` connects to the control socket and sends a msgpack-rpc `attach` request. +4. The daemon looks up `{root_dir="...", lang="go"}` in the registry: + - **Hit** — increments ref count, returns existing proxy socket path. + - **Miss** — spawns a new `gopls` process, creates a `Mux`, binds a proxy socket, registers the server, returns the proxy socket path. +5. `ohm --client` disconnects from the control socket, connects to the proxy socket, and begins bridging `stdin ↔ proxy` bidirectionally. +6. Neovim's lspconfig now believes it is speaking directly to gopls. + +### LSP request (e.g. hover) + +``` +Neovim → ohm --client (stdin→proxy) + → proxy socket → Mux.serveClient + → ID rewritten (client id → global id) + → WriteFrame to gopls stdin + → gopls processes request + → gopls stdout → Mux.Broadcast + → pending map lookup (global id → original client id + conn) + → ID restored + → WriteFrame to proxy socket + → ohm --client (proxy→stdout) → Neovim +``` + +### Server-pushed notification (e.g. publishDiagnostics) + +``` +gopls stdout → Mux.Broadcast + → no id, has method → broadcast path + → WriteFrame to every connected client + → all Neovim instances receive diagnostics +``` + +--- + +## Request ID rewriting + +LSP uses numeric request IDs chosen by the client. With multiple Neovim sessions sharing one gopls, their IDs collide (every session starts at 1). + +The Mux maintains a global atomic counter (`nextID`). On each incoming request: + +1. The original client ID is saved in a `pending` map keyed by `globalID`. +2. The message body is rewritten with `globalID` before forwarding to gopls. +3. When gopls responds, `Broadcast` looks up `globalID` in the pending map, rewrites the ID back to the original client value, and routes the response to that client's connection only. + +Notifications (no ID field) are broadcast to all clients since they are not responses to a specific request. + +--- + +## initialize caching + +The LSP `initialize` handshake is expensive: it triggers full project indexing in gopls. ohm ensures it happens exactly once per server lifetime. + +``` +First client Concurrent client Later client +────────────── ───────────────── ──────────── +send initialize + initInFlight = true ─────► sees initInFlight=true + forward to gopls block on <-initReady +gopls responds + cache initResponse + close(initReady) ──────► unblocked + send to client A rewrite ID, send cached send cached immediately +``` + +Three states tracked under `initMu`: +- `initResponse == nil`, `initInFlight == false` → first caller; forward to server +- `initResponse == nil`, `initInFlight == true` → concurrent caller; wait on `initReady` channel +- `initResponse != nil` → cached; return immediately with ID rewrite + +The `initialized` notification (sent after `initialize` succeeds) is only forwarded for the first client. Subsequent clients skip it via the same caching path. + +--- + +## Shutdown interception + +When a Neovim session closes, lspconfig sends `shutdown` then `exit`. Forwarding these to gopls would kill the shared server. + +`serveClient` intercepts both: + +- **`shutdown`** — a fake `{"result": null}` response is sent back to the client immediately. The request is never forwarded. +- **`exit`** (notification, no ID) — silently dropped. + +gopls never sees either message and stays running. + +--- + +## Ref counting and grace period + +Each `LSPServer` tracks a `Refs` count — the number of `ohm --client` bridge processes currently connected to its proxy socket. + +- `attach` → `IncrRef` +- `detach` → `DecrRef` + +When `Refs` reaches 0, a 10-second timer starts (`pendingKill`). If a new `attach` arrives within the window the timer is cancelled and the server is reused immediately. After 10 seconds the server is shut down gracefully. + +This handles the common case of closing and immediately reopening a file, or switching between splits. + +--- + +## Respawn + +When a server process exits unexpectedly (gopls crash, OOM kill), `Mux.Broadcast` reads EOF from the process stdout and calls `mux.onExit`, which triggers `respawnServer`. + +`respawnServer`: +1. Cancels any in-flight `pendingKill` timer for the key (a crash during the grace period must not let the timer kill the new process). +2. Spawns a fresh LSP process. +3. Creates a new `Mux` for the new process. +4. Swaps `server.Process` and `server.mux` in place under `server.mu`. +5. Starts `Broadcast` on the new mux. + +The proxy socket listener (`listenProxy`) keeps running throughout — it holds no reference to the old mux. New connections arriving after the swap go to the new mux automatically. + +Existing Neovim clients connected to the old mux will see their `serveClient` goroutines exit (write errors to a dead process), disconnect, and reconnect on the next LSP request via lspconfig. + +--- + +## Watchdog + +A goroutine wakes every 30 seconds and checks every registered server: + +| Check | Threshold | Action | +|---|---|---| +| RSS memory | > 1500 MB | graceful shutdown + remove | +| Last response age | > 5 minutes | graceful shutdown + remove | + +Memory is read from `/proc/{pid}/status` (VmRSS). Last response time is an atomic timestamp updated on every message received from the server's stdout. + +Both checks call `server.Close()` which sends a graceful LSP `shutdown`+`exit` sequence before killing the process. `mux.onExit` then triggers a respawn. + +--- + +## Graceful shutdown sequence + +`Mux.GracefulShutdown`: + +1. Register a synthetic internal `pending` entry (no client, just a `done` channel). +2. Send `{"method":"shutdown","id":}` to the server. +3. Wait up to 5 seconds for the response on `done`. +4. Send `{"method":"exit"}` notification. +5. Wait up to 2 seconds for the process to exit. +6. If still running after step 5, `Kill()`. + +--- + +## Concurrency model + +| Goroutine | Lifetime | What it does | +|---|---|---| +| `handleConn` | per control connection | decodes msgpack-rpc, dispatches attach/detach/status | +| `Broadcast` | per LSP server | reads server stdout, routes to clients | +| `serveClient` | per proxy client | reads client frames, rewrites IDs, writes to server stdin | +| `sendLoop` | per proxy client | drains send channel, writes frames to client conn | +| `listenProxy` | per LSP server | accepts new proxy connections | +| `captureStderr` | per LSP server | pipes server stderr to slog | +| `watchdog` | singleton | periodic memory + frozen checks | +| `respawnServer` | on crash | runs as `go m.onExit()` from Broadcast | + +Shared state and its lock: + +| State | Lock | +|---|---| +| `registry.servers` | `registry.mu` | +| `daemon.pendingKill` | `daemon.mu` | +| `mux.clients` | `mux.clientsMu` (RWMutex) | +| `mux.pending` | `mux.pendingMu` | +| `mux.initResponse / initInFlight` | `mux.initMu` | +| `server.mux / server.Process` | `server.mu` (RWMutex) | +| `mux.lastNs` | atomic | +| `mux.nextID` | atomic | diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go index 3127e2f..c9b5a7a 100644 --- a/rpc/rpc_test.go +++ b/rpc/rpc_test.go @@ -7,6 +7,105 @@ import ( "github.com/ugorji/go/codec" ) +func TestWriteResponse_roundtrip(t *testing.T) { + h := NewHandler() + var buf bytes.Buffer + + if err := h.WriteResponse(&buf, 42, "proxy-path"); err != nil { + t.Fatalf("WriteResponse: %v", err) + } + + var raw []interface{} + dec := codec.NewDecoder(&buf, &h.mh) + if err := dec.Decode(&raw); err != nil { + t.Fatalf("decode: %v", err) + } + if len(raw) != 4 { + t.Fatalf("expected 4 elements, got %d", len(raw)) + } + typ, _ := toUint64(raw[0]) + if typ != TypeResponse { + t.Errorf("type: want %d, got %d", TypeResponse, typ) + } + msgID, _ := toUint64(raw[1]) + if msgID != 42 { + t.Errorf("msgid: want 42, got %d", msgID) + } + if raw[2] != nil { + t.Errorf("error field: want nil, got %v", raw[2]) + } + if raw[3] != "proxy-path" { + t.Errorf("result: want proxy-path, got %v", raw[3]) + } +} + +func TestWriteResponse_nilResult(t *testing.T) { + h := NewHandler() + var buf bytes.Buffer + if err := h.WriteResponse(&buf, 1, nil); err != nil { + t.Fatalf("WriteResponse: %v", err) + } + var raw []interface{} + dec := codec.NewDecoder(&buf, &h.mh) + if err := dec.Decode(&raw); err != nil { + t.Fatalf("decode: %v", err) + } + if raw[3] != nil { + t.Errorf("result: want nil, got %v", raw[3]) + } +} + +func TestToUint64(t *testing.T) { + cases := []struct { + in interface{} + want uint64 + ok bool + }{ + {uint64(10), 10, true}, + {uint32(10), 10, true}, + {uint16(10), 10, true}, + {uint8(10), 10, true}, + {int64(10), 10, true}, + {int32(10), 10, true}, + {int16(10), 10, true}, + {int8(10), 10, true}, + {int(10), 10, true}, + {int64(-1), 0, false}, + {int8(-1), 0, false}, + {"nope", 0, false}, + {nil, 0, false}, + } + for _, tc := range cases { + got, ok := toUint64(tc.in) + if ok != tc.ok { + t.Errorf("toUint64(%T(%v)): ok=%v, want %v", tc.in, tc.in, ok, tc.ok) + } + if ok && got != tc.want { + t.Errorf("toUint64(%T(%v)): got %d, want %d", tc.in, tc.in, got, tc.want) + } + } +} + +func TestDecode_UnknownType(t *testing.T) { + raw := []interface{}{uint64(9), "method", []interface{}{}} + r := encodeMsg(t, raw) + h := NewHandler() + _, err := h.Decode(r) + if err == nil { + t.Fatal("expected error for unknown message type") + } +} + +func TestDecode_TooShort(t *testing.T) { + raw := []interface{}{uint64(0), uint64(1)} // request missing method+params + r := encodeMsg(t, raw) + h := NewHandler() + _, err := h.Decode(r) + if err == nil { + t.Fatal("expected error for too-short request") + } +} + func encodeMsg(t *testing.T, v interface{}) *bytes.Reader { t.Helper() var mh codec.MsgpackHandle