Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions bsonkit/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
"github.com/shopspring/decimal"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"golang.org/x/text/collate"
)

// Compare will compare two bson values and return their order according to the
// BSON type comparison order specification:
// https://docs.mongodb.com/manual/reference/bson-type-comparison-order.
func Compare(lv, rv interface{}) int {
func Compare(lv, rv interface{}, collator *collate.Collator) int {
// get types
lc, _ := Inspect(lv)
rc, _ := Inspect(rv)
Expand All @@ -32,11 +33,11 @@ func Compare(lv, rv interface{}) int {
case Number:
return compareNumbers(lv, rv)
case String:
return compareStrings(lv, rv)
return compareStrings(lv, rv, collator)
case Document:
return compareDocuments(lv, rv)
return compareDocuments(lv, rv, collator)
case Array:
return compareArrays(lv, rv)
return compareArrays(lv, rv, collator)
case Binary:
return compareBinaries(lv, rv)
case ObjectID:
Expand Down Expand Up @@ -105,18 +106,20 @@ func compareNumbers(lv, rv interface{}) int {
panic("bsonkit: unreachable")
}

func compareStrings(lv, rv interface{}) int {
func compareStrings(lv, rv interface{}, collator *collate.Collator) int {
// get strings
l := lv.(string)
r := rv.(string)

// compare strings
res := strings.Compare(l, r)

if collator != nil {
res = collator.Compare([]byte(l), []byte(r))
}
return res
}

func compareDocuments(lv, rv interface{}) int {
func compareDocuments(lv, rv interface{}, collator *collate.Collator) int {
// get documents
l := lv.(bson.D)
r := rv.(bson.D)
Expand Down Expand Up @@ -150,14 +153,14 @@ func compareDocuments(lv, rv interface{}) int {
}

// compare values
res = Compare(l[i].Value, r[i].Value)
res = Compare(l[i].Value, r[i].Value, collator)
if res != 0 {
return res
}
}
}

func compareArrays(lv, rv interface{}) int {
func compareArrays(lv, rv interface{}, collator *collate.Collator) int {
// get array
l := lv.(bson.A)
r := rv.(bson.A)
Expand Down Expand Up @@ -185,7 +188,7 @@ func compareArrays(lv, rv interface{}) int {
}

// compare elements
res := Compare(l[i], r[i])
res := Compare(l[i], r[i], collator)
if res != 0 {
return res
}
Expand Down
8 changes: 4 additions & 4 deletions bsonkit/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import (

func TestCompare(t *testing.T) {
// equality
assert.Equal(t, 0, Compare(bson.D{}, bson.D{}))
assert.Equal(t, 0, Compare(bson.D{}, bson.D{}, nil))

// less than
assert.Equal(t, -1, Compare("foo", false))
assert.Equal(t, -1, Compare("foo", false, nil))

// greater than
assert.Equal(t, 1, Compare(false, "foo"))
assert.Equal(t, 1, Compare(false, "foo", nil))

// decimal
dec, err := primitive.ParseDecimal128("3.14")
assert.NoError(t, err)
assert.Equal(t, 1, Compare(5.0, dec))
assert.Equal(t, 1, Compare(5.0, dec, nil))
}
2 changes: 1 addition & 1 deletion bsonkit/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Index struct {
func NewIndex(unique bool, columns []Column) *Index {
return &Index{
btree: btree.NewBTreeG[Doc](func(a, b Doc) bool {
return Order(a, b, columns, !unique) < 0
return Order(a, b, columns, !unique, nil) < 0
}),
}
}
Expand Down
4 changes: 2 additions & 2 deletions bsonkit/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Collect(list List, path string, compact, merge, flatten, distinct bool) bso

// sort results
sort.Slice(result, func(i, j int) bool {
return Compare(result[i], result[j]) < 0
return Compare(result[i], result[j], nil) < 0
})

// prepare distincts
Expand All @@ -108,7 +108,7 @@ func Collect(list List, path string, compact, merge, flatten, distinct bool) bso
var prevValue interface{}
for _, value := range result {
// check if same as previous value
if len(distincts) > 0 && Compare(prevValue, value) == 0 {
if len(distincts) > 0 && Compare(prevValue, value, nil) == 0 {
continue
}

Expand Down
36 changes: 18 additions & 18 deletions bsonkit/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (s *Schema) evaluateGeneric(value interface{}, valueClass Class, valueType
}
var ok bool
for _, enum := range kv {
if Compare(enum, value) == 0 {
if Compare(enum, value, nil) == 0 {
ok = true
}
}
Expand Down Expand Up @@ -318,10 +318,10 @@ func (s *Schema) evaluateNumber(num interface{}) error {
case "multipleOf":
switch kv := keyword.Value.(type) {
case int32, int64, float64, primitive.Decimal128:
if Compare(kv, int32(0)) <= 0 {
if Compare(kv, int32(0), nil) <= 0 {
return fmt.Errorf("invalid multipleOf value: %v", kv)
}
if Compare(Mod(num, kv), int32(0)) != 0 {
if Compare(Mod(num, kv), int32(0), nil) != 0 {
return ErrValidationFailed
}
default:
Expand All @@ -330,7 +330,7 @@ func (s *Schema) evaluateNumber(num interface{}) error {
case "minimum":
switch kv := keyword.Value.(type) {
case int32, int64, float64, primitive.Decimal128:
res := Compare(num, kv)
res := Compare(num, kv, nil)
if exclusiveMinimum && res <= 0 {
return ErrValidationFailed
} else if !exclusiveMinimum && res < 0 {
Expand All @@ -342,7 +342,7 @@ func (s *Schema) evaluateNumber(num interface{}) error {
case "maximum":
switch kv := keyword.Value.(type) {
case int32, int64, float64, primitive.Decimal128:
res := Compare(num, kv)
res := Compare(num, kv, nil)
if exclusiveMaximum && res >= 0 {
return ErrValidationFailed
} else if !exclusiveMaximum && res > 0 {
Expand All @@ -364,10 +364,10 @@ func (s *Schema) evaluateString(str string) error {
case "minLength":
switch kv := keyword.Value.(type) {
case int32, int64:
if Compare(kv, int32(0)) < 0 {
if Compare(kv, int32(0), nil) < 0 {
return fmt.Errorf("invalid minLength value: %v", kv)
}
if Compare(int64(len(str)), kv) < 0 {
if Compare(int64(len(str)), kv, nil) < 0 {
return ErrValidationFailed
}
default:
Expand All @@ -376,10 +376,10 @@ func (s *Schema) evaluateString(str string) error {
case "maxLength":
switch kv := keyword.Value.(type) {
case int32, int64:
if Compare(kv, int32(0)) < 0 {
if Compare(kv, int32(0), nil) < 0 {
return fmt.Errorf("invalid maxLength value: %v", kv)
}
if Compare(int64(len(str)), kv) > 0 {
if Compare(int64(len(str)), kv, nil) > 0 {
return ErrValidationFailed
}
default:
Expand Down Expand Up @@ -430,10 +430,10 @@ func (s *Schema) evaluateDocument(doc bson.D) error {
case "minProperties":
switch kv := keyword.Value.(type) {
case int32, int64:
if Compare(kv, int32(0)) < 0 {
if Compare(kv, int32(0), nil) < 0 {
return fmt.Errorf("invalid minProperties value: %v", kv)
}
if Compare(int64(len(doc)), kv) < 0 {
if Compare(int64(len(doc)), kv, nil) < 0 {
return ErrValidationFailed
}
default:
Expand All @@ -442,10 +442,10 @@ func (s *Schema) evaluateDocument(doc bson.D) error {
case "maxProperties":
switch kv := keyword.Value.(type) {
case int32, int64:
if Compare(kv, int32(0)) < 0 {
if Compare(kv, int32(0), nil) < 0 {
return fmt.Errorf("invalid maxProperties value: %v", kv)
}
if Compare(int64(len(doc)), kv) > 0 {
if Compare(int64(len(doc)), kv, nil) > 0 {
return ErrValidationFailed
}
default:
Expand Down Expand Up @@ -602,10 +602,10 @@ func (s *Schema) evaluateArray(arr bson.A) error {
case "minItems":
switch kv := keyword.Value.(type) {
case int32, int64:
if Compare(kv, int32(0)) < 0 {
if Compare(kv, int32(0), nil) < 0 {
return fmt.Errorf("invalid minItems value: %v", kv)
}
if Compare(int64(len(arr)), kv) < 0 {
if Compare(int64(len(arr)), kv, nil) < 0 {
return ErrValidationFailed
}
default:
Expand All @@ -614,10 +614,10 @@ func (s *Schema) evaluateArray(arr bson.A) error {
case "maxItems":
switch kv := keyword.Value.(type) {
case int32, int64:
if Compare(kv, int32(0)) < 0 {
if Compare(kv, int32(0), nil) < 0 {
return fmt.Errorf("invalid maxItems value: %v", kv)
}
if Compare(int64(len(arr)), kv) > 0 {
if Compare(int64(len(arr)), kv, nil) > 0 {
return ErrValidationFailed
}
default:
Expand All @@ -629,7 +629,7 @@ func (s *Schema) evaluateArray(arr bson.A) error {
if kv {
for i := 0; i < len(arr)-1; i++ {
for j := i + 1; j < len(arr); j++ {
if Compare(arr[i], arr[j]) == 0 {
if Compare(arr[i], arr[j], nil) == 0 {
return ErrValidationFailed
}
}
Expand Down
10 changes: 6 additions & 4 deletions bsonkit/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package bsonkit
import (
"sort"
"unsafe"

"golang.org/x/text/collate"
)

// Column defines a column for ordering.
Expand All @@ -12,22 +14,22 @@ type Column struct {
}

// Sort will sort the list of documents in-place based on the specified columns.
func Sort(list List, columns []Column, identity bool) {
func Sort(list List, columns []Column, identity bool, collator *collate.Collator) {
// sort slice by comparing values
sort.Slice(list, func(i, j int) bool {
return Order(list[i], list[j], columns, identity) < 0
return Order(list[i], list[j], columns, identity, collator) < 0
})
}

// Order will return the order of documents based on the specified columns.
func Order(l, r Doc, columns []Column, identity bool) int {
func Order(l, r Doc, columns []Column, identity bool, collator *collate.Collator) int {
for _, column := range columns {
// get values
a := Get(l, column.Path)
b := Get(r, column.Path)

// compare values
res := Compare(a, b)
res := Compare(a, b, collator)

// continue if equal
if res == 0 {
Expand Down
20 changes: 10 additions & 10 deletions bsonkit/sort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,38 @@ func TestSort(t *testing.T) {
list := List{a3, a1, a2}
Sort(list, []Column{
{Path: "a", Reverse: false},
}, false)
}, false, nil)
assert.Equal(t, List{a1, a2, a3}, list)

// sort backwards single
list = List{a3, a1, a2}
Sort(list, []Column{
{Path: "a", Reverse: true},
}, false)
}, false, nil)
assert.Equal(t, List{a3, a2, a1}, list)

// sort forwards multiple
list = List{a3, a1, a2}
Sort(list, []Column{
{Path: "b", Reverse: false},
{Path: "a", Reverse: false},
}, false)
}, false, nil)
assert.Equal(t, List{a2, a1, a3}, list)

// sort backwards multiple
list = List{a3, a1, a2}
Sort(list, []Column{
{Path: "b", Reverse: true},
{Path: "a", Reverse: true},
}, false)
}, false, nil)
assert.Equal(t, List{a3, a1, a2}, list)

// sort mixed
list = List{a3, a1, a2}
Sort(list, []Column{
{Path: "b", Reverse: false},
{Path: "a", Reverse: true},
}, false)
}, false, nil)
assert.Equal(t, List{a2, a3, a1}, list)
}

Expand All @@ -61,37 +61,37 @@ func TestSortIdentity(t *testing.T) {
list := List{a3, a1, a4, a2}
Sort(list, []Column{
{Path: "a", Reverse: false},
}, true)
}, true, nil)
assert.Equal(t, List{a1, a2, a3, a4}, list)

// sort backwards single
list = List{a3, a1, a4, a2}
Sort(list, []Column{
{Path: "a", Reverse: true},
}, true)
}, true, nil)
assert.Equal(t, List{a4, a3, a2, a1}, list)

// sort forwards multiple
list = List{a3, a1, a4, a2}
Sort(list, []Column{
{Path: "b", Reverse: false},
{Path: "a", Reverse: false},
}, true)
}, true, nil)
assert.Equal(t, List{a2, a3, a1, a4}, list)

// sort backwards multiple
list = List{a3, a1, a4, a2}
Sort(list, []Column{
{Path: "b", Reverse: true},
{Path: "a", Reverse: true},
}, true)
}, true, nil)
assert.Equal(t, List{a4, a1, a2, a3}, list)

// sort mixed
list = List{a3, a1, a4, a2}
Sort(list, []Column{
{Path: "b", Reverse: false},
{Path: "a", Reverse: true},
}, true)
}, true, nil)
assert.Equal(t, List{a2, a3, a4, a1}, list)
}
2 changes: 1 addition & 1 deletion bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ func (b *Bucket) hasIndex(ctx context.Context, coll ICollection, model mongo.Ind

// check if index with same keys already exists
for _, index := range indexes {
if bsonkit.Compare(index.Keys, model.Keys) == 0 {
if bsonkit.Compare(index.Keys, model.Keys, nil) == 0 {
return true, nil
}
}
Expand Down
Loading