Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 133 additions & 80 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,24 @@ import (
"syscall"
"time"

_ "github.com/go-sql-driver/mysql"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
logrus "github.com/sirupsen/logrus"

_ "github.com/go-sql-driver/mysql"
_ "modernc.org/sqlite"
)

// ///////////////////
// GLOBAL CONFIG
// ///////////////////
var (
dbDSN string
dbDriver string
serverAddr string
allowedOrigins []string

pongWait = 60 * time.Second
pingPeriod = 30 * time.Second

dbDSN string
dbDriver string
serverAddr string
allowedOrigins []string
pongWait = 60 * time.Second
pingPeriod = 30 * time.Second
offlineTTL time.Duration
maxQueuedMessagesPerPlayer int
maxConnectionsPerPlayer int
Expand Down Expand Up @@ -101,6 +98,9 @@ var (
// Rate limiting
limiters = make(map[string]*limiter)
lm sync.Mutex

// logFileHandle holds the open log file so it can be closed on shutdown.
logFileHandle *os.File
)

type limiter struct {
Expand Down Expand Up @@ -250,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) {
flushPendingMessages(playerID, c)
}

// unregisterConnection removes a websocket connection for a player and decrements the active connections metric.
// unregisterConnection removes a websocket connection for a player, explicitly closes the underlying
// websocket (releasing the file descriptor), and decrements the active connections metric.
// If no connections remain for the player, the player's entry is removed from the players map.
func unregisterConnection(playerID string, c *websocket.Conn) {
mu.Lock()
Expand All @@ -259,6 +260,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) {
if len(players[playerID]) == 0 {
delete(players, playerID)
}

_ = c.Close()
connections.Dec()
}

Expand Down Expand Up @@ -317,8 +320,8 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
// check connection limit
mu.Lock()
current := len(players[playerID])
mu.Unlock()
if current >= maxConnectionsPerPlayer {
mu.Unlock()
logrus.WithFields(logrus.Fields{
"player_id": playerID,
"current": current,
Expand All @@ -327,6 +330,7 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "too many connections", http.StatusTooManyRequests)
return
}
mu.Unlock()

// upgrade to WebSocket
conn, err := upgrader.Upgrade(w, r, nil)
Expand All @@ -353,9 +357,18 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {

ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()

done := make(chan struct{})
defer close(done)

go func() {
for range ticker.C {
_ = conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
for {
select {
case <-done:
return
case <-ticker.C:
_ = conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
}
}
}()

Expand Down Expand Up @@ -499,28 +512,105 @@ func initMetrics() {
prometheus.MustRegister(connections, messagesPublished, messagesDelivered)
}

// startOfflineMessageCleanup periodically removes expired pending messages from the offline queue.
// It stops when stopCh is closed.
func startOfflineMessageCleanup(stopCh <-chan struct{}) {
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

for {
select {
case <-stopCh:
return
case <-ticker.C:
now := time.Now()
pendingMu.Lock()
for pid, msgs := range pendingMessages {
filtered := msgs[:0]
for _, pm := range msgs {
if now.Sub(pm.timestamp) <= offlineTTL {
filtered = append(filtered, pm)
}
}
if len(filtered) == 0 {
delete(pendingMessages, pid)
} else {
pendingMessages[pid] = filtered
}
}
pendingMu.Unlock()
}
}
}()
}

// buildServer constructs and returns the HTTP server and its ServeMux with all routes registered.
func buildServer() *http.Server {
mux := http.NewServeMux()
mux.HandleFunc("/ws", wsHandler)
mux.HandleFunc("/publish", publishHandler)
mux.HandleFunc("/broadcast", broadcastHandler)
mux.Handle("/metrics", promhttp.Handler())

return &http.Server{Addr: serverAddr, Handler: mux}
}

// runServer starts the HTTP server and blocks until it shuts down.
// It listens for SIGINT/SIGTERM, closes stopCh to signal background goroutines,
// then performs a graceful HTTP shutdown followed by closing all websocket connections.
// Returns an error if the server exits unexpectedly.
func runServer(server *http.Server, stopCh chan struct{}) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)

go func() {
<-quit
logrus.Info("Shutting down server...")
// Signal all background goroutines (cleanup, revalidation) to stop.
close(stopCh)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
closeAllConnections()
}()

logrus.Infof("Server listening on %s", serverAddr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("server error: %w", err)
}
return nil
}

// startTokenRevalidation periodically validates all active websocket tokens.
// Invalid tokens cause connections to be closed and removed.
func startTokenRevalidation(interval time.Duration) {
// The provided stopCh can be closed to stop the revalidation loop and its ticker.
func startTokenRevalidation(interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
mu.Lock()
for playerID, conns := range players {
for c, wc := range conns {
_, valid := validateToken(wc.token, false)
if !valid {
logrus.WithFields(logrus.Fields{
"player_id": playerID,
}).Info("Token invalid, closing connection")
_ = wc.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "token expired"))
_ = wc.conn.Close()
delete(conns, c)
connections.Dec()
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
mu.Lock()
for playerID, conns := range players {
for c, wc := range conns {
_, valid := validateToken(wc.token, false)
if !valid {
logrus.WithFields(logrus.Fields{
"player_id": playerID,
}).Info("Token invalid, closing connection")
_ = wc.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "token expired"))
_ = wc.conn.Close()
delete(conns, c)
connections.Dec()
}
}
}
mu.Unlock()
}
mu.Unlock()
}
}()
}
Expand Down Expand Up @@ -588,6 +678,7 @@ func daemonizeSelf() error {
// setupLogging configures logrus logging for the application.
// It sets the output destination and log level based on global flags.
// Returns an error if the log file cannot be opened or if the log level is invalid.
// The opened log file handle is stored in logFileHandle so it can be closed on shutdown.
func setupLogging() error {
logrus.SetFormatter(&logrus.JSONFormatter{})

Expand All @@ -596,6 +687,8 @@ func setupLogging() error {
if err != nil {
return fmt.Errorf("failed to open log file %s: %w", logFile, err)
}

logFileHandle = f
logrus.SetOutput(f)
} else {
logrus.SetOutput(os.Stdout)
Expand Down Expand Up @@ -657,11 +750,17 @@ func run() error {
return fmt.Errorf("failed to setup logging: %w", err)
}

if logFileHandle != nil {
defer logFileHandle.Close()
}

// Initialize DB
if err := initDB(); err != nil {
return fmt.Errorf("failed to init DB: %w", err)
}

defer db.Close()

// Daemonize if needed
if daemonize {
if err := daemonizeSelf(); err != nil {
Expand All @@ -671,57 +770,11 @@ func run() error {

initMetrics()

// Start offline message cleanup
go func() {
ticker := time.NewTicker(30 * time.Second)
for range ticker.C {
now := time.Now()
pendingMu.Lock()
for pid, msgs := range pendingMessages {
filtered := msgs[:0]
for _, pm := range msgs {
if now.Sub(pm.timestamp) <= offlineTTL {
filtered = append(filtered, pm)
}
}
if len(filtered) == 0 {
delete(pendingMessages, pid)
} else {
pendingMessages[pid] = filtered
}
}
pendingMu.Unlock()
}
}()
// stopCh is closed on shutdown to signal background goroutines to exit.
stopCh := make(chan struct{})

// start WS token revalidation
startTokenRevalidation(tokenRevalidationPeriod)
startOfflineMessageCleanup(stopCh)
startTokenRevalidation(tokenRevalidationPeriod, stopCh)

mux := http.NewServeMux()
mux.HandleFunc("/ws", wsHandler)
mux.HandleFunc("/publish", publishHandler)
mux.HandleFunc("/broadcast", broadcastHandler)
mux.Handle("/metrics", promhttp.Handler())

server := &http.Server{Addr: serverAddr, Handler: mux}

// Graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-quit
logrus.Info("Shutting down server...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
closeAllConnections()
}()

logrus.Infof("Server listening on %s", serverAddr)
err := server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("server error: %w", err)
}

return nil
return runServer(buildServer(), stopCh)
}