Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions server/ast/resolvable_type_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,24 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference,
case *types.T:
columnTypeName = columnType.SQLStandardName()
if columnType.Family() == types.ArrayFamily {
_, baseResolvedType, err := nodeResolvableTypeReference(ctx, columnType.ArrayContents(), mayBeTrigger)
if err != nil {
return nil, nil, err
}
if baseResolvedType.IsResolvedType() {
// currently the built-in types will be resolved, so it can retrieve its array type
doltgresType = baseResolvedType.ToArrayType()
} else {
// TODO: handle array type of non-built-in types
baseResolvedType.TypCategory = pgtypes.TypeCategory_ArrayTypes
doltgresType = baseResolvedType
switch columnType.Oid() {
case oid.T_int2vector:
doltgresType = pgtypes.Int16vector
case oid.T_oidvector:
doltgresType = pgtypes.Oidvector
default:
_, baseResolvedType, err := nodeResolvableTypeReference(ctx, columnType.ArrayContents(), mayBeTrigger)
if err != nil {
return nil, nil, err
}
if baseResolvedType.IsResolvedType() {
// currently the built-in types will be resolved, so it can retrieve its array type
doltgresType = baseResolvedType.ToArrayType()
} else {
// TODO: handle array type of non-built-in types
baseResolvedType.TypCategory = pgtypes.TypeCategory_ArrayTypes
doltgresType = baseResolvedType
}
}
} else if columnType.Family() == types.GeometryFamily {
return nil, nil, errors.Errorf("geometry types are not yet supported")
Expand Down Expand Up @@ -109,6 +116,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference,
doltgresType = pgtypes.Float64
case oid.T_int2:
doltgresType = pgtypes.Int16
case oid.T_int2vector:
doltgresType = pgtypes.Int16vector
case oid.T_int4:
doltgresType = pgtypes.Int32
case oid.T_int8:
Expand All @@ -132,6 +141,8 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference,
}
case oid.T_oid:
doltgresType = pgtypes.Oid
case oid.T_oidvector:
doltgresType = pgtypes.Oidvector
case oid.T_regclass:
doltgresType = pgtypes.Regclass
case oid.T_regproc:
Expand Down
148 changes: 78 additions & 70 deletions server/functions/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,34 +169,7 @@ var array_recv = framework.Function3{
baseType := pgtypes.IDToBuiltInDoltgresType[id.Type(baseTypeOid)]
typmod := val3.(int32)
baseType = baseType.WithAttTypMod(typmod)
// Check for the nil value, then ensure the minimum length of the slice
if len(data) == 0 {
return nil, nil
}
if len(data) < 4 {
return nil, errors.Errorf("deserializing non-nil array value has invalid length of %d", len(data))
}
// Grab the number of elements and construct an output slice of the appropriate size
elementCount := binary.LittleEndian.Uint32(data)
output := make([]any, elementCount)
// Read all elements
for i := uint32(0); i < elementCount; i++ {
// We read from i+1 to account for the element count at the beginning
offset := binary.LittleEndian.Uint32(data[(i+1)*4:])
// If the value is null, then we can skip it, since the output slice default initializes all values to nil
if data[offset] == 1 {
continue
}
// The element data is everything from the offset to the next offset, excluding the null determinant
nextOffset := binary.LittleEndian.Uint32(data[(i+2)*4:])
o, err := baseType.DeserializeValue(ctx, data[offset+1:nextOffset])
if err != nil {
return nil, err
}
output[i] = o
}
// Returns all read elements
return output, nil
return deserializeArray(ctx, data, baseType)
},
}

Expand All @@ -207,49 +180,9 @@ var array_send = framework.Function1{
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyArray},
Strict: true,
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
arrType := t[0]
baseType := arrType.ArrayBaseType()
vals := val.([]any)

bb := bytes.Buffer{}
// Write the element count to a buffer. We're using an array since it's stack-allocated, so no need for pooling.
var elementCount [4]byte
binary.LittleEndian.PutUint32(elementCount[:], uint32(len(vals)))
bb.Write(elementCount[:])
// Create an array that contains the offsets for each value. Since we can't update the offset portion of the buffer
// as we determine the offsets, we have to track them outside the buffer. We'll overwrite the buffer later with the
// correct offsets. The last offset represents the end of the slice, which simplifies the logic for reading elements
// using the "current offset to next offset" strategy. We use a byte slice since the buffer only works with byte
// slices.
offsets := make([]byte, (len(vals)+1)*4)
bb.Write(offsets)
// The starting offset for the first element is Count(uint32) + (NumberOfElementOffsets * sizeof(uint32))
currentOffset := uint32(4 + (len(vals)+1)*4)
for i := range vals {
// Write the current offset
binary.LittleEndian.PutUint32(offsets[i*4:], currentOffset)
// Handle serialization of the value
// TODO: ARRAYs may be multidimensional, such as ARRAY[[4,2],[6,3]], which isn't accounted for here
serializedVal, err := baseType.SerializeValue(ctx, vals[i])
if err != nil {
return nil, err
}
// Handle the nil case and non-nil case
if serializedVal == nil {
bb.WriteByte(1)
currentOffset += 1
} else {
bb.WriteByte(0)
bb.Write(serializedVal)
currentOffset += 1 + uint32(len(serializedVal))
}
}
// Write the final offset, which will equal the length of the serialized slice
binary.LittleEndian.PutUint32(offsets[len(offsets)-4:], currentOffset)
// Get the final output, and write the updated offsets to it
outputBytes := bb.Bytes()
copy(outputBytes[4:], offsets)
return outputBytes, nil
arrType := t[0]
return serializeArray(ctx, vals, arrType.ArrayBaseType())
},
}

Expand Down Expand Up @@ -301,3 +234,78 @@ var array_subscript_handler = framework.Function1{
return []byte{}, nil
},
}

// deserializeArray serializes an array of given base type.
func serializeArray(ctx *sql.Context, vals []any, baseType *pgtypes.DoltgresType) ([]byte, error) {
bb := bytes.Buffer{}
// Write the element count to a buffer. We're using an array since it's stack-allocated, so no need for pooling.
var elementCount [4]byte
binary.LittleEndian.PutUint32(elementCount[:], uint32(len(vals)))
bb.Write(elementCount[:])
// Create an array that contains the offsets for each value. Since we can't update the offset portion of the buffer
// as we determine the offsets, we have to track them outside the buffer. We'll overwrite the buffer later with the
// correct offsets. The last offset represents the end of the slice, which simplifies the logic for reading elements
// using the "current offset to next offset" strategy. We use a byte slice since the buffer only works with byte
// slices.
offsets := make([]byte, (len(vals)+1)*4)
bb.Write(offsets)
// The starting offset for the first element is Count(uint32) + (NumberOfElementOffsets * sizeof(uint32))
currentOffset := uint32(4 + (len(vals)+1)*4)
for i := range vals {
// Write the current offset
binary.LittleEndian.PutUint32(offsets[i*4:], currentOffset)
// Handle serialization of the value
// TODO: ARRAYs may be multidimensional, such as ARRAY[[4,2],[6,3]], which isn't accounted for here
serializedVal, err := baseType.SerializeValue(ctx, vals[i])
if err != nil {
return nil, err
}
// Handle the nil case and non-nil case
if serializedVal == nil {
bb.WriteByte(1)
currentOffset += 1
} else {
bb.WriteByte(0)
bb.Write(serializedVal)
currentOffset += 1 + uint32(len(serializedVal))
}
}
// Write the final offset, which will equal the length of the serialized slice
binary.LittleEndian.PutUint32(offsets[len(offsets)-4:], currentOffset)
// Get the final output, and write the updated offsets to it
outputBytes := bb.Bytes()
copy(outputBytes[4:], offsets)
return outputBytes, nil
}

// deserializeArray deserializes an array of given base type.
func deserializeArray(ctx *sql.Context, data []byte, baseType *pgtypes.DoltgresType) ([]any, error) {
// Check for the nil value, then ensure the minimum length of the slice
if len(data) == 0 {
return nil, nil
}
if len(data) < 4 {
return nil, errors.Errorf("deserializing non-nil array value has invalid length of %d", len(data))
}
// Grab the number of elements and construct an output slice of the appropriate size
elementCount := binary.LittleEndian.Uint32(data)
output := make([]any, elementCount)
// Read all elements
for i := uint32(0); i < elementCount; i++ {
// We read from i+1 to account for the element count at the beginning
offset := binary.LittleEndian.Uint32(data[(i+1)*4:])
// If the value is null, then we can skip it, since the output slice default initializes all values to nil
if data[offset] == 1 {
continue
}
// The element data is everything from the offset to the next offset, excluding the null determinant
nextOffset := binary.LittleEndian.Uint32(data[(i+2)*4:])
o, err := baseType.DeserializeValue(ctx, data[offset+1:nextOffset])
if err != nil {
return nil, err
}
output[i] = o
}
// Returns all read elements
return output, nil
}
32 changes: 19 additions & 13 deletions server/functions/binary/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func initBinaryEqual() {
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, nameeqtext)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, numeric_eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, oideq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, oidvectoreq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, texteqname)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, text_eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, record_eq)
Expand Down Expand Up @@ -469,25 +470,30 @@ var numeric_eq = framework.Function2{
Callable: numeric_eq_callable,
}

// oideq_callable is the callable logic for the oideq function.
// This method doesn't use DotlgresType.Compare because it's on the critical path for many tooling queries that
// examine the pg_catalog tables.
func oideq_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
if val1 == nil || val2 == nil {
return false, nil
}

val1id, val2id := val1.(id.Id), val2.(id.Id)
return val1id == val2id, nil
}

// oideq represents the PostgreSQL function of the same name, taking the same parameters.
var oideq = framework.Function2{
Name: "oideq",
Return: pgtypes.Bool,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oid, pgtypes.Oid},
Strict: true,
Callable: oideq_callable,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
// This method doesn't use DoltgresType.Compare because it's on the critical path for many tooling queries that
// examine the pg_catalog tables.
val1id, val2id := val1.(id.Id), val2.(id.Id)
return val1id == val2id, nil
},
}

// oidvectoreq represents the PostgreSQL function of the same name, taking the same parameters.
var oidvectoreq = framework.Function2{
Name: "oidvectoreq",
Return: pgtypes.Bool,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
return res == 0, err
},
}

// texteqname_callable is the callable logic for the texteqname function.
Expand Down
13 changes: 13 additions & 0 deletions server/functions/binary/greater.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func initBinaryGreaterThan() {
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, namegttext)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, numeric_gt)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, oidgt)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, oidvectorgt)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, textgtname)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, text_gt)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterThan, time_gt)
Expand Down Expand Up @@ -399,6 +400,18 @@ var oidgt = framework.Function2{
},
}

// oidvectorgt represents the PostgreSQL function of the same name, taking the same parameters.
var oidvectorgt = framework.Function2{
Name: "oidvectorgt",
Return: pgtypes.Bool,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
return res == 1, err
},
}

// textgtname represents the PostgreSQL function of the same name, taking the same parameters.
var textgtname = framework.Function2{
Name: "textgtname",
Expand Down
13 changes: 13 additions & 0 deletions server/functions/binary/greater_equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func initBinaryGreaterOrEqual() {
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, namegetext)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, numeric_ge)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, oidge)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, oidvectorge)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, textgename)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, text_ge)
framework.RegisterBinaryFunction(framework.Operator_BinaryGreaterOrEqual, time_ge)
Expand Down Expand Up @@ -399,6 +400,18 @@ var oidge = framework.Function2{
},
}

// oidvectorge represents the PostgreSQL function of the same name, taking the same parameters.
var oidvectorge = framework.Function2{
Name: "oidvectorge",
Return: pgtypes.Bool,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
return res >= 0, err
},
}

// textgename represents the PostgreSQL function of the same name, taking the same parameters.
var textgename = framework.Function2{
Name: "textgename",
Expand Down
13 changes: 13 additions & 0 deletions server/functions/binary/less.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func initBinaryLessThan() {
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, namelttext)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, numeric_lt)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, oidlt)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, oidvectorlt)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, textltname)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, text_lt)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessThan, time_lt)
Expand Down Expand Up @@ -399,6 +400,18 @@ var oidlt = framework.Function2{
},
}

// oidvectorlt represents the PostgreSQL function of the same name, taking the same parameters.
var oidvectorlt = framework.Function2{
Name: "oidvectorlt",
Return: pgtypes.Bool,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
return res == -1, err
},
}

// textltname represents the PostgreSQL function of the same name, taking the same parameters.
var textltname = framework.Function2{
Name: "textltname",
Expand Down
13 changes: 13 additions & 0 deletions server/functions/binary/less_equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func initBinaryLessOrEqual() {
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, nameletext)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, numeric_le)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, oidle)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, oidvectorle)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, textlename)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, text_le)
framework.RegisterBinaryFunction(framework.Operator_BinaryLessOrEqual, time_le)
Expand Down Expand Up @@ -399,6 +400,18 @@ var oidle = framework.Function2{
},
}

// oidvectorle represents the PostgreSQL function of the same name, taking the same parameters.
var oidvectorle = framework.Function2{
Name: "oidvectorle",
Return: pgtypes.Bool,
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Oidvector, pgtypes.Oidvector},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
res, err := pgtypes.Oidvector.Compare(ctx, val1.([]any), val2.([]any))
return res <= 0, err
},
}

// textlename represents the PostgreSQL function of the same name, taking the same parameters.
var textlename = framework.Function2{
Name: "textlename",
Expand Down
Loading