From 1985ad23777297fb09db0338a8993a89920b4ced Mon Sep 17 00:00:00 2001 From: Graham Clark Date: Fri, 29 Jun 2018 12:00:28 -0400 Subject: [PATCH] Allow Equal() to be used by different packages in one program. Some packages may require Equal()'s parameters to be set in a particular way that is incompatible with other users within the same program. The global configuration parameters can be changed and restored, but this could lead to bugs due to race conditions. This commit makes the parameters that control Equal()'s operation part of a structure, Comparer, for which Equal() is now a method. Users can configure their own Comparer struct if desired. To preserve the existing package interface, the package-level Equals() method will use a default Comparer object that relies on pointers to the current global configuration parameters (pointers so that the operation of the global Equals() function will change immediately upon changing the value of any global configuration setting). --- deep.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 18 deletions(-) diff --git a/deep.go b/deep.go index ea6265c..4fc9566 100644 --- a/deep.go +++ b/deep.go @@ -41,13 +41,55 @@ var ( ErrNotHandled = errors.New("cannot compare the reflect.Kind") ) -type cmp struct { - diff []string - buff []string - floatFormat string +// Comparer is a struct capturing the configuration used for Equals(). The package +// Equal() function uses a default Comparer struct which references the global +// variables that control the execution of the equality algorithm. +type Comparer struct { + FloatPrecision *int + MaxDiff *int + MaxDepth *int + LogErrors *bool + CompareUnexportedFields *bool + ErrMaxRecursion *error + ErrTypeMismatch *error + ErrNotHandled *error } -var errorType = reflect.TypeOf((*error)(nil)).Elem() +// MakeComparer returns a Comparer struct with nil fields initialized to point +// to the global settings. +func MakeComparer(c Comparer) Comparer { + if c.FloatPrecision == nil { + c.FloatPrecision = &FloatPrecision + } + if c.MaxDiff == nil { + c.MaxDiff = &MaxDiff + } + if c.MaxDepth == nil { + c.MaxDepth = &MaxDepth + } + if c.LogErrors == nil { + c.LogErrors = &LogErrors + } + if c.CompareUnexportedFields == nil { + c.CompareUnexportedFields = &CompareUnexportedFields + } + if c.ErrMaxRecursion == nil { + c.ErrMaxRecursion = &ErrMaxRecursion + } + if c.ErrTypeMismatch == nil { + c.ErrTypeMismatch = &ErrTypeMismatch + } + if c.ErrNotHandled == nil { + c.ErrNotHandled = &ErrNotHandled + } + return c +} + +func makeDefaultComparer() Comparer { + return MakeComparer(Comparer{}) +} + +var DefaultComparer = makeDefaultComparer() // Equal compares variables a and b, recursing into their structure up to // MaxDepth levels deep, and returns a list of differences, or nil if there are @@ -56,12 +98,26 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem() // If a type has an Equal method, like time.Equal, it is called to check for // equality. func Equal(a, b interface{}) []string { + return DefaultComparer.Equal(a, b) +} + +type cmp struct { + diff []string + buff []string + floatFormat string + *Comparer +} + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +func (cp *Comparer) Equal(a, b interface{}) []string { aVal := reflect.ValueOf(a) bVal := reflect.ValueOf(b) c := &cmp{ diff: []string{}, buff: []string{}, - floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), + floatFormat: fmt.Sprintf("%%.%df", *cp.FloatPrecision), + Comparer: cp, } if a == nil && b == nil { return nil @@ -82,8 +138,8 @@ func Equal(a, b interface{}) []string { } func (c *cmp) equals(a, b reflect.Value, level int) { - if level > MaxDepth { - logError(ErrMaxRecursion) + if level > *c.MaxDepth { + c.logError(*c.ErrMaxRecursion) return } @@ -102,7 +158,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { bType := b.Type() if aType != bType { c.saveDiff(aType, bType) - logError(ErrTypeMismatch) + c.logError(*c.ErrTypeMismatch) return } @@ -181,7 +237,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } for i := 0; i < a.NumField(); i++ { - if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { + if aType.Field(i).PkgPath != "" && !*c.CompareUnexportedFields { continue // skip unexported field, e.g. s in type T struct {s string} } @@ -197,7 +253,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.pop() // pop field name from buff - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { break } } @@ -243,7 +299,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { return } } @@ -256,7 +312,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.push(fmt.Sprintf("map[%s]", key)) c.saveDiff("", b.MapIndex(key)) c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { return } } @@ -266,7 +322,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.push(fmt.Sprintf("array[%d]", i)) c.equals(a.Index(i), b.Index(i), level+1) c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { break } } @@ -300,7 +356,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { c.saveDiff("", b.Index(i)) } c.pop() - if len(c.diff) >= MaxDiff { + if len(c.diff) >= *c.MaxDiff { break } } @@ -335,7 +391,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) { } default: - logError(ErrNotHandled) + c.logError(*c.ErrNotHandled) } } @@ -358,8 +414,8 @@ func (c *cmp) saveDiff(aval, bval interface{}) { } } -func logError(err error) { - if LogErrors { +func (c *cmp) logError(err error) { + if *c.LogErrors { log.Println(err) } }