Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 157 additions & 4 deletions pkg/ffapi/openapi3.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/getkin/kin-openapi/openapi3"
Expand All @@ -48,6 +49,9 @@ type SwaggerGenOptions struct {
// this is useful to ensure that all fields are documented
PanicOnMissingDescription bool

// if set to true, the generator will panic if a oneOf is unresolved
PanicOnUnresolvedOneOfs bool

SupportFieldRedaction bool
DefaultRequestTimeout time.Duration

Expand Down Expand Up @@ -76,6 +80,36 @@ type BaseURLVariable struct {

var customRegexRemoval = regexp.MustCompile(`{(\w+)\:[^}]+}`)

var registeredSchemaTypes = map[string]interface{}{}
var registeredSchemaTypesLock sync.RWMutex

// RegisterSchemaType registers a schema name to a Go value for oneOf resolution.
func RegisterSchemaType(name string, value interface{}) {
if name == "" || value == nil {
return
}
registeredSchemaTypesLock.Lock()
defer registeredSchemaTypesLock.Unlock()
registeredSchemaTypes[name] = value
}

// UnregisterSchemaType removes a registered schema name.
func UnregisterSchemaType(name string) {
if name == "" {
return
}
registeredSchemaTypesLock.Lock()
defer registeredSchemaTypesLock.Unlock()
delete(registeredSchemaTypes, name)
}

func getRegisteredSchemaType(name string) (interface{}, bool) {
registeredSchemaTypesLock.RLock()
defer registeredSchemaTypesLock.RUnlock()
v, ok := registeredSchemaTypes[name]
return v, ok
}

type SwaggerGen struct {
options *SwaggerGenOptions
}
Expand Down Expand Up @@ -167,7 +201,7 @@ func (sg *SwaggerGen) ffInputTagHandler(ctx context.Context, route *Route, name
if sg.isTrue(tag.Get("ffexcludeinput")) {
return &openapi3gen.ExcludeSchemaSentinel{}
}
if taggedRoutes, ok := tag.Lookup("ffexcludeinput"); ok {
if taggedRoutes, ok := tag.Lookup("ffexcludeinput"); ok && route != nil {
for _, r := range strings.Split(taggedRoutes, ",") {
if route.Name == r {
return &openapi3gen.ExcludeSchemaSentinel{}
Expand All @@ -191,7 +225,7 @@ func (sg *SwaggerGen) ffTagHandler(ctx context.Context, route *Route, name strin
if sg.isTrue(tag.Get("ffexclude")) {
return &openapi3gen.ExcludeSchemaSentinel{}
}
if taggedRoutes, ok := tag.Lookup("ffexclude"); ok {
if taggedRoutes, ok := tag.Lookup("ffexclude"); ok && route != nil {
for _, r := range strings.Split(taggedRoutes, ",") {
if route.Name == r {
return &openapi3gen.ExcludeSchemaSentinel{}
Expand All @@ -203,11 +237,19 @@ func (sg *SwaggerGen) ffTagHandler(ctx context.Context, route *Route, name strin
key := fmt.Sprintf("%s.%s", structName, name)
description := i18n.Expand(ctx, i18n.MessageKey(key))
if description == key && sg.options.PanicOnMissingDescription {
return i18n.NewError(ctx, i18n.MsgFieldDescriptionMissing, key, route.Name)
routeName := ""
if route != nil {
routeName = route.Name
}
return i18n.NewError(ctx, i18n.MsgFieldDescriptionMissing, key, routeName)
}
schema.Description = description
} else if sg.options.PanicOnMissingDescription {
return i18n.NewError(ctx, i18n.MsgFFStructTagMissing, name, route.Name)
routeName := ""
if route != nil {
routeName = route.Name
}
return i18n.NewError(ctx, i18n.MsgFFStructTagMissing, name, routeName)
}
}
return nil
Expand Down Expand Up @@ -237,11 +279,119 @@ func (sg *SwaggerGen) addCustomType(t reflect.Type, schema *openapi3.Schema) {
}
}

func splitCSV(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
res := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
res = append(res, p)
}
}
return res
}

func componentNameFromRef(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if strings.HasPrefix(value, "#/components/schemas/") {
return strings.TrimPrefix(value, "#/components/schemas/")
}
if strings.Contains(value, "://") || strings.Contains(value, "#/") {
// If the value contains :// (external URL) or #/ (other internal fragments like #/paths/...), it’s not a schema component, so return "" to prevent misinterpreting the ref.
return ""
}
return value
}

func parseOneOfNames(value string) []string {
items := splitCSV(value)
out := make([]string, 0, len(items))
for _, item := range items {
name := componentNameFromRef(item)
if name != "" {
out = append(out, name)
}
}
return out
}

func (sg *SwaggerGen) applyOneOfTag(ctx context.Context, tag reflect.StructTag, componentSchemas openapi3.Schemas, schema *openapi3.Schema) error {
value := tag.Get("ffoneof")
if value == "" || schema == nil {
return nil
}
names := parseOneOfNames(value)
if len(names) == 0 {
return nil
}

for _, name := range names {
registeredSchemaType, ok := getRegisteredSchemaType(name)
if !ok && sg.options.PanicOnUnresolvedOneOfs {
return i18n.NewError(ctx, i18n.MsgFFOneOfReferencesUnregisteredSchema, name)
}
if ok {
sg.addRegisteredSchemaType(ctx, componentSchemas, name, registeredSchemaType)
}
}
if len(schema.OneOf) == 0 {
refs := make([]*openapi3.SchemaRef, 0, len(names))
for _, name := range names {
refs = append(refs, &openapi3.SchemaRef{Ref: "#/components/schemas/" + name})
}
schema.OneOf = refs
}
return nil
}

func (sg *SwaggerGen) ensureComponentAlias(components openapi3.Schemas, name string, schemaRef *openapi3.SchemaRef) {
// makes sure a schema exists in components.schemas under a given name, without overwriting anything
if name == "" || components == nil || schemaRef == nil {
return
}
if _, ok := components[name]; ok {
return
}
if schemaRef.Ref != "" {
components[name] = &openapi3.SchemaRef{Ref: schemaRef.Ref}
return
}
if schemaRef.Value != nil {
components[name] = &openapi3.SchemaRef{Value: schemaRef.Value}
}
}

func (sg *SwaggerGen) addRegisteredSchemaType(ctx context.Context, componentSchemas openapi3.Schemas, registeredSchemaTypeName string, registeredSchemaTypeValue interface{}) {

schemaCustomizer := func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error {
sg.addCustomType(t, schema)
if err := sg.applyOneOfTag(ctx, tag, componentSchemas, schema); err != nil {
return err
}
return sg.ffTagHandler(ctx, nil, name, tag, schema)
}
schemaRef, err := openapi3gen.NewSchemaRefForValue(registeredSchemaTypeValue, componentSchemas, openapi3gen.SchemaCustomizer(schemaCustomizer))
if err != nil {
panic(fmt.Sprintf("invalid schema registration for %s: %s", registeredSchemaTypeName, err))
}
sg.ensureComponentAlias(componentSchemas, registeredSchemaTypeName, schemaRef)

}

func (sg *SwaggerGen) addInput(ctx context.Context, doc *openapi3.T, route *Route, op *openapi3.Operation) {
var schemaRef *openapi3.SchemaRef
var err error
schemaCustomizer := func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error {
sg.addCustomType(t, schema)
if err := sg.applyOneOfTag(ctx, tag, doc.Components.Schemas, schema); err != nil {
return err
}
return sg.ffInputTagHandler(ctx, route, name, tag, schema)
}
switch {
Expand Down Expand Up @@ -338,6 +488,9 @@ func (sg *SwaggerGen) addOutput(ctx context.Context, doc *openapi3.T, route *Rou
s := i18n.Expand(ctx, i18n.APISuccessResponse)
schemaCustomizer := func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error {
sg.addCustomType(t, schema)
if err := sg.applyOneOfTag(ctx, tag, doc.Components.Schemas, schema); err != nil {
return err
}
return sg.ffOutputTagHandler(ctx, route, name, tag, schema)
}
switch {
Expand Down
Loading