Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions echokit/d3authjwtauthenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package echokit

import (
"errors"
"net/http"
"net/url"
"strings"
"time"

"github.com/auth0/go-jwt-middleware/v2/jwks"
"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/half-ogre/go-kit/kit"
"github.com/labstack/echo/v4"
)

const (
d3AuthJWTAuthenticatorContextKey = "go-kit-echokit-d3auth-jwt-authenticated-user"
)

type D3AuthConfig struct {
BaseURL string
Audience string
}

type D3AuthJWTAuthenticator struct {
config D3AuthConfig
jwtValidator *validator.Validator
}

func NewD3AuthJWTAuthenticator(config D3AuthConfig) (Authenticator, error) {
issuerURL := strings.TrimRight(config.BaseURL, "/")

jwksURL, err := url.Parse(issuerURL)
if err != nil {
return nil, kit.WrapError(err, "failed to parse d3-auth base URL")
}

provider := jwks.NewCachingProvider(jwksURL, 5*time.Minute)

jwtValidator, err := validator.New(
provider.KeyFunc,
validator.RS256,
issuerURL,
[]string{config.Audience},
validator.WithCustomClaims(
func() validator.CustomClaims {
return &Auth0CustomClaims{}
},
),
validator.WithAllowedClockSkew(time.Minute),
)
if err != nil {
return nil, kit.WrapError(err, "failed to create d3-auth JWT validator")
}

return &D3AuthJWTAuthenticator{
config: config,
jwtValidator: jwtValidator,
}, nil
}

func (a *D3AuthJWTAuthenticator) AuthenticateRequest(c echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return nil
}

authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return nil
}

validateResult, err := a.jwtValidator.ValidateToken(c.Request().Context(), authHeaderParts[1])
if err != nil {
return err
}

validatedClaims, ok := validateResult.(*validator.ValidatedClaims)
if !ok {
return errors.New("failed to cast to ValidatedClaims")
}

customClaims, ok := validatedClaims.CustomClaims.(*Auth0CustomClaims)
if !ok {
return errors.New("failed to cast custom claims")
}

authenticatedUser := AuthenticatedUser{
Sub: validatedClaims.RegisteredClaims.Subject,
Name: customClaims.Name,
GivenName: customClaims.GivenName,
FamilyName: customClaims.FamilyName,
MiddleName: customClaims.MiddleName,
Nickname: customClaims.Nickname,
PreferredUsername: customClaims.PreferredUsername,
Email: customClaims.Email,
EmailVerified: customClaims.EmailVerified,
Picture: customClaims.Picture,
UpdatedAt: customClaims.UpdatedAt,
Permissions: customClaims.Permissions,
}

c.Set(d3AuthJWTAuthenticatorContextKey, &authenticatedUser)

return nil
}

func (a *D3AuthJWTAuthenticator) GetAuthenticatedUser(c echo.Context) (*AuthenticatedUser, error) {
user, ok := c.Get(d3AuthJWTAuthenticatorContextKey).(*AuthenticatedUser)
if !ok || user == nil {
return nil, errors.New("no authenticated user")
}
return user, nil
}

func (a *D3AuthJWTAuthenticator) HandleNotAuthenticated(c echo.Context) error {
return c.NoContent(http.StatusUnauthorized)
}

func (a *D3AuthJWTAuthenticator) IsAuthenticated(c echo.Context) (bool, error) {
user := c.Get(d3AuthJWTAuthenticatorContextKey)
return user != nil, nil
}
31 changes: 22 additions & 9 deletions echokit/staticfiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (m *StaticFilesMiddleware) build() error {
// Other files are fingerprinted only if they are referenced from HTML, JS, or CSS.
importFromRegex := regexp.MustCompile(`(from\s+['"])(\./[^'"]+|\.\.\/[^'"]+)(['"])`)
importSideEffectRegex := regexp.MustCompile(`(import\s+['"])(\./[^'"]+|\.\.\/[^'"]+)(['"])`)
dynamicImportRegex := regexp.MustCompile(`(import\s*\(\s*['"])(\./[^'"]+|\.\.\/[^'"]+)(['"]\s*\))`)
dynamicImportRegex := regexp.MustCompile(`([^{]import\s*\(\s*['"])(\./[^'"]+|\.\.\/[^'"]+)(['"]\s*\))`)
cssURLRegex := regexp.MustCompile(`(url\s*\(\s*['"]?)(\./[^'")]+|\.\.\/[^'")]+)(['"]?\s*\))`)
htmlSrcRegex := regexp.MustCompile(`(?:src|href)="(/[^"]+)"`)

Expand Down Expand Up @@ -356,16 +356,29 @@ func (m *StaticFilesMiddleware) build() error {
}
}
}
// Add any remaining files (cycles or disconnected)
for path := range rawFiles {
found := false
for _, p := range order {
if p == path {
found = true
break
// Detect and warn about circular dependencies — files stuck in cycles
// will have their imports only partially rewritten (fingerprinting may be wrong).
ordered := make(map[string]bool, len(order))
for _, p := range order {
ordered[p] = true
}
for path, deg := range inDegree {
if deg > 0 && !ordered[path] {
// Find which of this file's dependencies are also stuck
var cycle []string
for _, child := range deps[path] {
if !ordered[child] {
cycle = append(cycle, child)
}
}
slog.Warn("static file has circular dependency — fingerprinted imports may be incomplete",
"file", path, "unresolved_deps", cycle)
}
if !found {
}

// Add remaining files (cycles or disconnected) in arbitrary order
for path := range rawFiles {
if !ordered[path] {
order = append(order, path)
}
}
Expand Down
58 changes: 58 additions & 0 deletions echokit/staticfiles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package echokit

import (
"context"
"log/slog"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -538,6 +539,63 @@ func extractFingerprintedPath(html, prefix string) string {
return html[start:end]
}

func TestStaticFilesMiddleware_CircularDependencyWarning(t *testing.T) {
t.Run("logs_warning_for_circular_imports", func(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "lib"), 0755)

// Create a circular dependency: a.js imports b.js, b.js imports a.js
os.WriteFile(filepath.Join(dir, "index.html"), []byte(`<html><body><script type="module" src="/lib/app.js"></script></body></html>`), 0644)
os.WriteFile(filepath.Join(dir, "lib", "a.js"), []byte(`import { B } from './b.js'; export const A = 'a' + B;`), 0644)
os.WriteFile(filepath.Join(dir, "lib", "b.js"), []byte(`import { A } from './a.js'; export const B = 'b' + A;`), 0644)
os.WriteFile(filepath.Join(dir, "lib", "app.js"), []byte(`import { A } from './a.js'; console.log(A);`), 0644)

// Capture log output
var logBuf strings.Builder
oldLogger := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelWarn})))
defer slog.SetDefault(oldLogger)

m := NewStaticFilesMiddleware(dir, false)
defer m.Close()
e := echo.New()
e.Use(m.Handler())

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, logBuf.String(), "circular dependency")
})

t.Run("no_warning_for_acyclic_imports", func(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "lib"), 0755)

os.WriteFile(filepath.Join(dir, "index.html"), []byte(`<html><body><script type="module" src="/lib/app.js"></script></body></html>`), 0644)
os.WriteFile(filepath.Join(dir, "lib", "utils.js"), []byte("export const U = 1;"), 0644)
os.WriteFile(filepath.Join(dir, "lib", "app.js"), []byte(`import { U } from './utils.js'; console.log(U);`), 0644)

var logBuf strings.Builder
oldLogger := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelWarn})))
defer slog.SetDefault(oldLogger)

m := NewStaticFilesMiddleware(dir, false)
defer m.Close()
e := echo.New()
e.Use(m.Handler())

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.NotContains(t, logBuf.String(), "circular dependency")
})
}

func TestStaticFilesMiddleware_LiveReload(t *testing.T) {
t.Run("triggers_reload_on_file_change", func(t *testing.T) {
dir := t.TempDir()
Expand Down
Loading