diff --git a/handler.go b/handler.go index da49041..9b638ab 100644 --- a/handler.go +++ b/handler.go @@ -9,6 +9,33 @@ import ( "time" ) +// responseWriter はステータスコードを記録するためのラッパー。 +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +// requestLogger はリクエストのメソッド・パス・ステータス・所要時間をログ出力するミドルウェア。 +func requestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + next.ServeHTTP(rw, r) + slog.Info("リクエスト処理", + "method", r.Method, + "path", r.URL.Path, + "status", rw.statusCode, + "duration", time.Since(start).String(), + "remote", r.RemoteAddr, + ) + }) +} + // startTime はサーバー起動時刻を記録する。 var startTime = time.Now() diff --git a/handler_test.go b/handler_test.go index fdf77a0..dabde8f 100644 --- a/handler_test.go +++ b/handler_test.go @@ -350,3 +350,26 @@ func TestIndexPageNotFound(t *testing.T) { t.Fatalf("expected 404 for unknown path, got %d", w.Code) } } + +func TestRequestLoggerMiddleware(t *testing.T) { + mux, _ := setupTestServer() + handler := requestLogger(mux) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 through middleware, got %d", w.Code) + } +} + +func TestResponseWriterStatusCode(t *testing.T) { + w := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + rw.WriteHeader(http.StatusNotFound) + if rw.statusCode != http.StatusNotFound { + t.Fatalf("expected statusCode=404, got %d", rw.statusCode) + } +} diff --git a/main.go b/main.go index b4f58e8..7860a12 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,14 @@ package main import ( + "context" + "errors" "log/slog" "net/http" "os" + "os/signal" + "syscall" + "time" ) // getEnv は環境変数を取得し、未設定の場合はデフォルト値を返す。 @@ -20,9 +25,35 @@ func main() { registerRoutes(mux, store) addr := ":" + getEnv("PORT", "8080") - slog.Info("BringItサーバーを起動しました", "addr", "http://localhost"+addr) - if err := http.ListenAndServe(addr, mux); err != nil { - slog.Error("サーバーの起動に失敗しました", "error", err) + srv := &http.Server{ + Addr: addr, + Handler: requestLogger(mux), + ReadTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + // シグナルを受信してグレースフルシャットダウンするチャネル + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + go func() { + slog.Info("BringItサーバーを起動しました", "addr", "http://localhost"+addr) + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Error("サーバーの起動に失敗しました", "error", err) + os.Exit(1) + } + }() + + <-quit + slog.Info("シャットダウンシグナルを受信しました") + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + slog.Error("グレースフルシャットダウンに失敗しました", "error", err) os.Exit(1) } + slog.Info("サーバーを正常に停止しました") }