diff --git a/deep.go b/deep.go index 4be3e1f..936d8ee 100644 --- a/deep.go +++ b/deep.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "reflect" + "slices" "strings" ) @@ -76,7 +77,7 @@ type cmp struct { diff []string buff []string floatFormat string - flag map[byte]bool + conf Differ } var errorType = reflect.TypeOf((*error)(nil)).Elem() @@ -92,38 +93,32 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem() // When comparing a struct, if a field has the tag `deep:"-"` then it will be // ignored. func Equal(a, b interface{}, flags ...interface{}) []string { - aVal := reflect.ValueOf(a) - bVal := reflect.ValueOf(b) - c := &cmp{ - diff: []string{}, - buff: []string{}, - floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), - flag: map[byte]bool{}, - } - for i := range flags { - c.flag[flags[i].(byte)] = true - } - if a == nil && b == nil { - return nil - } else if a == nil && b != nil { - c.saveDiff("", b) - } else if a != nil && b == nil { - c.saveDiff(a, "") - } - if len(c.diff) > 0 { - return c.diff - } + // error ignored to preserve API + differ, _ := New( + WithCompareFunctions(CompareFunctions), + WithCompareUnexportedFields(CompareUnexportedFields), + WithFloatPrecision(FloatPrecision), + WithIgnoreSliceOrder(hasFlag(flags, FLAG_IGNORE_SLICE_ORDER)), + WithLogErrors(LogErrors), + WithMaxDepth(MaxDepth), + WithMaxDiff(MaxDiff), + WithNilMapsAreEmpty(NilMapsAreEmpty), + WithNilPointersAreZero(NilPointersAreZero), + WithNilSlicesAreEmpty(NilSlicesAreEmpty), + ) + return differ.Compare(a, b) +} - c.equals(aVal, bVal, 0) - if len(c.diff) > 0 { - return c.diff // diffs - } - return nil // no diffs +func hasFlag(flags []interface{}, flag byte) bool { + return slices.ContainsFunc(flags, func(i interface{}) bool { + v, ok := i.(byte) + return ok && v == flag + }) } func (c *cmp) equals(a, b reflect.Value, level int) { - if MaxDepth > 0 && level > MaxDepth { - logError(ErrMaxRecursion) + if c.conf.maxDepth > 0 && level > c.conf.maxDepth { + c.logError(ErrMaxRecursion) return } @@ -153,7 +148,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { bFullType := bType.PkgPath() + "." + bType.Name() c.saveDiff(aFullType, bFullType) } - logError(ErrTypeMismatch) + c.logError(ErrTypeMismatch) return } @@ -193,10 +188,10 @@ func (c *cmp) equals(a, b reflect.Value, level int) { if bElem { b = b.Elem() } - if aElem && NilPointersAreZero && !a.IsValid() && b.IsValid() { + if aElem && c.conf.nilPointersAreZero && !a.IsValid() && b.IsValid() { a = reflect.Zero(b.Type()) } - if bElem && NilPointersAreZero && !b.IsValid() && a.IsValid() { + if bElem && c.conf.nilPointersAreZero && !b.IsValid() && a.IsValid() { b = reflect.Zero(a.Type()) } c.equals(a, b, level+1) @@ -244,7 +239,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } for i := 0; i < a.NumField(); i++ { - if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { + if aType.Field(i).PkgPath != "" && !c.conf.compareUnexportedFields { continue // skip unexported field, e.g. s in type T struct {s string} } @@ -264,7 +259,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.pop() // pop field name from buff - if len(c.diff) >= MaxDiff { + if len(c.diff) >= c.conf.maxDiff { break } } @@ -285,7 +280,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { */ if a.IsNil() || b.IsNil() { - if NilMapsAreEmpty { + if c.conf.nilMapsAreEmpty { if a.IsNil() && b.Len() != 0 { c.saveDiff("", b) return @@ -320,7 +315,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= c.conf.maxDiff { return } } @@ -333,7 +328,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.push(fmt.Sprintf("map[%v]", key)) c.saveDiff("", b.MapIndex(key)) c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= c.conf.maxDiff { return } } @@ -343,12 +338,12 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.push(fmt.Sprintf("array[%d]", i)) c.equals(a.Index(i), b.Index(i), level+1) c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= c.conf.maxDiff { break } } case reflect.Slice: - if NilSlicesAreEmpty { + if c.conf.nilSlicesAreEmpty { if a.IsNil() && b.Len() != 0 { c.saveDiff("", b) return @@ -378,7 +373,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { return } - if c.flag[FLAG_IGNORE_SLICE_ORDER] { + if c.conf.ignoreSliceOrder { // Compare slices by value and value count; ignore order. // Value equality is impliclity established by the maps: // any value v1 will hash to the same map value if it's equal @@ -411,7 +406,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.saveDiff("", b.Index(i)) } c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= c.conf.maxDiff { break } } @@ -451,7 +446,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.saveDiff(a.String(), b.String()) } case reflect.Func: - if CompareFunctions { + if c.conf.compareFunctions { if !a.IsNil() || !b.IsNil() { aVal, bVal := "nil func", "nil func" if !a.IsNil() { @@ -464,7 +459,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } } default: - logError(ErrNotHandled) + c.logError(ErrNotHandled) } } @@ -506,8 +501,8 @@ func (c *cmp) cmpMapValueCounts(a, b reflect.Value, am, bm map[interface{}]int, } } -func logError(err error) { - if LogErrors { +func (c *cmp) logError(err error) { + if c.conf.logErrors { log.Println(err) } } diff --git a/differ.go b/differ.go new file mode 100644 index 0000000..8060577 --- /dev/null +++ b/differ.go @@ -0,0 +1,186 @@ +// Package deep provides function deep.Equal which is like reflect.DeepEqual but +// returns a list of differences. This is helpful when comparing complex types +// like structures and maps. +package deep + +import ( + "fmt" + "reflect" +) + +type Opt func(*Differ) error + +// WithFloatPrecision is the number of decimal places to round float values +// to when comparing. +func WithFloatPrecision(p int) Opt { + return func(d *Differ) error { + d.floatPrecision = p + return nil + } +} + +// WithMaxDiff specifies the maximum number of differences to return. +func WithMaxDiff(m int) Opt { + return func(d *Differ) error { + d.maxDiff = m + return nil + } +} + +// WithMaxDepth specifies the maximum levels of a struct to recurse into, +// if greater than zero. If zero, there is no limit. +func WithMaxDepth(m int) Opt { + return func(d *Differ) error { + d.maxDepth = m + return nil + } +} + +// WithLogErrors causes errors to be logged to STDERR when true. +func WithLogErrors(b bool) Opt { + return func(differ *Differ) error { + differ.logErrors = b + return nil + } +} + +// WithCompareUnexportedFields causes unexported struct fields, like s in +// T{s int}, to be compared when true. This does not work for comparing +// error or Time types on unexported fields because methods on unexported +// fields cannot be called. +func WithCompareUnexportedFields(b bool) Opt { + return func(differ *Differ) error { + differ.compareUnexportedFields = b + return nil + } +} + +// WithCompareFunctions compares functions the same as reflect.DeepEqual: +// only two nil functions are equal. Every other combination is not equal. +// This is disabled by default because previous versions of this package +// ignored functions. Enabling it can possibly report new diffs. +func WithCompareFunctions(b bool) Opt { + return func(differ *Differ) error { + differ.compareFunctions = b + return nil + } +} + +// WithNilSlicesAreEmpty causes a nil slice to be equal to an empty slice. +func WithNilSlicesAreEmpty(b bool) Opt { + return func(differ *Differ) error { + differ.nilSlicesAreEmpty = b + return nil + } +} + +// WithNilMapsAreEmpty causes a nil map to be equal to an empty map. +func WithNilMapsAreEmpty(b bool) Opt { + return func(differ *Differ) error { + differ.nilMapsAreEmpty = b + return nil + } +} + +// WithNilPointersAreZero causes a nil pointer to be equal to a zero value. +func WithNilPointersAreZero(b bool) Opt { + return func(differ *Differ) error { + differ.nilPointersAreZero = b + return nil + } +} + +// WithIgnoreSliceOrder causes Equal to ignore slice order so that +// []int{1, 2} and []int{2, 1} are equal. Only slices of primitive scalars +// like numbers and strings are supported. Slices of complex types, +// like []T where T is a struct, are undefined because Equal does not +// recurse into the slice value when this flag is enabled. +func WithIgnoreSliceOrder(b bool) Opt { + return func(differ *Differ) error { + differ.ignoreSliceOrder = b + return nil + } +} + +type Differ struct { + compareFunctions bool + compareUnexportedFields bool + floatPrecision int + ignoreSliceOrder bool + logErrors bool + maxDepth int + maxDiff int + nilMapsAreEmpty bool + nilPointersAreZero bool + nilSlicesAreEmpty bool +} + +func New(opts ...Opt) (d Differ, err error) { + d = Differ{ + // options where zero-value equals default value are omitted + floatPrecision: 10, + maxDiff: 10, + } + for opt := range opts { + err = opts[opt](&d) + if err != nil { + return d, fmt.Errorf("invalid option: %w", err) + } + } + return d, nil +} + +type Delta []string + +func (d Delta) Equal(other Delta) bool { + if len(d) != len(other) { + return false + } + for i := range d { + if d[i] != other[i] { + return false + } + } + return true +} + +func (d Delta) ToSlice() []string { + return d +} + +// Compare compares variables a and b, recursing into their structure up to +// MaxDepth levels deep (if greater than zero), and returns a list of differences, +// or nil if there are none. Some differences may not be found if an error is +// also returned. +// +// If a type has an Equal method, like time.Equal, it is called to check for +// equality. +// +// When comparing a struct, if a field has the tag `deep:"-"` then it will be +// ignored. +func (d Differ) Compare(a, b any) Delta { + aVal := reflect.ValueOf(a) + bVal := reflect.ValueOf(b) + c := &cmp{ + conf: d, + diff: []string{}, + buff: []string{}, + floatFormat: fmt.Sprintf("%%.%df", d.floatPrecision), + } + if a == nil && b == nil { + return nil + } else if a == nil && b != nil { + c.saveDiff("", b) + } else if a != nil && b == nil { + c.saveDiff(a, "") + } + if len(c.diff) > 0 { + return c.diff + } + + c.equals(aVal, bVal, 0) + if len(c.diff) > 0 { + return c.diff // diffs + } + return nil // no diffs +}