From 20cf5acf34457862311f19f8961f03a631f52279 Mon Sep 17 00:00:00 2001 From: magali Date: Thu, 7 Dec 2023 12:38:08 +0100 Subject: [PATCH] feat(collation): Add basic collation on strings --- bsonkit/compare.go | 23 +++++++++-------- bsonkit/compare_test.go | 8 +++--- bsonkit/index.go | 2 +- bsonkit/lists.go | 4 +-- bsonkit/schema.go | 36 +++++++++++++-------------- bsonkit/sort.go | 10 +++++--- bsonkit/sort_test.go | 20 +++++++-------- bucket.go | 2 +- collection.go | 27 +++++++++++++++++--- collection_test.go | 2 +- engine.go | 6 ++--- mongokit/apply.go | 8 +++--- mongokit/collection.go | 55 +++++++++++++++++++++++++++++++++++------ mongokit/index.go | 4 +-- mongokit/match.go | 10 ++++---- mongokit/project.go | 4 +-- mongokit/sort.go | 5 ++-- mongokit/sort_test.go | 14 +++++------ store_test.go | 4 +-- transaction.go | 10 ++++---- 20 files changed, 159 insertions(+), 95 deletions(-) diff --git a/bsonkit/compare.go b/bsonkit/compare.go index 034debd..0a488ac 100644 --- a/bsonkit/compare.go +++ b/bsonkit/compare.go @@ -8,12 +8,13 @@ import ( "github.com/shopspring/decimal" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "golang.org/x/text/collate" ) // Compare will compare two bson values and return their order according to the // BSON type comparison order specification: // https://docs.mongodb.com/manual/reference/bson-type-comparison-order. -func Compare(lv, rv interface{}) int { +func Compare(lv, rv interface{}, collator *collate.Collator) int { // get types lc, _ := Inspect(lv) rc, _ := Inspect(rv) @@ -32,11 +33,11 @@ func Compare(lv, rv interface{}) int { case Number: return compareNumbers(lv, rv) case String: - return compareStrings(lv, rv) + return compareStrings(lv, rv, collator) case Document: - return compareDocuments(lv, rv) + return compareDocuments(lv, rv, collator) case Array: - return compareArrays(lv, rv) + return compareArrays(lv, rv, collator) case Binary: return compareBinaries(lv, rv) case ObjectID: @@ -105,18 +106,20 @@ func compareNumbers(lv, rv interface{}) int { panic("bsonkit: unreachable") } -func compareStrings(lv, rv interface{}) int { +func compareStrings(lv, rv interface{}, collator *collate.Collator) int { // get strings l := lv.(string) r := rv.(string) // compare strings res := strings.Compare(l, r) - + if collator != nil { + res = collator.Compare([]byte(l), []byte(r)) + } return res } -func compareDocuments(lv, rv interface{}) int { +func compareDocuments(lv, rv interface{}, collator *collate.Collator) int { // get documents l := lv.(bson.D) r := rv.(bson.D) @@ -150,14 +153,14 @@ func compareDocuments(lv, rv interface{}) int { } // compare values - res = Compare(l[i].Value, r[i].Value) + res = Compare(l[i].Value, r[i].Value, collator) if res != 0 { return res } } } -func compareArrays(lv, rv interface{}) int { +func compareArrays(lv, rv interface{}, collator *collate.Collator) int { // get array l := lv.(bson.A) r := rv.(bson.A) @@ -185,7 +188,7 @@ func compareArrays(lv, rv interface{}) int { } // compare elements - res := Compare(l[i], r[i]) + res := Compare(l[i], r[i], collator) if res != 0 { return res } diff --git a/bsonkit/compare_test.go b/bsonkit/compare_test.go index 93fd594..9b363bd 100644 --- a/bsonkit/compare_test.go +++ b/bsonkit/compare_test.go @@ -10,16 +10,16 @@ import ( func TestCompare(t *testing.T) { // equality - assert.Equal(t, 0, Compare(bson.D{}, bson.D{})) + assert.Equal(t, 0, Compare(bson.D{}, bson.D{}, nil)) // less than - assert.Equal(t, -1, Compare("foo", false)) + assert.Equal(t, -1, Compare("foo", false, nil)) // greater than - assert.Equal(t, 1, Compare(false, "foo")) + assert.Equal(t, 1, Compare(false, "foo", nil)) // decimal dec, err := primitive.ParseDecimal128("3.14") assert.NoError(t, err) - assert.Equal(t, 1, Compare(5.0, dec)) + assert.Equal(t, 1, Compare(5.0, dec, nil)) } diff --git a/bsonkit/index.go b/bsonkit/index.go index 4b35c91..425ff0b 100644 --- a/bsonkit/index.go +++ b/bsonkit/index.go @@ -12,7 +12,7 @@ type Index struct { func NewIndex(unique bool, columns []Column) *Index { return &Index{ btree: btree.NewBTreeG[Doc](func(a, b Doc) bool { - return Order(a, b, columns, !unique) < 0 + return Order(a, b, columns, !unique, nil) < 0 }), } } diff --git a/bsonkit/lists.go b/bsonkit/lists.go index 7431d02..0c11a42 100644 --- a/bsonkit/lists.go +++ b/bsonkit/lists.go @@ -98,7 +98,7 @@ func Collect(list List, path string, compact, merge, flatten, distinct bool) bso // sort results sort.Slice(result, func(i, j int) bool { - return Compare(result[i], result[j]) < 0 + return Compare(result[i], result[j], nil) < 0 }) // prepare distincts @@ -108,7 +108,7 @@ func Collect(list List, path string, compact, merge, flatten, distinct bool) bso var prevValue interface{} for _, value := range result { // check if same as previous value - if len(distincts) > 0 && Compare(prevValue, value) == 0 { + if len(distincts) > 0 && Compare(prevValue, value, nil) == 0 { continue } diff --git a/bsonkit/schema.go b/bsonkit/schema.go index b7c885c..bf698e1 100644 --- a/bsonkit/schema.go +++ b/bsonkit/schema.go @@ -185,7 +185,7 @@ func (s *Schema) evaluateGeneric(value interface{}, valueClass Class, valueType } var ok bool for _, enum := range kv { - if Compare(enum, value) == 0 { + if Compare(enum, value, nil) == 0 { ok = true } } @@ -318,10 +318,10 @@ func (s *Schema) evaluateNumber(num interface{}) error { case "multipleOf": switch kv := keyword.Value.(type) { case int32, int64, float64, primitive.Decimal128: - if Compare(kv, int32(0)) <= 0 { + if Compare(kv, int32(0), nil) <= 0 { return fmt.Errorf("invalid multipleOf value: %v", kv) } - if Compare(Mod(num, kv), int32(0)) != 0 { + if Compare(Mod(num, kv), int32(0), nil) != 0 { return ErrValidationFailed } default: @@ -330,7 +330,7 @@ func (s *Schema) evaluateNumber(num interface{}) error { case "minimum": switch kv := keyword.Value.(type) { case int32, int64, float64, primitive.Decimal128: - res := Compare(num, kv) + res := Compare(num, kv, nil) if exclusiveMinimum && res <= 0 { return ErrValidationFailed } else if !exclusiveMinimum && res < 0 { @@ -342,7 +342,7 @@ func (s *Schema) evaluateNumber(num interface{}) error { case "maximum": switch kv := keyword.Value.(type) { case int32, int64, float64, primitive.Decimal128: - res := Compare(num, kv) + res := Compare(num, kv, nil) if exclusiveMaximum && res >= 0 { return ErrValidationFailed } else if !exclusiveMaximum && res > 0 { @@ -364,10 +364,10 @@ func (s *Schema) evaluateString(str string) error { case "minLength": switch kv := keyword.Value.(type) { case int32, int64: - if Compare(kv, int32(0)) < 0 { + if Compare(kv, int32(0), nil) < 0 { return fmt.Errorf("invalid minLength value: %v", kv) } - if Compare(int64(len(str)), kv) < 0 { + if Compare(int64(len(str)), kv, nil) < 0 { return ErrValidationFailed } default: @@ -376,10 +376,10 @@ func (s *Schema) evaluateString(str string) error { case "maxLength": switch kv := keyword.Value.(type) { case int32, int64: - if Compare(kv, int32(0)) < 0 { + if Compare(kv, int32(0), nil) < 0 { return fmt.Errorf("invalid maxLength value: %v", kv) } - if Compare(int64(len(str)), kv) > 0 { + if Compare(int64(len(str)), kv, nil) > 0 { return ErrValidationFailed } default: @@ -430,10 +430,10 @@ func (s *Schema) evaluateDocument(doc bson.D) error { case "minProperties": switch kv := keyword.Value.(type) { case int32, int64: - if Compare(kv, int32(0)) < 0 { + if Compare(kv, int32(0), nil) < 0 { return fmt.Errorf("invalid minProperties value: %v", kv) } - if Compare(int64(len(doc)), kv) < 0 { + if Compare(int64(len(doc)), kv, nil) < 0 { return ErrValidationFailed } default: @@ -442,10 +442,10 @@ func (s *Schema) evaluateDocument(doc bson.D) error { case "maxProperties": switch kv := keyword.Value.(type) { case int32, int64: - if Compare(kv, int32(0)) < 0 { + if Compare(kv, int32(0), nil) < 0 { return fmt.Errorf("invalid maxProperties value: %v", kv) } - if Compare(int64(len(doc)), kv) > 0 { + if Compare(int64(len(doc)), kv, nil) > 0 { return ErrValidationFailed } default: @@ -602,10 +602,10 @@ func (s *Schema) evaluateArray(arr bson.A) error { case "minItems": switch kv := keyword.Value.(type) { case int32, int64: - if Compare(kv, int32(0)) < 0 { + if Compare(kv, int32(0), nil) < 0 { return fmt.Errorf("invalid minItems value: %v", kv) } - if Compare(int64(len(arr)), kv) < 0 { + if Compare(int64(len(arr)), kv, nil) < 0 { return ErrValidationFailed } default: @@ -614,10 +614,10 @@ func (s *Schema) evaluateArray(arr bson.A) error { case "maxItems": switch kv := keyword.Value.(type) { case int32, int64: - if Compare(kv, int32(0)) < 0 { + if Compare(kv, int32(0), nil) < 0 { return fmt.Errorf("invalid maxItems value: %v", kv) } - if Compare(int64(len(arr)), kv) > 0 { + if Compare(int64(len(arr)), kv, nil) > 0 { return ErrValidationFailed } default: @@ -629,7 +629,7 @@ func (s *Schema) evaluateArray(arr bson.A) error { if kv { for i := 0; i < len(arr)-1; i++ { for j := i + 1; j < len(arr); j++ { - if Compare(arr[i], arr[j]) == 0 { + if Compare(arr[i], arr[j], nil) == 0 { return ErrValidationFailed } } diff --git a/bsonkit/sort.go b/bsonkit/sort.go index 8f6a7d6..33d3bc2 100644 --- a/bsonkit/sort.go +++ b/bsonkit/sort.go @@ -3,6 +3,8 @@ package bsonkit import ( "sort" "unsafe" + + "golang.org/x/text/collate" ) // Column defines a column for ordering. @@ -12,22 +14,22 @@ type Column struct { } // Sort will sort the list of documents in-place based on the specified columns. -func Sort(list List, columns []Column, identity bool) { +func Sort(list List, columns []Column, identity bool, collator *collate.Collator) { // sort slice by comparing values sort.Slice(list, func(i, j int) bool { - return Order(list[i], list[j], columns, identity) < 0 + return Order(list[i], list[j], columns, identity, collator) < 0 }) } // Order will return the order of documents based on the specified columns. -func Order(l, r Doc, columns []Column, identity bool) int { +func Order(l, r Doc, columns []Column, identity bool, collator *collate.Collator) int { for _, column := range columns { // get values a := Get(l, column.Path) b := Get(r, column.Path) // compare values - res := Compare(a, b) + res := Compare(a, b, collator) // continue if equal if res == 0 { diff --git a/bsonkit/sort_test.go b/bsonkit/sort_test.go index 77eca95..388a2f4 100644 --- a/bsonkit/sort_test.go +++ b/bsonkit/sort_test.go @@ -16,14 +16,14 @@ func TestSort(t *testing.T) { list := List{a3, a1, a2} Sort(list, []Column{ {Path: "a", Reverse: false}, - }, false) + }, false, nil) assert.Equal(t, List{a1, a2, a3}, list) // sort backwards single list = List{a3, a1, a2} Sort(list, []Column{ {Path: "a", Reverse: true}, - }, false) + }, false, nil) assert.Equal(t, List{a3, a2, a1}, list) // sort forwards multiple @@ -31,7 +31,7 @@ func TestSort(t *testing.T) { Sort(list, []Column{ {Path: "b", Reverse: false}, {Path: "a", Reverse: false}, - }, false) + }, false, nil) assert.Equal(t, List{a2, a1, a3}, list) // sort backwards multiple @@ -39,7 +39,7 @@ func TestSort(t *testing.T) { Sort(list, []Column{ {Path: "b", Reverse: true}, {Path: "a", Reverse: true}, - }, false) + }, false, nil) assert.Equal(t, List{a3, a1, a2}, list) // sort mixed @@ -47,7 +47,7 @@ func TestSort(t *testing.T) { Sort(list, []Column{ {Path: "b", Reverse: false}, {Path: "a", Reverse: true}, - }, false) + }, false, nil) assert.Equal(t, List{a2, a3, a1}, list) } @@ -61,14 +61,14 @@ func TestSortIdentity(t *testing.T) { list := List{a3, a1, a4, a2} Sort(list, []Column{ {Path: "a", Reverse: false}, - }, true) + }, true, nil) assert.Equal(t, List{a1, a2, a3, a4}, list) // sort backwards single list = List{a3, a1, a4, a2} Sort(list, []Column{ {Path: "a", Reverse: true}, - }, true) + }, true, nil) assert.Equal(t, List{a4, a3, a2, a1}, list) // sort forwards multiple @@ -76,7 +76,7 @@ func TestSortIdentity(t *testing.T) { Sort(list, []Column{ {Path: "b", Reverse: false}, {Path: "a", Reverse: false}, - }, true) + }, true, nil) assert.Equal(t, List{a2, a3, a1, a4}, list) // sort backwards multiple @@ -84,7 +84,7 @@ func TestSortIdentity(t *testing.T) { Sort(list, []Column{ {Path: "b", Reverse: true}, {Path: "a", Reverse: true}, - }, true) + }, true, nil) assert.Equal(t, List{a4, a1, a2, a3}, list) // sort mixed @@ -92,6 +92,6 @@ func TestSortIdentity(t *testing.T) { Sort(list, []Column{ {Path: "b", Reverse: false}, {Path: "a", Reverse: true}, - }, true) + }, true, nil) assert.Equal(t, List{a2, a3, a4, a1}, list) } diff --git a/bucket.go b/bucket.go index fb83ffe..1c6d5b4 100644 --- a/bucket.go +++ b/bucket.go @@ -675,7 +675,7 @@ func (b *Bucket) hasIndex(ctx context.Context, coll ICollection, model mongo.Ind // check if index with same keys already exists for _, index := range indexes { - if bsonkit.Compare(index.Keys, model.Keys) == 0 { + if bsonkit.Compare(index.Keys, model.Keys, nil) == 0 { return true, nil } } diff --git a/collection.go b/collection.go index 4978026..5384c63 100644 --- a/collection.go +++ b/collection.go @@ -258,7 +258,7 @@ func (c *Collection) CountDocuments(ctx context.Context, filter interface{}, opt // find documents res, err := useTransaction(ctx, c.engine, false, func(txn *Transaction) (interface{}, error) { - return txn.Find(c.handle, query, nil, skip, limit) + return txn.Find(c.handle, query, nil, nil, skip, limit) }) if err != nil { return 0, err @@ -376,7 +376,7 @@ func (c *Collection) Distinct(ctx context.Context, field string, filter interfac // find documents res, err := useTransaction(ctx, c.engine, false, func(txn *Transaction) (interface{}, error) { - return txn.Find(c.handle, query, nil, 0, 0) + return txn.Find(c.handle, query, nil, nil, 0, 0) }) if err != nil { return nil, err @@ -456,6 +456,7 @@ func (c *Collection) Find(ctx context.Context, filter interface{}, opts ...*opti "Skip": supported, "Snapshot": ignored, "Sort": supported, + "Collation": ignored, }) // check filer @@ -487,6 +488,15 @@ func (c *Collection) Find(ctx context.Context, filter interface{}, opts ...*opti } } + // get collation + var collation bsonkit.Doc + if opt.Collation != nil { + collation, err = bsonkit.Transform(opt.Collation) + if err != nil { + return nil, err + } + } + // get skip var skip int if opt.Skip != nil { @@ -501,7 +511,7 @@ func (c *Collection) Find(ctx context.Context, filter interface{}, opts ...*opti // find documents res, err := useTransaction(ctx, c.engine, false, func(txn *Transaction) (interface{}, error) { - return txn.Find(c.handle, query, sort, skip, limit) + return txn.Find(c.handle, query, sort, collation, skip, limit) }) if err != nil { return nil, err @@ -560,6 +570,15 @@ func (c *Collection) FindOne(ctx context.Context, filter interface{}, opts ...*o } } + // get collation + var collation bsonkit.Doc + if opt.Collation != nil { + collation, err = bsonkit.Transform(opt.Collation) + if err != nil { + return &SingleResult{err: err} + } + } + // get skip var skip int if opt.Skip != nil { @@ -577,7 +596,7 @@ func (c *Collection) FindOne(ctx context.Context, filter interface{}, opts ...*o // find documents res, err := useTransaction(ctx, c.engine, false, func(txn *Transaction) (interface{}, error) { - return txn.Find(c.handle, query, sort, skip, 1) + return txn.Find(c.handle, query, sort, collation, skip, 1) }) if err != nil { return &SingleResult{err: err} diff --git a/collection_test.go b/collection_test.go index 1077d88..05fbcab 100644 --- a/collection_test.go +++ b/collection_test.go @@ -1708,7 +1708,7 @@ func TestCollectionUpdateMany(t *testing.T) { }, dumpCollection(c, false)) // invalid _id mutation - res2, err = c.UpdateMany(nil, bson.M{ + _, err = c.UpdateMany(nil, bson.M{ "_id": id1, }, bson.M{ "$set": bson.M{ diff --git a/engine.go b/engine.go index 89c29c5..017bddd 100644 --- a/engine.go +++ b/engine.go @@ -265,7 +265,7 @@ func (e *Engine) Watch(handle Handle, pipeline bsonkit.List, resumeAfter, startA if resumeAfter != nil { resumed := false for _, event := range oplog.List { - res := bsonkit.Compare(*resumeAfter, bsonkit.Get(event, "_id")) + res := bsonkit.Compare(*resumeAfter, bsonkit.Get(event, "_id"), nil) if res == 0 { last = event resumed = true @@ -281,7 +281,7 @@ func (e *Engine) Watch(handle Handle, pipeline bsonkit.List, resumeAfter, startA if startAfter != nil { resumed := false for _, event := range oplog.List { - res := bsonkit.Compare(*startAfter, bsonkit.Get(event, "_id")) + res := bsonkit.Compare(*startAfter, bsonkit.Get(event, "_id"), nil) if res == 0 { last = event resumed = true @@ -297,7 +297,7 @@ func (e *Engine) Watch(handle Handle, pipeline bsonkit.List, resumeAfter, startA if startAt != nil { resumed := false for i, event := range oplog.List { - res := bsonkit.Compare(*startAt, bsonkit.Get(event, "clusterTime")) + res := bsonkit.Compare(*startAt, bsonkit.Get(event, "clusterTime"), nil) if res == 0 { if i > 0 { last = oplog.List[i-1] diff --git a/mongokit/apply.go b/mongokit/apply.go index 2d92426..942a748 100644 --- a/mongokit/apply.go +++ b/mongokit/apply.go @@ -244,7 +244,7 @@ func applyMax(ctx Context, doc bsonkit.Doc, _, path string, v interface{}) error } // replace value if smaller - if bsonkit.Compare(value, v) < 0 { + if bsonkit.Compare(value, v, nil) < 0 { // replace value _, err := bsonkit.Put(doc, path, v, false) if err != nil { @@ -281,7 +281,7 @@ func applyMin(ctx Context, doc bsonkit.Doc, _, path string, v interface{}) error } // replace value if bigger - if bsonkit.Compare(value, v) > 0 { + if bsonkit.Compare(value, v, nil) > 0 { // replace value _, err := bsonkit.Put(doc, path, v, false) if err != nil { @@ -384,9 +384,9 @@ func applyPush(ctx Context, doc bsonkit.Doc, _, path string, v interface{}) erro func applyPop(ctx Context, doc bsonkit.Doc, name, path string, v interface{}) error { // check value last := false - if bsonkit.Compare(v, int64(1)) == 0 { + if bsonkit.Compare(v, int64(1), nil) == 0 { last = true - } else if bsonkit.Compare(v, int64(-1)) != 0 { + } else if bsonkit.Compare(v, int64(-1), nil) != 0 { return fmt.Errorf("%s: expected 1 or -1", name) } diff --git a/mongokit/collection.go b/mongokit/collection.go index 2d22ff0..42d03f3 100644 --- a/mongokit/collection.go +++ b/mongokit/collection.go @@ -3,10 +3,18 @@ package mongokit import ( "fmt" + "github.com/256dpi/lungo/bsonkit" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/options" + "golang.org/x/text/collate" + "golang.org/x/text/language" +) - "github.com/256dpi/lungo/bsonkit" +const ( + collationFieldNameLocale = "locale" ) // TODO: Test Collection. @@ -61,14 +69,45 @@ func NewCollection(idIndex bool) *Collection { } // Find will look up the documents that match the specified query. -func (c *Collection) Find(query, sort bsonkit.Doc, skip, limit int) (*Result, error) { +func (c *Collection) Find(query, sort bsonkit.Doc, collation bsonkit.Doc, skip, limit int) (*Result, error) { // get documents list := c.Documents.List // sort documents var err error if sort != nil && len(*sort) > 0 { - list, err = Sort(list, sort) + + var collator *collate.Collator + if collation != nil && len(*collation) > 0 { + + // Marshal the BSON document that contains the locale of the collation. + data, err := bson.Marshal(collation) + if err != nil { + panic(err) + } + + // Create a Decoder that reads the marshaled BSON document and use it to + // unmarshal the document into a mongo-driver options.Collation struct. + decoder, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) + if err != nil { + panic(err) + } + + var res *options.Collation + err = decoder.Decode(&res) + if err != nil { + panic(err) + } + + l := language.English + if res != nil && res.Locale != "" { + l = language.Make(fmt.Sprint(res.Locale)) + } + + collator = collate.New(l) + } + + list, err = Sort(list, sort, collator) if err != nil { return nil, err } @@ -136,7 +175,7 @@ func (c *Collection) Replace(query, repl, sort bsonkit.Doc) (*Result, error) { // sort documents var err error if sort != nil && len(*sort) > 0 { - list, err = Sort(list, sort) + list, err = Sort(list, sort, nil) if err != nil { return nil, err } @@ -203,7 +242,7 @@ func (c *Collection) Update(query, update, sort bsonkit.Doc, skip, limit int, ar // sort documents var err error if sort != nil && len(*sort) > 0 { - list, err = Sort(list, sort) + list, err = Sort(list, sort, nil) if err != nil { return nil, err } @@ -308,7 +347,7 @@ func (c *Collection) Upsert(query, repl, update bsonkit.Doc, arrayFilters bsonki // check ids if queryID != bsonkit.Missing && replID != bsonkit.Missing { - if bsonkit.Compare(replID, queryID) != 0 { + if bsonkit.Compare(replID, queryID, nil) != 0 { return nil, fmt.Errorf("query _id and replacement _id must match") } } @@ -374,7 +413,7 @@ func (c *Collection) Delete(query, sort bsonkit.Doc, skip, limit int) (*Result, // sort documents var err error if sort != nil && len(*sort) > 0 { - list, err = Sort(list, sort) + list, err = Sort(list, sort, nil) if err != nil { return nil, err } @@ -446,7 +485,7 @@ func (c *Collection) CreateIndex(name string, config IndexConfig) (string, error // check duplicate for name, index := range c.Indexes { - if bsonkit.Compare(*config.Key, *index.Config().Key) == 0 { + if bsonkit.Compare(*config.Key, *index.Config().Key, nil) == 0 { return "", fmt.Errorf("existing index %q has same key", name) } } diff --git a/mongokit/index.go b/mongokit/index.go index a3cdc8c..923a381 100644 --- a/mongokit/index.go +++ b/mongokit/index.go @@ -31,7 +31,7 @@ type IndexConfig struct { // Equal will compare to configurations and return whether they are equal. func (c IndexConfig) Equal(d IndexConfig) bool { // check key - if bsonkit.Compare(*c.Key, *d.Key) != 0 { + if bsonkit.Compare(*c.Key, *d.Key, nil) != 0 { return false } @@ -48,7 +48,7 @@ func (c IndexConfig) Equal(d IndexConfig) bool { if d.Partial != nil { p2 = *d.Partial } - if bsonkit.Compare(p1, p2) != 0 { + if bsonkit.Compare(p1, p2, nil) != 0 { return false } diff --git a/mongokit/match.go b/mongokit/match.go index 0f38a67..19eff60 100644 --- a/mongokit/match.go +++ b/mongokit/match.go @@ -140,7 +140,7 @@ func matchComp(_ Context, doc bsonkit.Doc, op, path string, v interface{}) error comp := lc == rc // compare field with value - res := bsonkit.Compare(field, v) + res := bsonkit.Compare(field, v, nil) // check operator var ok bool @@ -205,7 +205,7 @@ func matchIn(_ Context, doc bsonkit.Doc, name, path string, v interface{}) error // check if field is in array for _, item := range array { - if bsonkit.Compare(field, item) == 0 { + if bsonkit.Compare(field, item, nil) == 0 { return nil } } @@ -340,7 +340,7 @@ func matchAll(_ Context, doc bsonkit.Doc, name, path string, v interface{}) erro for _, value := range array { ok := false for _, element := range arr { - if bsonkit.Compare(value, element) == 0 { + if bsonkit.Compare(value, element, nil) == 0 { ok = true } } @@ -355,7 +355,7 @@ func matchAll(_ Context, doc bsonkit.Doc, name, path string, v interface{}) erro // check if field is in array for _, item := range array { - if bsonkit.Compare(field, item) != 0 { + if bsonkit.Compare(field, item, nil) != 0 { return ErrNotMatched } } @@ -375,7 +375,7 @@ func matchSize(_ Context, doc bsonkit.Doc, name, path string, v interface{}) err // compare length if array array, ok := field.(bson.A) if ok { - if bsonkit.Compare(int64(len(array)), v) == 0 { + if bsonkit.Compare(int64(len(array)), v, nil) == 0 { return nil } } diff --git a/mongokit/project.go b/mongokit/project.go index fe543c4..7342614 100644 --- a/mongokit/project.go +++ b/mongokit/project.go @@ -135,9 +135,9 @@ func projectCondition(ctx Context, _ bsonkit.Doc, _, path string, v interface{}) state := ctx.Value.(*projectState) // handle inclusion or exclusion - if bsonkit.Compare(v, int64(1)) == 0 { + if bsonkit.Compare(v, int64(1), nil) == 0 { state.include = append(state.include, path) - } else if bsonkit.Compare(v, int64(0)) == 0 { + } else if bsonkit.Compare(v, int64(0), nil) == 0 { if path == "_id" { state.hideID = true } else { diff --git a/mongokit/sort.go b/mongokit/sort.go index 6ecd0cf..24ab48d 100644 --- a/mongokit/sort.go +++ b/mongokit/sort.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/256dpi/lungo/bsonkit" + "golang.org/x/text/collate" ) // Columns will return columns from a MongoDB sort document. @@ -41,7 +42,7 @@ func Columns(doc bsonkit.Doc) ([]bsonkit.Column, error) { // Sort will sort a list based on a MongoDB sort document and return a new // list with sorted documents. -func Sort(list bsonkit.List, doc bsonkit.Doc) (bsonkit.List, error) { +func Sort(list bsonkit.List, doc bsonkit.Doc, collator *collate.Collator) (bsonkit.List, error) { // copy list result := make(bsonkit.List, len(list)) copy(result, list) @@ -53,7 +54,7 @@ func Sort(list bsonkit.List, doc bsonkit.Doc) (bsonkit.List, error) { } // sort list - bsonkit.Sort(result, columns, true) + bsonkit.Sort(result, columns, true, collator) return result, nil } diff --git a/mongokit/sort_test.go b/mongokit/sort_test.go index 238df2e..f5fc25b 100644 --- a/mongokit/sort_test.go +++ b/mongokit/sort_test.go @@ -19,28 +19,28 @@ func TestSort(t *testing.T) { // invalid document list, err := Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "a", Value: "0"}, - }) + }, nil) assert.Error(t, err) assert.Nil(t, list) // invalid document list, err = Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "a", Value: 0}, - }) + }, nil) assert.Error(t, err) assert.Nil(t, list) // sort forwards single list, err = Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "a", Value: int64(1)}, - }) + }, nil) assert.NoError(t, err) assert.Equal(t, bsonkit.List{a1, a2, a3}, list) // sort backwards single list, err = Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "a", Value: int64(-1)}, - }) + }, nil) assert.NoError(t, err) assert.Equal(t, bsonkit.List{a3, a2, a1}, list) @@ -48,7 +48,7 @@ func TestSort(t *testing.T) { list, err = Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "b", Value: int64(1)}, bson.E{Key: "a", Value: int64(1)}, - }) + }, nil) assert.NoError(t, err) assert.Equal(t, bsonkit.List{a2, a1, a3}, list) @@ -56,7 +56,7 @@ func TestSort(t *testing.T) { list, err = Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "b", Value: int64(-1)}, bson.E{Key: "a", Value: int64(-1)}, - }) + }, nil) assert.NoError(t, err) assert.Equal(t, bsonkit.List{a3, a1, a2}, list) @@ -64,7 +64,7 @@ func TestSort(t *testing.T) { list, err = Sort(bsonkit.List{a3, a1, a2}, &bson.D{ bson.E{Key: "b", Value: int64(1)}, bson.E{Key: "a", Value: int64(-1)}, - }) + }, nil) assert.NoError(t, err) assert.Equal(t, bsonkit.List{a2, a3, a1}, list) } diff --git a/store_test.go b/store_test.go index fc411a7..83fdce8 100644 --- a/store_test.go +++ b/store_test.go @@ -148,7 +148,7 @@ func TestFileStore(t *testing.T) { txn, err = engine.Begin(nil, false) assert.NoError(t, err) - res, err = txn.Find(handle, bsonkit.MustConvert(bson.M{}), nil, 0, 0) + res, err = txn.Find(handle, bsonkit.MustConvert(bson.M{}), nil, nil, 0, 0) assert.NoError(t, err) assert.Equal(t, bsonkit.List{ bsonkit.MustConvert(bson.M{ @@ -162,7 +162,7 @@ func TestFileStore(t *testing.T) { }, res.Matched) databases, err := txn.ListDatabases(bsonkit.MustConvert(bson.M{})) - bsonkit.Sort(databases, []bsonkit.Column{{Path: "name"}}, true) + bsonkit.Sort(databases, []bsonkit.Column{{Path: "name"}}, true, nil) assert.NoError(t, err) assert.Equal(t, bson.A{ "foo", diff --git a/transaction.go b/transaction.go index 381525c..c5346d9 100644 --- a/transaction.go +++ b/transaction.go @@ -127,7 +127,7 @@ func (t *Transaction) Create(handle Handle) error { // Find will query documents from a namespace. Sort, skip and limit may be // supplied to modify the result. The returned results will contain the matched // list of documents. -func (t *Transaction) Find(handle Handle, query, sort bsonkit.Doc, skip, limit int) (*Result, error) { +func (t *Transaction) Find(handle Handle, query, sort bsonkit.Doc, collation bsonkit.Doc, skip, limit int) (*Result, error) { // acquire read lock t.mutex.RLock() defer t.mutex.RUnlock() @@ -144,7 +144,7 @@ func (t *Transaction) Find(handle Handle, query, sort bsonkit.Doc, skip, limit i } // find documents - res, err := t.catalog.Namespaces[handle].Find(query, sort, skip, limit) + res, err := t.catalog.Namespaces[handle].Find(query, sort, collation, skip, limit) if err != nil { return nil, err } @@ -901,7 +901,7 @@ func (t *Transaction) ListIndexes(handle Handle) (bsonkit.List, error) { // sort list bsonkit.Sort(list, []bsonkit.Column{ {Path: "name"}, - }, true) + }, true, nil) return list, nil } @@ -1046,8 +1046,8 @@ func (t *Transaction) Clean(minSize, maxSize int, minAge, maxAge time.Duration) ts := bsonkit.Get(doc, "_id.ts") // determine inclusion - afterMin := i < minIndex && bsonkit.Compare(ts, minTimestamp) < 0 - beyondMax := i < maxIndex || bsonkit.Compare(ts, maxTimestamp) < 0 + afterMin := i < minIndex && bsonkit.Compare(ts, minTimestamp, nil) < 0 + beyondMax := i < maxIndex || bsonkit.Compare(ts, maxTimestamp, nil) < 0 // remove event if below threshold or timestamp if afterMin && beyondMax {