Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 40 additions & 45 deletions deep.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"log"
"reflect"
"slices"
"strings"
)

Expand Down Expand Up @@ -76,7 +77,7 @@ type cmp struct {
diff []string
buff []string
floatFormat string
flag map[byte]bool
conf Differ
}

var errorType = reflect.TypeOf((*error)(nil)).Elem()
Expand All @@ -92,38 +93,32 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem()
// When comparing a struct, if a field has the tag `deep:"-"` then it will be
// ignored.
func Equal(a, b interface{}, flags ...interface{}) []string {
aVal := reflect.ValueOf(a)
bVal := reflect.ValueOf(b)
c := &cmp{
diff: []string{},
buff: []string{},
floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
flag: map[byte]bool{},
}
for i := range flags {
c.flag[flags[i].(byte)] = true
}
if a == nil && b == nil {
return nil
} else if a == nil && b != nil {
c.saveDiff("<nil pointer>", b)
} else if a != nil && b == nil {
c.saveDiff(a, "<nil pointer>")
}
if len(c.diff) > 0 {
return c.diff
}
// error ignored to preserve API
differ, _ := New(
WithCompareFunctions(CompareFunctions),
WithCompareUnexportedFields(CompareUnexportedFields),
WithFloatPrecision(FloatPrecision),
WithIgnoreSliceOrder(hasFlag(flags, FLAG_IGNORE_SLICE_ORDER)),
WithLogErrors(LogErrors),
WithMaxDepth(MaxDepth),
WithMaxDiff(MaxDiff),
WithNilMapsAreEmpty(NilMapsAreEmpty),
WithNilPointersAreZero(NilPointersAreZero),
WithNilSlicesAreEmpty(NilSlicesAreEmpty),
)
return differ.Compare(a, b)
}

c.equals(aVal, bVal, 0)
if len(c.diff) > 0 {
return c.diff // diffs
}
return nil // no diffs
func hasFlag(flags []interface{}, flag byte) bool {
return slices.ContainsFunc(flags, func(i interface{}) bool {
v, ok := i.(byte)
return ok && v == flag
})
}

func (c *cmp) equals(a, b reflect.Value, level int) {
if MaxDepth > 0 && level > MaxDepth {
logError(ErrMaxRecursion)
if c.conf.maxDepth > 0 && level > c.conf.maxDepth {
c.logError(ErrMaxRecursion)
return
}

Expand Down Expand Up @@ -153,7 +148,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
bFullType := bType.PkgPath() + "." + bType.Name()
c.saveDiff(aFullType, bFullType)
}
logError(ErrTypeMismatch)
c.logError(ErrTypeMismatch)
return
}

Expand Down Expand Up @@ -193,10 +188,10 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
if bElem {
b = b.Elem()
}
if aElem && NilPointersAreZero && !a.IsValid() && b.IsValid() {
if aElem && c.conf.nilPointersAreZero && !a.IsValid() && b.IsValid() {
a = reflect.Zero(b.Type())
}
if bElem && NilPointersAreZero && !b.IsValid() && a.IsValid() {
if bElem && c.conf.nilPointersAreZero && !b.IsValid() && a.IsValid() {
b = reflect.Zero(a.Type())
}
c.equals(a, b, level+1)
Expand Down Expand Up @@ -244,7 +239,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
}

for i := 0; i < a.NumField(); i++ {
if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
if aType.Field(i).PkgPath != "" && !c.conf.compareUnexportedFields {
continue // skip unexported field, e.g. s in type T struct {s string}
}

Expand All @@ -264,7 +259,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {

c.pop() // pop field name from buff

if len(c.diff) >= MaxDiff {
if len(c.diff) >= c.conf.maxDiff {
break
}
}
Expand All @@ -285,7 +280,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
*/

if a.IsNil() || b.IsNil() {
if NilMapsAreEmpty {
if c.conf.nilMapsAreEmpty {
if a.IsNil() && b.Len() != 0 {
c.saveDiff("<nil map>", b)
return
Expand Down Expand Up @@ -320,7 +315,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {

c.pop()

if len(c.diff) >= MaxDiff {
if len(c.diff) >= c.conf.maxDiff {
return
}
}
Expand All @@ -333,7 +328,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.push(fmt.Sprintf("map[%v]", key))
c.saveDiff("<does not have key>", b.MapIndex(key))
c.pop()
if len(c.diff) >= MaxDiff {
if len(c.diff) >= c.conf.maxDiff {
return
}
}
Expand All @@ -343,12 +338,12 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.push(fmt.Sprintf("array[%d]", i))
c.equals(a.Index(i), b.Index(i), level+1)
c.pop()
if len(c.diff) >= MaxDiff {
if len(c.diff) >= c.conf.maxDiff {
break
}
}
case reflect.Slice:
if NilSlicesAreEmpty {
if c.conf.nilSlicesAreEmpty {
if a.IsNil() && b.Len() != 0 {
c.saveDiff("<nil slice>", b)
return
Expand Down Expand Up @@ -378,7 +373,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
return
}

if c.flag[FLAG_IGNORE_SLICE_ORDER] {
if c.conf.ignoreSliceOrder {
// Compare slices by value and value count; ignore order.
// Value equality is impliclity established by the maps:
// any value v1 will hash to the same map value if it's equal
Expand Down Expand Up @@ -411,7 +406,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.saveDiff("<no value>", b.Index(i))
}
c.pop()
if len(c.diff) >= MaxDiff {
if len(c.diff) >= c.conf.maxDiff {
break
}
}
Expand Down Expand Up @@ -451,7 +446,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
c.saveDiff(a.String(), b.String())
}
case reflect.Func:
if CompareFunctions {
if c.conf.compareFunctions {
if !a.IsNil() || !b.IsNil() {
aVal, bVal := "nil func", "nil func"
if !a.IsNil() {
Expand All @@ -464,7 +459,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
}
}
default:
logError(ErrNotHandled)
c.logError(ErrNotHandled)
}
}

Expand Down Expand Up @@ -506,8 +501,8 @@ func (c *cmp) cmpMapValueCounts(a, b reflect.Value, am, bm map[interface{}]int,
}
}

func logError(err error) {
if LogErrors {
func (c *cmp) logError(err error) {
if c.conf.logErrors {
log.Println(err)
}
}
186 changes: 186 additions & 0 deletions differ.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Package deep provides function deep.Equal which is like reflect.DeepEqual but
// returns a list of differences. This is helpful when comparing complex types
// like structures and maps.
package deep

import (
"fmt"
"reflect"
)

type Opt func(*Differ) error

// WithFloatPrecision is the number of decimal places to round float values
// to when comparing.
func WithFloatPrecision(p int) Opt {
return func(d *Differ) error {
d.floatPrecision = p
return nil
}
}

// WithMaxDiff specifies the maximum number of differences to return.
func WithMaxDiff(m int) Opt {
return func(d *Differ) error {
d.maxDiff = m
return nil
}
}

// WithMaxDepth specifies the maximum levels of a struct to recurse into,
// if greater than zero. If zero, there is no limit.
func WithMaxDepth(m int) Opt {
return func(d *Differ) error {
d.maxDepth = m
return nil
}
}

// WithLogErrors causes errors to be logged to STDERR when true.
func WithLogErrors(b bool) Opt {
return func(differ *Differ) error {
differ.logErrors = b
return nil
}
}

// WithCompareUnexportedFields causes unexported struct fields, like s in
// T{s int}, to be compared when true. This does not work for comparing
// error or Time types on unexported fields because methods on unexported
// fields cannot be called.
func WithCompareUnexportedFields(b bool) Opt {
return func(differ *Differ) error {
differ.compareUnexportedFields = b
return nil
}
}

// WithCompareFunctions compares functions the same as reflect.DeepEqual:
// only two nil functions are equal. Every other combination is not equal.
// This is disabled by default because previous versions of this package
// ignored functions. Enabling it can possibly report new diffs.
func WithCompareFunctions(b bool) Opt {
return func(differ *Differ) error {
differ.compareFunctions = b
return nil
}
}

// WithNilSlicesAreEmpty causes a nil slice to be equal to an empty slice.
func WithNilSlicesAreEmpty(b bool) Opt {
return func(differ *Differ) error {
differ.nilSlicesAreEmpty = b
return nil
}
}

// WithNilMapsAreEmpty causes a nil map to be equal to an empty map.
func WithNilMapsAreEmpty(b bool) Opt {
return func(differ *Differ) error {
differ.nilMapsAreEmpty = b
return nil
}
}

// WithNilPointersAreZero causes a nil pointer to be equal to a zero value.
func WithNilPointersAreZero(b bool) Opt {
return func(differ *Differ) error {
differ.nilPointersAreZero = b
return nil
}
}

// WithIgnoreSliceOrder causes Equal to ignore slice order so that
// []int{1, 2} and []int{2, 1} are equal. Only slices of primitive scalars
// like numbers and strings are supported. Slices of complex types,
// like []T where T is a struct, are undefined because Equal does not
// recurse into the slice value when this flag is enabled.
func WithIgnoreSliceOrder(b bool) Opt {
return func(differ *Differ) error {
differ.ignoreSliceOrder = b
return nil
}
}

type Differ struct {
compareFunctions bool
compareUnexportedFields bool
floatPrecision int
ignoreSliceOrder bool
logErrors bool
maxDepth int
maxDiff int
nilMapsAreEmpty bool
nilPointersAreZero bool
nilSlicesAreEmpty bool
}

func New(opts ...Opt) (d Differ, err error) {
d = Differ{
// options where zero-value equals default value are omitted
floatPrecision: 10,
maxDiff: 10,
}
for opt := range opts {
err = opts[opt](&d)
if err != nil {
return d, fmt.Errorf("invalid option: %w", err)
}
}
return d, nil
}

type Delta []string

func (d Delta) Equal(other Delta) bool {
if len(d) != len(other) {
return false
}
for i := range d {
if d[i] != other[i] {
return false
}
}
return true
}

func (d Delta) ToSlice() []string {
return d
}

// Compare compares variables a and b, recursing into their structure up to
// MaxDepth levels deep (if greater than zero), and returns a list of differences,
// or nil if there are none. Some differences may not be found if an error is
// also returned.
//
// If a type has an Equal method, like time.Equal, it is called to check for
// equality.
//
// When comparing a struct, if a field has the tag `deep:"-"` then it will be
// ignored.
func (d Differ) Compare(a, b any) Delta {
aVal := reflect.ValueOf(a)
bVal := reflect.ValueOf(b)
c := &cmp{
conf: d,
diff: []string{},
buff: []string{},
floatFormat: fmt.Sprintf("%%.%df", d.floatPrecision),
}
if a == nil && b == nil {
return nil
} else if a == nil && b != nil {
c.saveDiff("<nil pointer>", b)
} else if a != nil && b == nil {
c.saveDiff(a, "<nil pointer>")
}
if len(c.diff) > 0 {
return c.diff
}

c.equals(aVal, bVal, 0)
if len(c.diff) > 0 {
return c.diff // diffs
}
return nil // no diffs
}