Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"reflect"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

type kv struct {
Expand Down Expand Up @@ -155,6 +158,35 @@ func TestDecoder_scan(t *testing.T) {
}

func TestDecoder_errors(t *testing.T) {
tests := []struct {
data string
dec func(string) *Decoder
want error
}{
{
data: "a=1\nb=2",
dec: func(s string) *Decoder {
dec := NewDecoderSize(strings.NewReader(s), 1)
return dec
},
want: bufio.ErrTooLong,
},
}

for _, test := range tests {
dec := test.dec(test.data)

for dec.ScanRecord() {
for dec.ScanKeyval() {
}
}
if diff := cmp.Diff(test.want, dec.Err(), cmpopts.EquateErrors()); diff != "" {
t.Errorf("%#v: Decoder.Err() value mismatch (-want,+got):\n%s", test.data, diff)
}
}
}

func TestDecoder_SyntaxError(t *testing.T) {
defaultDecoder := func(s string) *Decoder { return NewDecoder(strings.NewReader(s)) }
tests := []struct {
data string
Expand Down Expand Up @@ -231,14 +263,6 @@ func TestDecoder_errors(t *testing.T) {
dec: defaultDecoder,
want: &SyntaxError{Msg: "invalid key", Line: 1, Pos: 2},
},
{
data: "a=1\nb=2",
dec: func(s string) *Decoder {
dec := NewDecoderSize(strings.NewReader(s), 1)
return dec
},
want: bufio.ErrTooLong,
},
}

for _, test := range tests {
Expand All @@ -248,8 +272,16 @@ func TestDecoder_errors(t *testing.T) {
for dec.ScanKeyval() {
}
}
if got, want := dec.Err(), test.want; !reflect.DeepEqual(got, want) {
t.Errorf("got: %v, want: %v", got, want)

switch got := dec.Err().(type) {
case nil:
t.Errorf("%#v: dec.Err() == nil, want: *SyntaxError", test.data)
case *SyntaxError:
if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("%#v: dec.Err() mismatch (-want,+got):\n%s", test.data, diff)
}
default:
t.Errorf("%#v: dec.Err().(type) == %T, want: *SyntaxError", test.data, got)
}
}
}
Expand Down
20 changes: 10 additions & 10 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

// MarshalKeyvals returns the logfmt encoding of keyvals, a variadic sequence
// of alternating keys and values.
func MarshalKeyvals(keyvals ...interface{}) ([]byte, error) {
func MarshalKeyvals(keyvals ...any) ([]byte, error) {
buf := &bytes.Buffer{}
if err := NewEncoder(buf).EncodeKeyvals(keyvals...); err != nil {
return nil, err
Expand Down Expand Up @@ -45,7 +45,7 @@ var (
// EncodeKeyval writes the logfmt encoding of key and value to the stream. A
// single space is written before the second and subsequent keys in a record.
// Nothing is written if a non-nil error is returned.
func (enc *Encoder) EncodeKeyval(key, value interface{}) error {
func (enc *Encoder) EncodeKeyval(key, value any) error {
enc.scratch.Reset()
if enc.needSep {
if _, err := enc.scratch.Write(space); err != nil {
Expand All @@ -72,7 +72,7 @@ func (enc *Encoder) EncodeKeyval(key, value interface{}) error {
// unsupported type or that cause a MarshalerError are replaced by their error
// but do not cause EncodeKeyvals to return an error. If a non-nil error is
// returned some key/value pairs may not have be written.
func (enc *Encoder) EncodeKeyvals(keyvals ...interface{}) error {
func (enc *Encoder) EncodeKeyvals(keyvals ...any) error {
if len(keyvals) == 0 {
return nil
}
Expand Down Expand Up @@ -122,7 +122,7 @@ var ErrUnsupportedKeyType = errors.New("unsupported key type")
// unsupported type.
var ErrUnsupportedValueType = errors.New("unsupported value type")

func writeKey(w io.Writer, key interface{}) error {
func writeKey(w io.Writer, key any) error {
if key == nil {
return ErrNilKey
}
Expand Down Expand Up @@ -155,7 +155,7 @@ func writeKey(w io.Writer, key interface{}) error {
switch rkey.Kind() {
case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Struct:
return ErrUnsupportedKeyType
case reflect.Ptr:
case reflect.Pointer:
if rkey.IsNil() {
return ErrNilKey
}
Expand Down Expand Up @@ -194,7 +194,7 @@ func writeBytesKey(w io.Writer, key []byte) error {
return err
}

func writeValue(w io.Writer, value interface{}) error {
func writeValue(w io.Writer, value any) error {
switch v := value.(type) {
case nil:
return writeBytesValue(w, null)
Expand Down Expand Up @@ -222,7 +222,7 @@ func writeValue(w io.Writer, value interface{}) error {
switch rvalue.Kind() {
case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Struct:
return ErrUnsupportedValueType
case reflect.Ptr:
case reflect.Pointer:
if rvalue.IsNil() {
return writeBytesValue(w, null)
}
Expand Down Expand Up @@ -276,7 +276,7 @@ func (enc *Encoder) Reset() {
func safeError(err error) (s string, ok bool) {
defer func() {
if panicVal := recover(); panicVal != nil {
if v := reflect.ValueOf(err); v.Kind() == reflect.Ptr && v.IsNil() {
if v := reflect.ValueOf(err); v.Kind() == reflect.Pointer && v.IsNil() {
s, ok = "null", false
} else {
s, ok = fmt.Sprintf("PANIC:%v", panicVal), false
Expand All @@ -290,7 +290,7 @@ func safeError(err error) (s string, ok bool) {
func safeString(str fmt.Stringer) (s string, ok bool) {
defer func() {
if panicVal := recover(); panicVal != nil {
if v := reflect.ValueOf(str); v.Kind() == reflect.Ptr && v.IsNil() {
if v := reflect.ValueOf(str); v.Kind() == reflect.Pointer && v.IsNil() {
s, ok = "null", false
} else {
s, ok = fmt.Sprintf("PANIC:%v", panicVal), true
Expand All @@ -304,7 +304,7 @@ func safeString(str fmt.Stringer) (s string, ok bool) {
func safeMarshal(tm encoding.TextMarshaler) (b []byte, err error) {
defer func() {
if panicVal := recover(); panicVal != nil {
if v := reflect.ValueOf(tm); v.Kind() == reflect.Ptr && v.IsNil() {
if v := reflect.ValueOf(tm); v.Kind() == reflect.Pointer && v.IsNil() {
b, err = nil, nil
} else {
b, err = nil, fmt.Errorf("panic when marshalling: %s", panicVal)
Expand Down
101 changes: 80 additions & 21 deletions encode_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"io"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

func TestSafeString(t *testing.T) {
Expand All @@ -29,23 +32,23 @@ func TestSafeMarshal(t *testing.T) {
func TestWriteKeyStrings(t *testing.T) {
keygen := []struct {
name string
fn func(string) interface{}
fn func(string) any
}{
{
name: "string",
fn: func(s string) interface{} { return s },
fn: func(s string) any { return s },
},
{
name: "named-string",
fn: func(s string) interface{} { return stringData(s) },
fn: func(s string) any { return stringData(s) },
},
{
name: "Stringer",
fn: func(s string) interface{} { return stringStringer(s) },
fn: func(s string) any { return stringStringer(s) },
},
{
name: "TextMarshaler",
fn: func(s string) interface{} { return stringMarshaler(s) },
fn: func(s string) any { return stringMarshaler(s) },
},
}

Expand Down Expand Up @@ -99,7 +102,7 @@ func TestWriteKey(t *testing.T) {
)

data := []struct {
key interface{}
key any
want string
err error
}{
Expand All @@ -110,7 +113,6 @@ func TestWriteKey(t *testing.T) {
{key: (*stringerMarshaler)(nil), err: ErrNilKey},
{key: ptr, want: "1"},

{key: errorMarshaler{}, err: &MarshalerError{Type: reflect.TypeOf(errorMarshaler{}), Err: errMarshaling}},
{key: make(chan int), err: ErrUnsupportedKeyType},
{key: []int{}, err: ErrUnsupportedKeyType},
{key: map[int]int{}, err: ErrUnsupportedKeyType},
Expand All @@ -122,8 +124,8 @@ func TestWriteKey(t *testing.T) {
for _, d := range data {
w := &bytes.Buffer{}
err := writeKey(w, d.key)
if !reflect.DeepEqual(err, d.err) {
t.Errorf("%#v: got error: %v, want error: %v", d.key, err, d.err)
if diff := cmp.Diff(d.err, err, cmpopts.EquateErrors()); diff != "" {
t.Errorf("%#v: error value mismatch (-want,+got):\n%s", d.key, diff)
}
if err != nil {
continue
Expand All @@ -134,13 +136,42 @@ func TestWriteKey(t *testing.T) {
}
}

func TestWriteKeyMarshalError(t *testing.T) {
data := []struct {
key any
want string
err error
}{
{key: errorMarshaler{}, err: &MarshalerError{Type: reflect.TypeOf(errorMarshaler{}), Err: errMarshaling}},
}

for _, d := range data {
w := &bytes.Buffer{}
err := writeKey(w, d.key)

switch err := err.(type) {
case nil:
t.Errorf("%#v: err == nil, want: not nil", d.key)
case *MarshalerError:
if got, want := err.Type, reflect.TypeOf(errorMarshaler{}); got != want {
t.Errorf("%#v: MarshalerError.Type == %v, want: %v", d.key, got, want)
}
if diff := cmp.Diff(errMarshaling, err.Err, cmpopts.EquateErrors()); diff != "" {
t.Errorf("%#v: MarshalerError.Err value mismatch (-want,+got):\n%s", d.key, diff)
}
default:
t.Errorf("%#v: unexpected error, got: %q, want: a MarshalerError", d.key, err)
}
}
}

func TestWriteValueStrings(t *testing.T) {
keygen := []func(string) interface{}{
func(s string) interface{} { return s },
func(s string) interface{} { return errors.New(s) },
func(s string) interface{} { return stringData(s) },
func(s string) interface{} { return stringStringer(s) },
func(s string) interface{} { return stringMarshaler(s) },
keygen := []func(string) any{
func(s string) any { return s },
func(s string) any { return errors.New(s) },
func(s string) any { return stringData(s) },
func(s string) any { return stringStringer(s) },
func(s string) any { return stringMarshaler(s) },
}

data := []struct {
Expand Down Expand Up @@ -188,7 +219,7 @@ func TestWriteValue(t *testing.T) {
)

data := []struct {
value interface{}
value any
want string
err error
}{
Expand All @@ -199,7 +230,6 @@ func TestWriteValue(t *testing.T) {
{value: (*stringerMarshaler)(nil), want: "null"},
{value: ptr, want: "1"},

{value: errorMarshaler{}, err: &MarshalerError{Type: reflect.TypeOf(errorMarshaler{}), Err: errMarshaling}},
{value: make(chan int), err: ErrUnsupportedValueType},
{value: []int{}, err: ErrUnsupportedValueType},
{value: map[int]int{}, err: ErrUnsupportedValueType},
Expand All @@ -211,8 +241,8 @@ func TestWriteValue(t *testing.T) {
for _, d := range data {
w := &bytes.Buffer{}
err := writeValue(w, d.value)
if !reflect.DeepEqual(err, d.err) {
t.Errorf("%#v: got error: %v, want error: %v", d.value, err, d.err)
if diff := cmp.Diff(d.err, err, cmpopts.EquateErrors()); diff != "" {
t.Errorf("%#v: error value mismatch (-want,+got):\n%s", d.value, diff)
}
if err != nil {
continue
Expand All @@ -223,6 +253,35 @@ func TestWriteValue(t *testing.T) {
}
}

func TestWriteValueMarshalError(t *testing.T) {
data := []struct {
value any
want string
err error
}{
{value: errorMarshaler{}, err: &MarshalerError{Type: reflect.TypeOf(errorMarshaler{}), Err: errMarshaling}},
}

for _, d := range data {
w := &bytes.Buffer{}
err := writeValue(w, d.value)

switch err := err.(type) {
case nil:
t.Errorf("%#v: err == nil, want: not nil", d.value)
case *MarshalerError:
if got, want := err.Type, reflect.TypeOf(errorMarshaler{}); got != want {
t.Errorf("%#v: MarshalerError.Type == %v, want: %v", d.value, got, want)
}
if diff := cmp.Diff(errMarshaling, err.Err, cmpopts.EquateErrors()); diff != "" {
t.Errorf("%#v: MarshalerError.Err value mismatch (-want,+got):\n%s", d.value, diff)
}
default:
t.Errorf("%#v: unexpected error, got: %q, want: a MarshalerError", d.value, err)
}
}
}

type stringData string

type stringStringer string
Expand Down Expand Up @@ -266,7 +325,7 @@ func BenchmarkWriteStringKey(b *testing.B) {
for _, k := range keys {
b.Run(k, func(b *testing.B) {
for i := 0; i < b.N; i++ {
writeStringKey(ioutil.Discard, k)
writeStringKey(io.Discard, k)
}
})
}
Expand Down
Loading