diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2560488..d069293 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,7 +10,7 @@ jobs: name: Build strategy: matrix: - go: ["1.20", "1.21"] + go: ["1.24", "1.25"] runs-on: ubuntu-latest container: golang:${{ matrix.go }}-bookworm steps: diff --git a/go.mod b/go.mod index edd23af..16eed6e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/stellar/go-xdr -go 1.12 +go 1.22 diff --git a/xdr3/bench_test.go b/xdr3/bench_test.go deleted file mode 100644 index 19d4886..0000000 --- a/xdr3/bench_test.go +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2012-2014 Dave Collins - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -package xdr_test - -import ( - "bytes" - "testing" - "unsafe" - - xdr "github.com/stellar/go-xdr/xdr3" -) - -// BenchmarkUnmarshal benchmarks the Unmarshal function by using a dummy -// ImageHeader structure. -func BenchmarkUnmarshal(b *testing.B) { - b.StopTimer() - // Hypothetical image header format. - type ImageHeader struct { - Signature [3]byte - Version uint32 - IsGrayscale bool - NumSections uint32 - } - // XDR encoded data described by the above structure. - encodedData := []byte{ - 0xAB, 0xCD, 0xEF, 0x00, - 0x00, 0x00, 0x00, 0x02, - 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x0A, - } - var h ImageHeader - b.StartTimer() - - for i := 0; i < b.N; i++ { - r := bytes.NewReader(encodedData) - _, _ = xdr.Unmarshal(r, &h) - } - b.SetBytes(int64(len(encodedData))) -} - -// BenchmarkMarshal benchmarks the Marshal function by using a dummy ImageHeader -// structure. -func BenchmarkMarshal(b *testing.B) { - b.StopTimer() - // Hypothetical image header format. - type ImageHeader struct { - Signature [3]byte - Version uint32 - IsGrayscale bool - NumSections uint32 - } - h := ImageHeader{[3]byte{0xAB, 0xCD, 0xEF}, 2, true, 10} - size := unsafe.Sizeof(h) - w := bytes.NewBuffer(nil) - b.StartTimer() - - for i := 0; i < b.N; i++ { - w.Reset() - _, _ = xdr.Marshal(w, &h) - } - b.SetBytes(int64(size)) -} diff --git a/xdr3/decode.go b/xdr3/decode.go index 1c8f383..4fe531d 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -1,5 +1,6 @@ /* * Copyright (c) 2012-2014 Dave Collins + * Copyright (c) 2026 Stellar Development Foundation * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -17,8 +18,8 @@ package xdr import ( + "encoding/binary" "fmt" - "io" "math" "reflect" "strconv" @@ -28,7 +29,6 @@ import ( const maxInt32 = math.MaxInt32 var errMaxSlice = "data exceeds max slice limit" -var errIODecode = "%s while decoding %d bytes" // DecodeDefaultMaxDepth is the default maximum decoding depth const DecodeDefaultMaxDepth = 200 @@ -39,30 +39,33 @@ type DecodeOptions struct { // It prevents infinite recursions in cyclic datastructures and determines the maximum callstack growth. // If set to 0, DecodeDefaultMaxDepth will be used. MaxDepth uint - - // MaxInputLen sets the maximum input size. It is used by the decoder to sanity-check - // allocation sizes and avoid heap explosions from doctored inputs. - // - // If set to 0, the decoder will try to figure out the input size by checking whether - // the provided io.Reader implements Len() (e.g. strings.Reader, bytes.Reader and bytes.Buffer do). - // Otherwise, no sanity checks will be done. - MaxInputLen int } // DefaultDecodeOptions are the default decoding options. var DefaultDecodeOptions = DecodeOptions{ - MaxDepth: DecodeDefaultMaxDepth, - MaxInputLen: 0, + MaxDepth: DecodeDefaultMaxDepth, +} + +// DecoderFrom is implemented by types that can decode themselves from a Decoder. +// Types implementing this interface get a fast path in Decode(), bypassing reflection. +// The maxDepth parameter tracks recursion depth to prevent stack overflow from +// maliciously crafted deeply-nested data. Implementations should decrement maxDepth +// when calling DecodeFrom on nested types and return an error if maxDepth reaches 0. +type DecoderFrom interface { + DecodeFrom(d *Decoder, maxDepth uint) (int, error) } /* Unmarshal parses XDR-encoded data into the value pointed to by v reading from -reader r and returning the total number of bytes read. An addressable pointer +a byte slice and returning the total number of bytes read. An addressable pointer must be provided since Unmarshal needs to both store the result of the decode as well as obtain target type information. Unmarhsal traverses v recursively and automatically indirects pointers through arbitrary depth, allocating them as necessary, to decode the data into the underlying value pointed to. +If v implements DecoderFrom, its DecodeFrom method is called directly for +better performance. Otherwise, reflection-based decoding is used. + Unmarshal uses reflection to determine the type of the concrete value contained by v and performs a mapping of underlying XDR types to Go types as follows: @@ -99,82 +102,71 @@ an ErrorCode value for further inspection from sophisticated callers. Some potential issues are unsupported Go types, attempting to decode a value which is too large to fit into a specified Go type, and exceeding max slice limitations. */ -func Unmarshal(r io.Reader, v interface{}) (int, error) { - d := NewDecoder(r) +func Unmarshal(data []byte, v interface{}) (int, error) { + d := NewDecoder(data) return d.Decode(v) } // UnmarshalWithOptions works like Unmarshal but accepts decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options DecodeOptions) (int, error) { - d := NewDecoderWithOptions(r, options) +func UnmarshalWithOptions(data []byte, v interface{}, options DecodeOptions) (int, error) { + d := NewDecoderWithOptions(data, options) return d.Decode(v) } -type lenLeft interface { - Len() int -} - -// A Decoder wraps an io.Reader that is expected to provide an XDR-encoded byte -// stream and provides several exposed methods to manually decode various XDR -// primitives without relying on reflection. The NewDecoder function can be -// used to get a new Decoder directly. +// A Decoder reads XDR-encoded data from a byte slice and provides several +// exposed methods to manually decode various XDR primitives without relying +// on reflection. The NewDecoder function can be used to get a new Decoder directly. // -// Typically, Unmarshal should be used instead of manual decoding. A Decoder +// Typically, Unmarshal should be used instead of manual decoding. A Decoder // is exposed, so it is possible to perform manual decoding should it be // necessary in complex scenarios where automatic reflection-based decoding // won't work. type Decoder struct { - // used to minimize heap allocations during decoding - scratchBuf [8]byte - r io.Reader - l lenLeft - maxDepth uint -} - -// readerLenWrapper wraps a reader an initial length and provides a Len() method indicating -// how much input is left -type readerLenWrapper struct { - inner io.Reader - readCount int - initialLen int -} - -func (l *readerLenWrapper) Len() int { - return l.initialLen - l.readCount -} - -func (l *readerLenWrapper) Read(p []byte) (int, error) { - n, err := l.inner.Read(p) - if n > 0 { - l.readCount += n - } - return n, err + buf []byte + pos int + maxDepth uint } // NewDecoder returns a Decoder that can be used to manually decode XDR data -// from a provided reader. Typically, Unmarshal should be used instead of +// from a provided byte slice. Typically, Unmarshal should be used instead of // manually creating a Decoder. -func NewDecoder(r io.Reader) *Decoder { - return NewDecoderWithOptions(r, DefaultDecodeOptions) +func NewDecoder(data []byte) *Decoder { + return NewDecoderWithOptions(data, DefaultDecodeOptions) } // NewDecoderWithOptions works like NewDecoder but allows supplying decoding options. -func NewDecoderWithOptions(r io.Reader, options DecodeOptions) *Decoder { +func NewDecoderWithOptions(data []byte, options DecodeOptions) *Decoder { maxDepth := options.MaxDepth if maxDepth < 1 { maxDepth = DecodeDefaultMaxDepth } - if l, ok := r.(lenLeft); ok { - return &Decoder{r: r, l: l, maxDepth: maxDepth} + return &Decoder{ + buf: data, + pos: 0, + maxDepth: maxDepth, } - if options.MaxInputLen > 0 { - rlw := &readerLenWrapper{ - inner: r, - initialLen: options.MaxInputLen, - } - return &Decoder{r: rlw, l: rlw, maxDepth: maxDepth} - } - return &Decoder{r: r, l: nil, maxDepth: options.MaxDepth} +} + +// Reset resets the decoder to read from a new byte slice, allowing reuse +// of the decoder to reduce allocations. +func (d *Decoder) Reset(data []byte) { + d.buf = data + d.pos = 0 +} + +// Remaining returns the number of unread bytes in the buffer. +func (d *Decoder) Remaining() int { + return len(d.buf) - d.pos +} + +// Position returns the current read position in the buffer. +func (d *Decoder) Position() int { + return d.pos +} + +// MaxDepth returns the maximum decoding depth setting. +func (d *Decoder) MaxDepth() uint { + return d.maxDepth } // DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the @@ -187,16 +179,12 @@ func NewDecoderWithOptions(r io.Reader, options DecodeOptions) *Decoder { // RFC Section 4.1 - Integer // 32-bit big-endian signed integer in range [-2147483648, 2147483647] func (d *Decoder) DecodeInt() (int32, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:4]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeInt", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 4 { + return 0, 0, unmarshalError("DecodeInt", ErrIO, "unexpected end of input", nil, nil) } - - rv := int32(d.scratchBuf[3]) | int32(d.scratchBuf[2])<<8 | - int32(d.scratchBuf[1])<<16 | int32(d.scratchBuf[0])<<24 - return rv, n, nil + v := int32(binary.BigEndian.Uint32(d.buf[d.pos:])) + d.pos += 4 + return v, 4, nil } // DecodeUint treats the next 4 bytes as an XDR encoded unsigned integer and @@ -209,16 +197,12 @@ func (d *Decoder) DecodeInt() (int32, int, error) { // RFC Section 4.2 - Unsigned Integer // 32-bit big-endian unsigned integer in range [0, 4294967295] func (d *Decoder) DecodeUint() (uint32, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:4]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeUint", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 4 { + return 0, 0, unmarshalError("DecodeUint", ErrIO, "unexpected end of input", nil, nil) } - - rv := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 | - uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24 - return rv, n, nil + v := binary.BigEndian.Uint32(d.buf[d.pos:]) + d.pos += 4 + return v, 4, nil } // DecodeEnum treats the next 4 bytes as an XDR encoded enumeration value and @@ -284,18 +268,12 @@ func (d *Decoder) DecodeBool() (bool, int, error) { // RFC Section 4.5 - Hyper Integer // 64-bit big-endian signed integer in range [-9223372036854775808, 9223372036854775807] func (d *Decoder) DecodeHyper() (int64, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:8]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeHyper", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 8 { + return 0, 0, unmarshalError("DecodeHyper", ErrIO, "unexpected end of input", nil, nil) } - - rv := int64(d.scratchBuf[7]) | int64(d.scratchBuf[6])<<8 | - int64(d.scratchBuf[5])<<16 | int64(d.scratchBuf[4])<<24 | - int64(d.scratchBuf[3])<<32 | int64(d.scratchBuf[2])<<40 | - int64(d.scratchBuf[1])<<48 | int64(d.scratchBuf[0])<<56 - return rv, n, err + v := int64(binary.BigEndian.Uint64(d.buf[d.pos:])) + d.pos += 8 + return v, 8, nil } // DecodeUhyper treats the next 8 bytes as an XDR encoded unsigned hyper value @@ -309,18 +287,12 @@ func (d *Decoder) DecodeHyper() (int64, int, error) { // RFC Section 4.5 - Unsigned Hyper Integer // 64-bit big-endian unsigned integer in range [0, 18446744073709551615] func (d *Decoder) DecodeUhyper() (uint64, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:8]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeUhyper", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 8 { + return 0, 0, unmarshalError("DecodeUhyper", ErrIO, "unexpected end of input", nil, nil) } - - rv := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 | - uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 | - uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 | - uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56 - return rv, n, nil + v := binary.BigEndian.Uint64(d.buf[d.pos:]) + d.pos += 8 + return v, 8, nil } // DecodeFloat treats the next 4 bytes as an XDR encoded floating point and @@ -333,16 +305,12 @@ func (d *Decoder) DecodeUhyper() (uint64, int, error) { // RFC Section 4.6 - Floating Point // 32-bit single-precision IEEE 754 floating point func (d *Decoder) DecodeFloat() (float32, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:4]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeFloat", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 4 { + return 0, 0, unmarshalError("DecodeFloat", ErrIO, "unexpected end of input", nil, nil) } - - val := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 | - uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24 - return math.Float32frombits(val), n, nil + v := binary.BigEndian.Uint32(d.buf[d.pos:]) + d.pos += 4 + return math.Float32frombits(v), 4, nil } // DecodeDouble treats the next 8 bytes as an XDR encoded double-precision @@ -356,18 +324,12 @@ func (d *Decoder) DecodeFloat() (float32, int, error) { // RFC Section 4.7 - Double-Precision Floating Point // 64-bit double-precision IEEE 754 floating point func (d *Decoder) DecodeDouble() (float64, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:8]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeDouble", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 8 { + return 0, 0, unmarshalError("DecodeDouble", ErrIO, "unexpected end of input", nil, nil) } - - val := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 | - uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 | - uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 | - uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56 - return math.Float64frombits(val), n, nil + v := binary.BigEndian.Uint64(d.buf[d.pos:]) + d.pos += 8 + return math.Float64frombits(v), 8, nil } // RFC Section 4.8 - Quadruple-Precision Floating Point @@ -395,54 +357,43 @@ func (d *Decoder) DecodeFixedOpaque(size int32) ([]byte, int, error) { return out, n, nil } -// DecodeFixedOpaqueInplace is an in-place version of DecodeFixedOpaque. -// It improves performance when the destination is pre-allocated (which avoids -// internally allocating an extra slice and does not require further copying) -func (d *Decoder) DecodeFixedOpaqueInplace(out []byte) (int, error) { - size := len(out) - // Nothing to do if size is 0. +// consumePaddedData validates XDR padding and advances the position. +// Returns the start position in the buffer and the padded size. +// Callers can access d.buf[start:start+size] for the actual data. +func (d *Decoder) consumePaddedData(size int) (start, paddedSize int, err error) { if size == 0 { - return 0, nil + return d.pos, 0, nil } - pad := (4 - (size % 4)) % 4 - paddedSize := size + pad + paddedSize = (size + 3) &^ 3 // Round up to multiple of 4 if uint(paddedSize) > uint(maxInt32) { - err := unmarshalError("DecodeFixedOpaqueInplace", ErrOverflow, - errMaxSlice, paddedSize, nil) - return 0, err + return 0, 0, unmarshalError("consumePaddedData", ErrOverflow, errMaxSlice, paddedSize, nil) } - - n, err := io.ReadFull(d.r, out) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), size) - err := unmarshalError("DecodeFixedOpaqueInplace", ErrIO, msg, out[:n], - err) - return n, err + if d.Remaining() < paddedSize { + return 0, 0, unmarshalError("consumePaddedData", ErrIO, "unexpected end of input", nil, nil) } - if pad > 0 { - // the maximum value of pad is 3, so the scratch buffer should be enough - _ = d.scratchBuf[2] - padding := d.scratchBuf[:pad] - n2, err := io.ReadFull(d.r, padding) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), pad) - err := unmarshalError("DecodeFixedOpaqueInplace", ErrIO, msg, out[:n], - err) - return n, err - } - n += n2 - // check all the padding bytes to be zero - for _, p := range padding { - if p != 0x00 { - msg := "non-zero padding" - err := unmarshalError("DecodeFixedOpaqueInplace", ErrIO, msg, padding[:n2], nil) - return n, err - } + // Validate padding bytes are zero + for i := size; i < paddedSize; i++ { + if d.buf[d.pos+i] != 0 { + return 0, 0, unmarshalError("consumePaddedData", ErrIO, "non-zero padding", d.buf[d.pos+size:d.pos+paddedSize], nil) } } + start = d.pos + d.pos += paddedSize + return start, paddedSize, nil +} + +// DecodeFixedOpaqueInplace is an in-place version of DecodeFixedOpaque. +// It improves performance when the destination is pre-allocated (which avoids +// internally allocating an extra slice and does not require further copying) +func (d *Decoder) DecodeFixedOpaqueInplace(out []byte) (int, error) { + start, n, err := d.consumePaddedData(len(out)) + if err != nil { + return 0, err + } + copy(out, d.buf[start:start+len(out)]) return n, nil } @@ -463,15 +414,9 @@ func (d *Decoder) DecodeOpaque(maxSize int) ([]byte, int, error) { return nil, n, err } - maxSize = d.mergeInputLenAndMaxSize(maxSize) - if maxSize == 0 { - maxSize = maxInt32 - } - + maxSize = d.mergeRemainingAndMaxSize(maxSize) if uint(dataLen) > uint(maxSize) { - err := unmarshalError("DecodeOpaque", ErrOverflow, errMaxSlice, - dataLen, nil) - return nil, n, err + return nil, n, unmarshalError("DecodeOpaque", ErrOverflow, errMaxSlice, dataLen, nil) } rv, n2, err := d.DecodeFixedOpaque(int32(dataLen)) @@ -503,23 +448,49 @@ func (d *Decoder) DecodeString(maxSize int) (string, int, error) { return "", n, err } - maxSize = d.mergeInputLenAndMaxSize(maxSize) - if maxSize == 0 { - maxSize = maxInt32 - } - + maxSize = d.mergeRemainingAndMaxSize(maxSize) if uint(dataLen) > uint(maxSize) { - err = unmarshalError("DecodeString", ErrOverflow, errMaxSlice, - dataLen, nil) - return "", n, err + return "", n, unmarshalError("DecodeString", ErrOverflow, errMaxSlice, dataLen, nil) } - opaque, n2, err := d.DecodeFixedOpaque(int32(dataLen)) - n += n2 + start, n2, err := d.consumePaddedData(int(dataLen)) if err != nil { return "", n, err } - return string(opaque), n, nil + + return string(d.buf[start : start+int(dataLen)]), n + n2, nil +} + +// Skip advances the decoder position by n bytes without decoding. +func (d *Decoder) Skip(n int) error { + if n < 0 { + return unmarshalError("Skip", ErrBadArguments, "negative skip length", n, nil) + } + if d.Remaining() < n { + return unmarshalError("Skip", ErrIO, "unexpected end of input", nil, nil) + } + d.pos += n + return nil +} + +// Bytes returns the remaining unread bytes in the buffer. +// WARNING: The returned slice shares memory with the input buffer. +func (d *Decoder) Bytes() []byte { + return d.buf[d.pos:] +} + +// mergeRemainingAndMaxSize returns the effective maximum size for decoding, +// taking into account both the user-specified maxSize and the remaining +// input bytes. If maxSize is 0 (unlimited), it defaults to maxInt32. +// The result is the minimum of maxSize and remaining bytes. +func (d *Decoder) mergeRemainingAndMaxSize(maxSize int) int { + if maxSize == 0 { + maxSize = maxInt32 + } + if remaining := d.Remaining(); remaining < maxSize { + return remaining + } + return maxSize } // decodeFixedArray treats the next bytes as a series of XDR encoded elements @@ -577,15 +548,9 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int, m return n, err } - maxSize = d.mergeInputLenAndMaxSize(maxSize) - if maxSize == 0 { - maxSize = maxInt32 - } - + maxSize = d.mergeRemainingAndMaxSize(maxSize) if uint(dataLen) > uint(maxSize) { - err := unmarshalError("decodeArray", ErrOverflow, errMaxSlice, - dataLen, nil) - return n, err + return n, unmarshalError("decodeArray", ErrOverflow, errMaxSlice, dataLen, nil) } // Allocate storage for the slice elements (the underlying array) if @@ -621,10 +586,9 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int, m func setUnionArmsToNil(v reflect.Value) { for i := 0; i < v.NumField(); i++ { f := v.Field(i) - if f.Kind() != reflect.Ptr { - continue + if f.Kind() == reflect.Ptr && f.CanSet() { + f.Set(reflect.Zero(f.Type())) } - v.Set(reflect.Zero(v.Type())) } } @@ -644,20 +608,23 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { vs := v.FieldByName(u.SwitchFieldName()) // ensure the switch field is a valid enum value for the union, if possible - enum, ok := vs.Interface().(Enum) - - if ok && !enum.ValidEnum(i) { + enum, isEnum := vs.Interface().(Enum) + if isEnum && !enum.ValidEnum(i) { msg := fmt.Sprintf("switch '%d' is not valid enum value for union", i) err := unmarshalError("decode", ErrBadUnionSwitch, msg, nil, nil) return n, err } + // Set switch field value with proper type handling kind := vs.Kind() - if kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || - kind == reflect.Uint32 || kind == reflect.Uint64 { - vs.SetUint(uint64(i)) - } else { + switch kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: vs.SetInt(int64(i)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vs.SetUint(uint64(i)) + default: + return n, unmarshalError("decodeUnion", ErrBadUnionSwitch, + fmt.Sprintf("switch field has unsupported type: %v", kind), nil, nil) } arm, ok := u.ArmForSwitch(i) @@ -672,11 +639,7 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { return n, nil } - vv := v.FieldByName(arm) - - vvet := vv.Type().Elem() - vv.Set(reflect.New(vvet)) - + // Validate field exists in type before accessing it field, ok := v.Type().FieldByName(arm) if !ok { msg := fmt.Sprintf("switch '%s' is not valid for union", arm) @@ -684,6 +647,8 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { return n, err } + vv := v.FieldByName(arm) + maxSize := 0 sizeTag := field.Tag.Get("xdrmaxsize") if sizeTag != "" { @@ -694,12 +659,25 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { maxSize = int(sz) } - n2, err := d.decode(vv.Elem(), maxSize, maxDepth) - n += n2 - - if err != nil { - return n, err + // Handle both pointer and value-type union arms + if vv.Kind() == reflect.Ptr { + // Pointer field - allocate new value and decode into it + vvet := vv.Type().Elem() + vv.Set(reflect.New(vvet)) + n2, err := d.decode(vv.Elem(), maxSize, maxDepth) + n += n2 + if err != nil { + return n, err + } + } else { + // Value field - decode directly into the field + n2, err := d.decode(vv, maxSize, maxDepth) + n += n2 + if err != nil { + return n, err + } } + return n, nil } @@ -808,10 +786,11 @@ func (d *Decoder) decodeMap(v reflect.Value, maxDepth uint) (int, error) { if err != nil { return n, err } - if left, ok := d.InputLen(); ok { - if uint(left) < uint(dataLen) { - return n, unmarshalError("decodeMap", ErrOverflow, errMaxSlice, dataLen, nil) - } + // Sanity check: each map entry requires at least 8 bytes (4 for key + 4 for value) + // This prevents allocating huge maps based on malicious length values + // Use multiplication (with uint64 to prevent overflow) instead of division for precision + if uint64(dataLen)*8 > uint64(d.Remaining()) { + return n, unmarshalError("decodeMap", ErrOverflow, errMaxSlice, dataLen, nil) } // Allocate storage for the underlying map if needed. @@ -875,15 +854,6 @@ func (d *Decoder) decodeInterface(v reflect.Value, maxDepth uint) (int, error) { return d.decode(ve, 0, maxDepth) } -func (d *Decoder) mergeInputLenAndMaxSize(maxSize int) int { - if left, ok := d.InputLen(); ok { - if maxSize == 0 || left < maxSize { - return left - } - } - return maxSize -} - // decode is the main workhorse for unmarshalling via reflection. It uses // the passed reflection value to choose the XDR primitives to decode from // the encapsulated reader. It is a recursive function, @@ -1148,9 +1118,12 @@ func (d *Decoder) indirectIfPtr(v reflect.Value) (reflect.Value, error) { } // Decode operates identically to the Unmarshal function with the exception of -// using the reader associated with the Decoder as the source of XDR-encoded -// data instead of a user-supplied reader. See the Unmarhsal documentation for -// specifics. Decode(v) is equivalent to DecodeWithMaxDepth(v, DecodeDefaultMaxDepth) +// using the byte slice associated with the Decoder as the source of XDR-encoded +// data instead of a user-supplied byte slice. See the Unmarshal documentation for +// specifics. +// +// If v implements DecoderFrom, its DecodeFrom method is called directly for +// better performance, bypassing reflection. func (d *Decoder) Decode(v interface{}) (int, error) { if v == nil { msg := "can't unmarshal to nil interface" @@ -1158,6 +1131,12 @@ func (d *Decoder) Decode(v interface{}) (int, error) { nil) } + // Fast path: if v implements DecoderFrom, use it directly + if decodable, ok := v.(DecoderFrom); ok { + return decodable.DecodeFrom(d, d.maxDepth) + } + + // Fallback: reflection-based decoding vv := reflect.ValueOf(v) if vv.Kind() != reflect.Ptr { msg := fmt.Sprintf("can't unmarshal to non-pointer '%v' - use "+ @@ -1174,11 +1153,3 @@ func (d *Decoder) Decode(v interface{}) (int, error) { return d.decode(vv.Elem(), 0, d.maxDepth) } - -// InputLen returns the size left to read from the decoder's input if available -func (d *Decoder) InputLen() (int, bool) { - if d.l == nil { - return 0, false - } - return d.l.Len(), true -} diff --git a/xdr3/decode_bench_test.go b/xdr3/decode_bench_test.go new file mode 100644 index 0000000..4d1312b --- /dev/null +++ b/xdr3/decode_bench_test.go @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2012-2014 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package xdr_test + +import ( + "testing" + + xdr "github.com/stellar/go-xdr/xdr3" +) + +// Test data for benchmarks +var ( + // XDR encoded int32 (value: 12345678) + encodedInt = []byte{0x00, 0xBC, 0x61, 0x4E} + + // XDR encoded uint32 (value: 0xDEADBEEF) + encodedUint = []byte{0xDE, 0xAD, 0xBE, 0xEF} + + // XDR encoded int64/hyper (value: 0x0102030405060708) + encodedHyper = []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + + // XDR encoded string "hello world" (length 11, padded to 12 bytes) + encodedString = []byte{ + 0x00, 0x00, 0x00, 0x0B, // length = 11 + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x00, // "hello world" + 1 byte padding + } + + // XDR encoded longer string (64 bytes of 'x', padded) + encodedLongString = func() []byte { + data := make([]byte, 4+64) // length prefix + 64 chars (no padding needed, 64 % 4 == 0) + data[0], data[1], data[2], data[3] = 0x00, 0x00, 0x00, 0x40 // length = 64 + for i := 4; i < 68; i++ { + data[i] = 'x' + } + return data + }() + + // XDR encoded variable opaque (32 bytes) + encodedOpaque = func() []byte { + data := make([]byte, 4+32) // length prefix + 32 bytes + data[0], data[1], data[2], data[3] = 0x00, 0x00, 0x00, 0x20 // length = 32 + for i := 4; i < 36; i++ { + data[i] = byte(i - 4) + } + return data + }() + + // XDR encoded fixed opaque (32 bytes, no length prefix) + encodedFixedOpaque = func() []byte { + data := make([]byte, 32) + for i := 0; i < 32; i++ { + data[i] = byte(i) + } + return data + }() +) + +// ============================================================================ +// Primitive Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeInt(b *testing.B) { + d := xdr.NewDecoder(encodedInt) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedInt) + _, _, _ = d.DecodeInt() + } + b.SetBytes(int64(len(encodedInt))) +} + +func BenchmarkDecodeUint(b *testing.B) { + d := xdr.NewDecoder(encodedUint) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedUint) + _, _, _ = d.DecodeUint() + } + b.SetBytes(int64(len(encodedUint))) +} + +func BenchmarkDecodeHyper(b *testing.B) { + d := xdr.NewDecoder(encodedHyper) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedHyper) + _, _, _ = d.DecodeHyper() + } + b.SetBytes(int64(len(encodedHyper))) +} + +// ============================================================================ +// String Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeString(b *testing.B) { + d := xdr.NewDecoder(encodedString) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedString) + _, _, _ = d.DecodeString(0) + } + b.SetBytes(int64(len(encodedString))) +} + +func BenchmarkDecodeLongString(b *testing.B) { + d := xdr.NewDecoder(encodedLongString) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedLongString) + _, _, _ = d.DecodeString(0) + } + b.SetBytes(int64(len(encodedLongString))) +} + +// ============================================================================ +// Opaque Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeOpaque(b *testing.B) { + d := xdr.NewDecoder(encodedOpaque) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedOpaque) + _, _, _ = d.DecodeOpaque(0) + } + b.SetBytes(int64(len(encodedOpaque))) +} + +// ============================================================================ +// Fixed Opaque Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeFixedOpaque(b *testing.B) { + d := xdr.NewDecoder(encodedFixedOpaque) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedFixedOpaque) + _, _, _ = d.DecodeFixedOpaque(32) + } + b.SetBytes(int64(len(encodedFixedOpaque))) +} + +// ============================================================================ +// Multiple Field Decoding Benchmarks (simulating struct decoding) +// ============================================================================ + +// Encoded struct with: uint32, int64, string "hello world", 32-byte opaque +var encodedStruct = func() []byte { + var buf []byte + buf = append(buf, encodedUint...) + buf = append(buf, encodedHyper...) + buf = append(buf, encodedString...) + buf = append(buf, 0x00, 0x00, 0x00, 0x20) // opaque length = 32 + for i := 0; i < 32; i++ { + buf = append(buf, byte(i)) + } + return buf +}() + +func BenchmarkDecodeMultipleFields(b *testing.B) { + d := xdr.NewDecoder(encodedStruct) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedStruct) + _, _, _ = d.DecodeUint() + _, _, _ = d.DecodeHyper() + _, _, _ = d.DecodeString(0) + _, _, _ = d.DecodeOpaque(0) + } + b.SetBytes(int64(len(encodedStruct))) +} + +// ============================================================================ +// Decoder Reuse Benchmarks +// ============================================================================ + +// benchDecoderFromType implements DecoderFrom for benchmarking +type benchDecoderFromType struct { + Value int32 +} + +func (t *benchDecoderFromType) DecodeFrom(d *xdr.Decoder) (int, error) { + v, n, err := d.DecodeInt() + if err != nil { + return n, err + } + t.Value = v + return n, nil +} + +// BenchmarkNewDecoderEachTime measures creating a new decoder for each decode +func BenchmarkNewDecoderEachTime(b *testing.B) { + data := encodedInt + var result benchDecoderFromType + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d := xdr.NewDecoder(data) + _, _ = d.Decode(&result) + } + b.SetBytes(int64(len(data))) +} + +// BenchmarkDecoderReuse measures reusing a decoder with Reset+Decode +func BenchmarkDecoderReuse(b *testing.B) { + data := encodedInt + d := xdr.NewDecoder(nil) + var result benchDecoderFromType + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(data) + _, _ = d.Decode(&result) + } + b.SetBytes(int64(len(data))) +} diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index 747ebf9..ba0b71d 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -1,5 +1,6 @@ /* * Copyright (c) 2012-2014 Dave Collins + * Copyright (c) 2026 Stellar Development Foundation * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -14,1181 +15,1223 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -package xdr_test +package xdr import ( - "bytes" - "encoding/base64" - "fmt" - "math" + "errors" "reflect" "testing" - "time" - - . "github.com/stellar/go-xdr/xdr3" ) -// subTest is used to allow testing of the Unmarshal function into struct fields -// which are structs themselves. -type subTest struct { - A string - B uint8 -} +// TestDecoder_EOF tests that Decoder returns proper errors on EOF +func TestDecoder_EOF(t *testing.T) { + tests := []struct { + name string + data []byte + decode func(d *Decoder) error + errCode ErrorCode // expected error code, defaults to ErrIO + }{ + { + name: "DecodeInt EOF", + data: []byte{0x00, 0x00}, // Only 2 bytes, need 4 + decode: func(d *Decoder) error { + _, _, err := d.DecodeInt() + return err + }, + }, + { + name: "DecodeUint EOF", + data: []byte{0x00, 0x00, 0x00}, // Only 3 bytes, need 4 + decode: func(d *Decoder) error { + _, _, err := d.DecodeUint() + return err + }, + }, + { + name: "DecodeHyper EOF", + data: []byte{0x00, 0x00, 0x00, 0x00}, // Only 4 bytes, need 8 + decode: func(d *Decoder) error { + _, _, err := d.DecodeHyper() + return err + }, + }, + { + name: "DecodeUhyper EOF", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, // Only 5 bytes, need 8 + decode: func(d *Decoder) error { + _, _, err := d.DecodeUhyper() + return err + }, + }, + { + name: "DecodeFloat EOF", + data: []byte{0x00}, // Only 1 byte, need 4 + decode: func(d *Decoder) error { + _, _, err := d.DecodeFloat() + return err + }, + }, + { + name: "DecodeDouble EOF", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, // Only 6 bytes, need 8 + decode: func(d *Decoder) error { + _, _, err := d.DecodeDouble() + return err + }, + }, + { + name: "DecodeBool EOF", + data: []byte{}, // Empty + decode: func(d *Decoder) error { + _, _, err := d.DecodeBool() + return err + }, + }, + { + name: "DecodeFixedOpaque EOF", + data: []byte{0x01, 0x02}, // Only 2 bytes, need 4 (size=3 padded) + decode: func(d *Decoder) error { + _, _, err := d.DecodeFixedOpaque(3) + return err + }, + }, + { + name: "DecodeFixedOpaqueInplace EOF", + data: []byte{0x01, 0x02, 0x03}, // Only 3 bytes, need 4 (padded) + decode: func(d *Decoder) error { + out := make([]byte, 3) + _, err := d.DecodeFixedOpaqueInplace(out) + return err + }, + }, + { + name: "DecodeOpaque EOF on length", + data: []byte{0x00, 0x00}, // Only 2 bytes for length prefix + decode: func(d *Decoder) error { + _, _, err := d.DecodeOpaque(0) + return err + }, + }, + { + name: "DecodeOpaque EOF on data", + data: []byte{0x00, 0x00, 0x00, 0x08, 0x01, 0x02}, // Length=8, only 2 data bytes + errCode: ErrOverflow, // length exceeds available data + decode: func(d *Decoder) error { + _, _, err := d.DecodeOpaque(0) + return err + }, + }, + { + name: "DecodeString EOF on length", + data: []byte{0x00}, // Only 1 byte for length prefix + decode: func(d *Decoder) error { + _, _, err := d.DecodeString(0) + return err + }, + }, + { + name: "DecodeString EOF on data", + data: []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e'}, // Length=5, only 2 chars + errCode: ErrOverflow, // length exceeds available data + decode: func(d *Decoder) error { + _, _, err := d.DecodeString(0) + return err + }, + }, + { + name: "Skip EOF", + data: []byte{0x01, 0x02}, + decode: func(d *Decoder) error { + return d.Skip(10) + }, + }, + } -// allTypesTest is used to allow testing of the Unmarshal function into struct -// fields of all supported types. -type allTypesTest struct { - A int8 - B uint8 - C int16 - D uint16 - E int32 - F uint32 - G int64 - H uint64 - I bool - J float32 - K float64 - L string - M []byte - N [3]byte - O []int16 - P [2]subTest - Q subTest - R map[string]uint32 - S time.Time -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + err := tt.decode(d) + if err == nil { + t.Fatal("expected error, got nil") + } -// opaqueStruct is used to test handling of uint8 slices and arrays. -type opaqueStruct struct { - Slice []uint8 `xdropaque:"false"` - Array [1]uint8 `xdropaque:"false"` + // Verify it's an UnmarshalError with expected error code + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + expectedCode := tt.errCode + if expectedCode == 0 { + expectedCode = ErrIO + } + if unmarshalErr.ErrorCode != expectedCode { + t.Errorf("expected %v, got %v", expectedCode, unmarshalErr.ErrorCode) + } + }) + } } -type AnEnum int32 - -func (e AnEnum) ValidEnum(v int32) bool { - return v < 3 -} +// TestDecoder_InvalidPadding tests that non-zero padding bytes cause errors +func TestDecoder_InvalidPadding(t *testing.T) { + tests := []struct { + name string + data []byte + decode func(d *Decoder) error + }{ + { + name: "DecodeFixedOpaque invalid padding", + data: []byte{0x01, 0x02, 0x03, 0xFF}, // 3 bytes data + non-zero padding + decode: func(d *Decoder) error { + _, _, err := d.DecodeFixedOpaque(3) + return err + }, + }, + { + name: "DecodeFixedOpaqueInplace invalid padding", + data: []byte{0x01, 0x02, 0x03, 0x01}, // 3 bytes data + non-zero padding + decode: func(d *Decoder) error { + out := make([]byte, 3) + _, err := d.DecodeFixedOpaqueInplace(out) + return err + }, + }, + { + name: "DecodeString invalid padding", + data: []byte{ + 0x00, 0x00, 0x00, 0x03, // length = 3 + 'a', 'b', 'c', 0x01, // "abc" + non-zero padding + }, + decode: func(d *Decoder) error { + _, _, err := d.DecodeString(0) + return err + }, + }, + } -type aUnion struct { - Type AnEnum - Data *int32 - Text *string `xdrmaxsize:"28"` -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + err := tt.decode(d) + if err == nil { + t.Fatal("expected error, got nil") + } -func (u aUnion) SwitchFieldName() string { - return "Type" + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrIO { + t.Errorf("expected ErrIO for padding error, got %v", unmarshalErr.ErrorCode) + } + }) + } } -func (u aUnion) ArmForSwitch(sw int32) (string, bool) { - switch sw { - case 0: - return "Data", true - case 1: - return "Text", true - case 2: // void - return "", true +// TestDecoder_InvalidBool tests that invalid boolean values cause errors +func TestDecoder_InvalidBool(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"bool value 2", []byte{0x00, 0x00, 0x00, 0x02}}, + {"bool value -1", []byte{0xFF, 0xFF, 0xFF, 0xFF}}, + {"bool value 100", []byte{0x00, 0x00, 0x00, 0x64}}, } - return "-", false -} - -type structWithPointer struct { - Data *string -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + _, _, err := d.DecodeBool() + if err == nil { + t.Fatal("expected error, got nil") + } -// testExpectedURet is a convenience method to test an expected number of bytes -// read and error for an unmarshal. -func testExpectedURet(t *testing.T, name string, n, wantN int, err, wantErr error) bool { - // First ensure the number of bytes read is the expected value. The - // byes read should be accurate even when an error occurs. - if n != wantN { - t.Errorf("%s: unexpected num bytes read - got: %v want: %v\n", - name, n, wantN) - return false + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrBadEnumValue { + t.Errorf("expected ErrBadEnumValue, got %v", unmarshalErr.ErrorCode) + } + }) } - - // Next check for the expected error. - return assertError(t, name, err, wantErr) } -func assertError(t *testing.T, name string, err, wantErr error) bool { - if reflect.TypeOf(err) != reflect.TypeOf(wantErr) { - t.Errorf("%s: failed to detect error - got: %v <%[2]T> want: %T", - name, err, wantErr) - return false +// TestDecoder_InvalidEnum tests that invalid enum values cause errors +func TestDecoder_InvalidEnum(t *testing.T) { + validEnums := map[int32]bool{ + 0: true, + 1: true, + 2: true, } - if rerr, ok := err.(*UnmarshalError); ok { - if werr, ok := wantErr.(*UnmarshalError); ok { - if rerr.ErrorCode != werr.ErrorCode { - t.Errorf("%s: failed to detect error code - "+ - "got: %v want: %v", name, - rerr.ErrorCode, werr.ErrorCode) - return false - } - } + tests := []struct { + name string + data []byte + }{ + {"enum value 3", []byte{0x00, 0x00, 0x00, 0x03}}, + {"enum value -1", []byte{0xFF, 0xFF, 0xFF, 0xFF}}, + {"enum value 100", []byte{0x00, 0x00, 0x00, 0x64}}, } - return true -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + _, _, err := d.DecodeEnum(validEnums) + if err == nil { + t.Fatal("expected error, got nil") + } -// TestUnmarshal ensures the Unmarshal function works properly with all types. -func TestUnmarshal(t *testing.T) { - // Variables for various unsupported Unmarshal types. - var nilInterface interface{} - var testChan chan int - var testFunc func() - var testComplex64 complex64 - var testComplex128 complex128 - - // structTestIn is input data for the big struct test of all supported - // types. - structTestIn := []byte{ - 0x00, 0x00, 0x00, 0x7F, // A - 0x00, 0x00, 0x00, 0xFF, // B - 0x00, 0x00, 0x7F, 0xFF, // C - 0x00, 0x00, 0xFF, 0xFF, // D - 0x7F, 0xFF, 0xFF, 0xFF, // E - 0xFF, 0xFF, 0xFF, 0xFF, // F - 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // G - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // H - 0x00, 0x00, 0x00, 0x01, // I - 0x40, 0x48, 0xF5, 0xC3, // J - 0x40, 0x09, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18, // K - 0x00, 0x00, 0x00, 0x03, 0x78, 0x64, 0x72, 0x00, // L - 0x00, 0x00, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // M - 0x01, 0x02, 0x03, 0x00, // N - 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x02, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, // O - 0x00, 0x00, 0x00, 0x03, 0x6F, 0x6E, 0x65, 0x00, // P[0].A - 0x00, 0x00, 0x00, 0x01, // P[0].B - 0x00, 0x00, 0x00, 0x03, 0x74, 0x77, 0x6F, 0x00, // P[1].A - 0x00, 0x00, 0x00, 0x02, // P[1].B - 0x00, 0x00, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, // Q.A - 0x00, 0x00, 0x00, 0x03, // Q.B - 0x00, 0x00, 0x00, 0x02, // R length - 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x31, // R key map1 - 0x00, 0x00, 0x00, 0x01, // R value map1 - 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x32, // R key map2 - 0x00, 0x00, 0x00, 0x02, // R value map2 - 0x00, 0x00, 0x00, 0x14, 0x32, 0x30, 0x31, 0x34, - 0x2d, 0x30, 0x34, 0x2d, 0x30, 0x34, 0x54, 0x30, - 0x33, 0x3a, 0x32, 0x34, 0x3a, 0x34, 0x38, 0x5a, // S - } - - // structTestWant is the expected output after unmarshalling - // structTestIn. - structTestWant := allTypesTest{ - 127, // A - 255, // B - 32767, // C - 65535, // D - 2147483647, // E - 4294967295, // F - 9223372036854775807, // G - 18446744073709551615, // H - true, // I - 3.14, // J - 3.141592653589793, // K - "xdr", // L - []byte{1, 2, 3, 4}, // M - [3]byte{1, 2, 3}, // N - []int16{512, 1024, 2048}, // O - [2]subTest{{"one", 1}, {"two", 2}}, // P - subTest{"bar", 3}, // Q - map[string]uint32{"map1": 1, "map2": 2}, // R - time.Unix(1396581888, 0).UTC(), // S + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrBadEnumValue { + t.Errorf("expected ErrBadEnumValue, got %v", unmarshalErr.ErrorCode) + } + }) } +} +// TestDecoder_MaxSizeExceeded tests that exceeding max size causes errors +func TestDecoder_MaxSizeExceeded(t *testing.T) { tests := []struct { - in []byte // input bytes - wantVal interface{} // expected value - wantN int // expected number of bytes read - err error // expected error + name string + data []byte + maxSize int + decode func(d *Decoder, maxSize int) error }{ - // int8 - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, int8(0), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x40}, int8(64), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x7F}, int8(127), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, int8(-1), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0x80}, int8(-128), 4, nil}, - // Expected Failures -- 128, -129 overflow int8 and not enough - // bytes - {[]byte{0x00, 0x00, 0x00, 0x80}, int8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0xFF, 0x7F}, int8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, int8(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint8 - XDR Unsigned Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, uint8(0), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x40}, uint8(64), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0xFF}, uint8(255), 4, nil}, - // Expected Failures -- 256, -1 overflow uint8 and not enough - // bytes - {[]byte{0x00, 0x00, 0x01, 0x00}, uint8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, uint8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, uint8(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // int16 - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, int16(0), 4, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00}, int16(1024), 4, nil}, - {[]byte{0x00, 0x00, 0x7F, 0xFF}, int16(32767), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, int16(-1), 4, nil}, - {[]byte{0xFF, 0xFF, 0x80, 0x00}, int16(-32768), 4, nil}, - // Expected Failures -- 32768, -32769 overflow int16 and not - // enough bytes - {[]byte{0x00, 0x00, 0x80, 0x00}, int16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0x7F, 0xFF}, int16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, uint16(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint16 - XDR Unsigned Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, uint16(0), 4, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00}, uint16(1024), 4, nil}, - {[]byte{0x00, 0x00, 0xFF, 0xFF}, uint16(65535), 4, nil}, - // Expected Failures -- 65536, -1 overflow uint16 and not enough - // bytes - {[]byte{0x00, 0x01, 0x00, 0x00}, uint16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, uint16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, uint16(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // int32 - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, int32(0), 4, nil}, - {[]byte{0x00, 0x04, 0x00, 0x00}, int32(262144), 4, nil}, - {[]byte{0x7F, 0xFF, 0xFF, 0xFF}, int32(2147483647), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, int32(-1), 4, nil}, - {[]byte{0x80, 0x00, 0x00, 0x00}, int32(-2147483648), 4, nil}, - // Expected Failure -- not enough bytes - {[]byte{0x00, 0x00, 0x00}, int32(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint32 - XDR Unsigned Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, uint32(0), 4, nil}, - {[]byte{0x00, 0x04, 0x00, 0x00}, uint32(262144), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, uint32(4294967295), 4, nil}, - // Expected Failure -- not enough bytes - {[]byte{0x00, 0x00, 0x00}, uint32(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // int64 - XDR Hyper Integer - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(0), 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, int64(1 << 34), 8, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(1 << 42), 8, nil}, - {[]byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(9223372036854775807), 8, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(-1), 8, nil}, - {[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(-9223372036854775808), 8, nil}, - // Expected Failures -- not enough bytes - {[]byte{0x7f, 0xff, 0xff}, int64(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x7f, 0x00, 0xff, 0x00}, int64(0), 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint64 - XDR Unsigned Hyper Integer - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(0), 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 34), 8, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 42), 8, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, uint64(18446744073709551615), 8, nil}, - {[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(9223372036854775808), 8, nil}, - // Expected Failures -- not enough bytes - {[]byte{0xff, 0xff, 0xff}, uint64(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xff, 0x00, 0xff, 0x00}, uint64(0), 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // bool - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, false, 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x01}, true, 4, nil}, - // Expected Failures -- only 0 or 1 is a valid bool - {[]byte{0x01, 0x00, 0x00, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {[]byte{0x00, 0x00, 0x40, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - - // float32 - XDR Floating-Point - {[]byte{0x00, 0x00, 0x00, 0x00}, float32(0), 4, nil}, - {[]byte{0x40, 0x48, 0xF5, 0xC3}, float32(3.14), 4, nil}, - {[]byte{0x49, 0x96, 0xB4, 0x38}, float32(1234567.0), 4, nil}, - {[]byte{0xFF, 0x80, 0x00, 0x00}, float32(math.Inf(-1)), 4, nil}, - {[]byte{0x7F, 0x80, 0x00, 0x00}, float32(math.Inf(0)), 4, nil}, - // Expected Failures -- not enough bytes - {[]byte{0xff, 0xff}, float32(0), 2, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xff, 0x00, 0xff}, float32(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // float64 - XDR Double-precision Floating-Point - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(0), 8, nil}, - {[]byte{0x40, 0x09, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18}, float64(3.141592653589793), 8, nil}, - {[]byte{0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(-1)), 8, nil}, - {[]byte{0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(0)), 8, nil}, - // Expected Failures -- not enough bytes - {[]byte{0xff, 0xff, 0xff}, float64(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xff, 0x00, 0xff, 0x00}, float64(0), 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // string - XDR String - {[]byte{0x00, 0x00, 0x00, 0x00}, "", 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x03, 0x78, 0x64, 0x72, 0x00}, "xdr", 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x06, 0xCF, 0x84, 0x3D, 0x32, 0xCF, 0x80, 0x00, 0x00}, "τ=2π", 12, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and len larger than available bytes. - {[]byte{0x00, 0x00, 0xFF}, "", 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // []byte - XDR Variable Opaque - {[]byte{0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00}, []byte{0x01}, 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03, 0x00}, []byte{0x01, 0x02, 0x03}, 8, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and data larger than available bytes. - {[]byte{0x00, 0x00, 0xFF}, []byte{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // [#]byte - XDR Fixed Opaque - {[]byte{0x01, 0x00, 0x00, 0x00}, [1]byte{0x01}, 4, nil}, - {[]byte{0x01, 0x02, 0x00, 0x00}, [2]byte{0x01, 0x02}, 4, nil}, - {[]byte{0x01, 0x02, 0x03, 0x00}, [3]byte{0x01, 0x02, 0x03}, 4, nil}, - {[]byte{0x01, 0x02, 0x03, 0x04}, [4]byte{0x01, 0x02, 0x03, 0x04}, 4, nil}, - {[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x00}, [5]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 8, nil}, - // Expected Failure -- fixed opaque data not padded - {[]byte{0x01}, [1]byte{}, 1, &UnmarshalError{ErrorCode: ErrIO}}, - - // [] - XDR Variable-Length Array - {[]byte{0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00}, - []int16{512, 1024, 2048}, 16, nil}, - {[]byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, []bool{true, false}, 12, nil}, - // Expected Failure -- 2 entries in array - not enough bytes - {[]byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}, []bool{}, 8, &UnmarshalError{ErrorCode: ErrIO}}, - - // [#] - XDR Fixed-Length Array - {[]byte{0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00}, [2]uint32{512, 1024}, 8, nil}, - // Expected Failure -- 2 entries in array - not enough bytes - {[]byte{0x00, 0x00, 0x00, 0x02}, [2]uint32{}, 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // map[string]uint32 - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x31, 0x00, 0x00, 0x00, 0x01}, - map[string]uint32{"map1": 1}, 16, nil}, - // Expected Failures -- not enough bytes in length, 1 map - // element no extra bytes, 1 map element not enough bytes for - // key, 1 map element not enough bytes for value. - {[]byte{0x00, 0x00, 0x00}, map[string]uint32{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x01}, map[string]uint32{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, map[string]uint32{}, 7, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x31}, - map[string]uint32{}, 12, &UnmarshalError{ErrorCode: ErrIO}}, - - // time.Time - XDR String per RFC3339 - {[]byte{ - 0x00, 0x00, 0x00, 0x14, 0x32, 0x30, 0x31, 0x34, - 0x2d, 0x30, 0x34, 0x2d, 0x30, 0x34, 0x54, 0x30, - 0x33, 0x3a, 0x32, 0x34, 0x3a, 0x34, 0x38, 0x5a, - }, time.Unix(1396581888, 0).UTC(), 24, nil}, - // Expected Failures -- not enough bytes, improperly formatted - // time - {[]byte{0x00, 0x00, 0x00}, time.Time{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x00}, time.Time{}, 4, &UnmarshalError{ErrorCode: ErrParseTime}}, - - // struct - XDR Structure -- test struct contains all supported types - {structTestIn, structTestWant, len(structTestIn), nil}, - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02}, - opaqueStruct{[]uint8{1}, [1]uint8{2}}, 12, nil}, - // Expected Failures -- normal struct not enough bytes, non - // opaque data not enough bytes for slice, non opaque data not - // enough bytes for slice. - {[]byte{0x00, 0x00}, allTypesTest{}, 2, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00}, opaqueStruct{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00}, opaqueStruct{}, 5, &UnmarshalError{ErrorCode: ErrIO}}, - - // Expected errors - {nil, nilInterface, 0, &UnmarshalError{ErrorCode: ErrNilInterface}}, - {nil, &nilInterface, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testChan, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testChan, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testFunc, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testFunc, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testComplex64, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testComplex64, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testComplex128, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testComplex128, 0, &UnmarshalError{ErrorCode: ErrIO}}, - } - - for i, test := range tests { - // Attempt to unmarshal to a non-pointer version of each - // positive test type to ensure the appropriate error is - // returned. - if test.err == nil && test.wantVal != nil { - testName := fmt.Sprintf("Unmarshal #%d (non-pointer)", i) - wantErr := &UnmarshalError{ErrorCode: ErrBadArguments} - - wvt := reflect.TypeOf(test.wantVal) - want := reflect.New(wvt).Elem().Interface() - n, err := Unmarshal(bytes.NewReader(test.in), want) - if !testExpectedURet(t, testName, n, 0, err, wantErr) { - continue + { + name: "DecodeOpaque exceeds maxSize", + data: []byte{0x00, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}, + maxSize: 8, // Only allow 8 bytes + decode: func(d *Decoder, maxSize int) error { + _, _, err := d.DecodeOpaque(maxSize) + return err + }, + }, + { + name: "DecodeString exceeds maxSize", + data: []byte{0x00, 0x00, 0x00, 0x08, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'}, + maxSize: 4, // Only allow 4 chars + decode: func(d *Decoder, maxSize int) error { + _, _, err := d.DecodeString(maxSize) + return err + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + err := tt.decode(d, tt.maxSize) + if err == nil { + t.Fatal("expected error, got nil") + } + + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrOverflow { + t.Errorf("expected ErrOverflow, got %v", unmarshalErr.ErrorCode) } + }) + } +} + +// TestDecoder_EmptyCases tests edge cases with empty/zero values +func TestDecoder_EmptyCases(t *testing.T) { + t.Run("DecodeFixedOpaque size=0", func(t *testing.T) { + d := NewDecoder([]byte{}) + data, n, err := d.DecodeFixedOpaque(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(data) != 0 { + t.Errorf("expected empty slice, got %v", data) } + if n != 0 { + t.Errorf("expected 0 bytes read, got %d", n) + } + }) - testName := fmt.Sprintf("Unmarshal #%d", i) - // Create a new pointer to the appropriate type. - var want interface{} - if test.wantVal != nil { - wvt := reflect.TypeOf(test.wantVal) - want = reflect.New(wvt).Interface() + t.Run("DecodeFixedOpaqueInplace size=0", func(t *testing.T) { + d := NewDecoder([]byte{}) + n, err := d.DecodeFixedOpaqueInplace([]byte{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - n, err := Unmarshal(bytes.NewReader(test.in), want) + if n != 0 { + t.Errorf("expected 0 bytes read, got %d", n) + } + }) - // First ensure the number of bytes read is the expected value - // and the error is the expected one. - if !testExpectedURet(t, testName, n, test.wantN, err, test.err) { - continue + t.Run("DecodeString empty", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x00}) // length = 0 + s, n, err := d.DecodeString(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "" { + t.Errorf("expected empty string, got %q", s) } - if test.err != nil { - continue + if n != 4 { + t.Errorf("expected 4 bytes read (length prefix), got %d", n) } + }) +} - // Finally, ensure the read value is the expected one. - wantElem := reflect.Indirect(reflect.ValueOf(want)).Interface() - if !reflect.DeepEqual(wantElem, test.wantVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, wantElem, test.wantVal) - continue +// TestDecoder_SuccessfulDecodes tests that valid data decodes correctly +func TestDecoder_SuccessfulDecodes(t *testing.T) { + t.Run("DecodeInt", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x2A}) // 42 + v, n, err := d.DecodeInt() + if err != nil { + t.Fatalf("unexpected error: %v", err) } - } + if v != 42 { + t.Errorf("expected 42, got %d", v) + } + if n != 4 { + t.Errorf("expected 4 bytes, got %d", n) + } + }) - // successful enum decoding - var anEnum AnEnum - _, err := Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x01}), &anEnum) + t.Run("DecodeInt negative", func(t *testing.T) { + d := NewDecoder([]byte{0xFF, 0xFF, 0xFF, 0xFE}) // -2 + v, _, err := d.DecodeInt() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != -2 { + t.Errorf("expected -2, got %d", v) + } + }) - if err != nil { - t.Errorf("enum decode: expected no error, got: %v\n", err) - } + t.Run("DecodeBool true", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x01}) + v, _, err := d.DecodeBool() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !v { + t.Error("expected true") + } + }) - if anEnum != AnEnum(1) { - t.Errorf("enum decode: expected 1, got: %v\n", anEnum) - } + t.Run("DecodeBool false", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x00}) + v, _, err := d.DecodeBool() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v { + t.Error("expected false") + } + }) - // failed enum decoding - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x03}), &anEnum) + t.Run("DecodeEnum valid", func(t *testing.T) { + validEnums := map[int32]bool{1: true, 2: true, 3: true} + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x02}) + v, _, err := d.DecodeEnum(validEnums) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != 2 { + t.Errorf("expected 2, got %d", v) + } + }) - if err == nil { - t.Errorf("enum decode: expected error, got none") - } + t.Run("DecodeString with padding", func(t *testing.T) { + // "abc" = 3 bytes, needs 1 byte padding + d := NewDecoder([]byte{ + 0x00, 0x00, 0x00, 0x03, // length = 3 + 'a', 'b', 'c', 0x00, // data + padding + }) + s, n, err := d.DecodeString(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "abc" { + t.Errorf("expected 'abc', got %q", s) + } + if n != 8 { // 4 (length) + 4 (padded data) + t.Errorf("expected 8 bytes, got %d", n) + } + }) - // union decoding - var u aUnion - // void arm - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x02}), &u) - if err != nil { - t.Errorf("union decode: expected no error, got: %v\n", err) - } + t.Run("DecodeFixedOpaque with padding", func(t *testing.T) { + // 5 bytes of data, needs 3 bytes padding + d := NewDecoder([]byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x00, + }) + data, n, err := d.DecodeFixedOpaque(5) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(data) != 5 { + t.Errorf("expected 5 bytes, got %d", len(data)) + } + if n != 8 { // 5 bytes padded to 8 + t.Errorf("expected 8 bytes read, got %d", n) + } + }) +} - if u.Type != AnEnum(2) { - t.Errorf("union decode: expected type == 2, got: %v\n", u.Type) +// TestDecoder_Position tests position tracking +func TestDecoder_Position(t *testing.T) { + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // int32 = 1 + 0x00, 0x00, 0x00, 0x02, // int32 = 2 } + d := NewDecoder(data) - if u.Data != nil { - t.Errorf("union decode: expected data to be nil, it was not.") + if d.Position() != 0 { + t.Errorf("expected position 0, got %d", d.Position()) } - - // non-void arm - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05}), &u) - if err != nil { - t.Errorf("union decode: expected no error, got: %v\n", err) + if d.Remaining() != 8 { + t.Errorf("expected 8 remaining, got %d", d.Remaining()) } - if u.Type != AnEnum(0) { - t.Errorf("union decode: expected type == 0, got: %v\n", u.Type) + d.DecodeInt() + if d.Position() != 4 { + t.Errorf("expected position 4, got %d", d.Position()) } - - if u.Data == nil { - t.Errorf("union decode: expected data to be filled, it was not.") + if d.Remaining() != 4 { + t.Errorf("expected 4 remaining, got %d", d.Remaining()) } - if *u.Data != 5 { - t.Errorf("union decode: expected data to be 5, got: %v\n", *u.Data) + d.DecodeInt() + if d.Position() != 8 { + t.Errorf("expected position 8, got %d", d.Position()) } - - // non-void arm: xdrmaxsize - _, err = Unmarshal(bytes.NewReader([]byte{ - 0x00, 0x00, 0x00, 0x01, // Text - 0x00, 0x00, 0x00, 0x1D, // String length = 29 - 0x74, 0x65, 0x73, 0x74, // "test" 4 - 0x74, 0x65, 0x73, 0x74, // "test" 8 - 0x74, 0x65, 0x73, 0x74, // "test" 12 - 0x74, 0x65, 0x73, 0x74, // "test" 16 - 0x74, 0x65, 0x73, 0x74, // "test" 20 - 0x74, 0x65, 0x73, 0x74, // "test" 24 - 0x74, 0x65, 0x73, 0x74, // "test" 28 - 0x74, 0x00, 0x00, 0x00, // "test" 29 - }), &u) - if err == nil { - t.Errorf("union decode: expected error") + if d.Remaining() != 0 { + t.Errorf("expected 0 remaining, got %d", d.Remaining()) } +} + +// TestDecoder_Reset tests the Reset method +func TestDecoder_Reset(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x01}) + d.DecodeInt() - if err.Error() != "xdr:DecodeString: data exceeds max slice limit - read: '29'" { - t.Errorf("union decode: expected 'data exceeds max slice limit' error") + if d.Position() != 4 { + t.Errorf("expected position 4, got %d", d.Position()) } - // invalid enum for switch - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x03}), &u) - if err == nil { - t.Errorf("union decode: expected error, got nil") + // Reset with new data + d.Reset([]byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03}) + if d.Position() != 0 { + t.Errorf("expected position 0 after reset, got %d", d.Position()) + } + if d.Remaining() != 8 { + t.Errorf("expected 8 remaining after reset, got %d", d.Remaining()) } - // invalid arm for switch - _, err = Unmarshal(bytes.NewReader([]byte{0xFF, 0xFF, 0xFF, 0xFF}), &u) - if err == nil { - t.Errorf("union decode: expected error, got nil") + v, _, _ := d.DecodeInt() + if v != 2 { + t.Errorf("expected 2, got %d", v) } } -// decodeFunc is used to identify which public function of the Decoder object -// a test applies to. -type decodeFunc int - -const ( - fDecodeBool decodeFunc = iota - fDecodeDouble - fDecodeEnum - fDecodeFixedOpaque - fDecodeFloat - fDecodeHyper - fDecodeInt - fDecodeOpaque - fDecodeString - fDecodeUhyper - fDecodeUint -) - -// Map of decodeFunc values to names for pretty printing. -var decodeFuncStrings = map[decodeFunc]string{ - fDecodeBool: "DecodeBool", - fDecodeDouble: "DecodeDouble", - fDecodeEnum: "DecodeEnum", - fDecodeFixedOpaque: "DecodeFixedOpaque", - fDecodeFloat: "DecodeFloat", - fDecodeHyper: "DecodeHyper", - fDecodeInt: "DecodeInt", - fDecodeOpaque: "DecodeOpaque", - fDecodeString: "DecodeString", - fDecodeUhyper: "DecodeUhyper", - fDecodeUint: "DecodeUint", +// testDecoderFromType is a test type that implements DecoderFrom +type testDecoderFromType struct { + Value int32 } -// String implements the fmt.Stringer interface and returns the encode function -// as a human-readable string. -func (f decodeFunc) String() string { - if s := decodeFuncStrings[f]; s != "" { - return s +func (t *testDecoderFromType) DecodeFrom(d *Decoder, maxDepth uint) (int, error) { + v, n, err := d.DecodeInt() + if err != nil { + return n, err } - return fmt.Sprintf("Unknown decodeFunc (%d)", f) + t.Value = v + return n, nil } -// TestDecoder ensures a Decoder works as intended. -func TestDecoder(t *testing.T) { - tests := []struct { - f decodeFunc // function to use to decode - in []byte // input bytes - wantVal interface{} // expected value - wantN int // expected number of bytes read - err error // expected error - }{ - // Bool - {fDecodeBool, []byte{0x00, 0x00, 0x00, 0x00}, false, 4, nil}, - {fDecodeBool, []byte{0x00, 0x00, 0x00, 0x01}, true, 4, nil}, - // Expected Failures -- only 0 or 1 is a valid bool - {fDecodeBool, []byte{0x01, 0x00, 0x00, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {fDecodeBool, []byte{0x00, 0x00, 0x40, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - - // Double - {fDecodeDouble, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(0), 8, nil}, - {fDecodeDouble, []byte{0x40, 0x09, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18}, float64(3.141592653589793), 8, nil}, - {fDecodeDouble, []byte{0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(-1)), 8, nil}, - {fDecodeDouble, []byte{0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(0)), 8, nil}, - - // Enum - {fDecodeEnum, []byte{0x00, 0x00, 0x00, 0x00}, int32(0), 4, nil}, - {fDecodeEnum, []byte{0x00, 0x00, 0x00, 0x01}, int32(1), 4, nil}, - {fDecodeEnum, []byte{0x00, 0x00, 0x00, 0x02}, nil, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {fDecodeEnum, []byte{0x12, 0x34, 0x56, 0x78}, nil, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {fDecodeEnum, []byte{0x00}, nil, 1, &UnmarshalError{ErrorCode: ErrIO}}, - - // FixedOpaque - {fDecodeFixedOpaque, []byte{0x01, 0x00, 0x00, 0x00}, []byte{0x01}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x00, 0x00}, []byte{0x01, 0x02}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x03, 0x00}, []byte{0x01, 0x02, 0x03}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x00}, []byte{0x01, 0x02, 0x03, 0x04, 0x05}, 8, nil}, - // Expected Failure -- fixed opaque data not padded - {fDecodeFixedOpaque, []byte{0x01}, []byte{0x00}, 1, &UnmarshalError{ErrorCode: ErrIO}}, - - // Float - {fDecodeFloat, []byte{0x00, 0x00, 0x00, 0x00}, float32(0), 4, nil}, - {fDecodeFloat, []byte{0x40, 0x48, 0xF5, 0xC3}, float32(3.14), 4, nil}, - {fDecodeFloat, []byte{0x49, 0x96, 0xB4, 0x38}, float32(1234567.0), 4, nil}, - {fDecodeFloat, []byte{0xFF, 0x80, 0x00, 0x00}, float32(math.Inf(-1)), 4, nil}, - {fDecodeFloat, []byte{0x7F, 0x80, 0x00, 0x00}, float32(math.Inf(0)), 4, nil}, - - // Hyper - {fDecodeHyper, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(0), 8, nil}, - {fDecodeHyper, []byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, int64(1 << 34), 8, nil}, - {fDecodeHyper, []byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(1 << 42), 8, nil}, - {fDecodeHyper, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(9223372036854775807), 8, nil}, - {fDecodeHyper, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(-1), 8, nil}, - {fDecodeHyper, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(-9223372036854775808), 8, nil}, - - // Int - {fDecodeInt, []byte{0x00, 0x00, 0x00, 0x00}, int32(0), 4, nil}, - {fDecodeInt, []byte{0x00, 0x04, 0x00, 0x00}, int32(262144), 4, nil}, - {fDecodeInt, []byte{0x7F, 0xFF, 0xFF, 0xFF}, int32(2147483647), 4, nil}, - {fDecodeInt, []byte{0xFF, 0xFF, 0xFF, 0xFF}, int32(-1), 4, nil}, - {fDecodeInt, []byte{0x80, 0x00, 0x00, 0x00}, int32(-2147483648), 4, nil}, - - // Opaque - {fDecodeOpaque, []byte{0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00}, []byte{0x01}, 8, nil}, - {fDecodeOpaque, []byte{0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03, 0x00}, []byte{0x01, 0x02, 0x03}, 8, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and data larger than available bytes. - {fDecodeOpaque, []byte{0x00, 0x00, 0xFF}, []byte{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {fDecodeOpaque, []byte{0xFF, 0xFF, 0xFF, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {fDecodeOpaque, []byte{0x7F, 0xFF, 0xFF, 0xFD}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {fDecodeOpaque, []byte{0x00, 0x00, 0x00, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // String - {fDecodeString, []byte{0x00, 0x00, 0x00, 0x00}, "", 4, nil}, - {fDecodeString, []byte{0x00, 0x00, 0x00, 0x03, 0x78, 0x64, 0x72, 0x00}, "xdr", 8, nil}, - {fDecodeString, []byte{0x00, 0x00, 0x00, 0x06, 0xCF, 0x84, 0x3D, 0x32, 0xCF, 0x80, 0x00, 0x00}, "τ=2π", 12, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and len larger than available bytes. - {fDecodeString, []byte{0x00, 0x00, 0xFF}, "", 3, &UnmarshalError{ErrorCode: ErrIO}}, - {fDecodeString, []byte{0xFF, 0xFF, 0xFF, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {fDecodeString, []byte{0x00, 0x00, 0x00, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // Uhyper - {fDecodeUhyper, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(0), 8, nil}, - {fDecodeUhyper, []byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 34), 8, nil}, - {fDecodeUhyper, []byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 42), 8, nil}, - {fDecodeUhyper, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, uint64(18446744073709551615), 8, nil}, - {fDecodeUhyper, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(9223372036854775808), 8, nil}, - - // Uint - {fDecodeUint, []byte{0x00, 0x00, 0x00, 0x00}, uint32(0), 4, nil}, - {fDecodeUint, []byte{0x00, 0x04, 0x00, 0x00}, uint32(262144), 4, nil}, - {fDecodeUint, []byte{0xFF, 0xFF, 0xFF, 0xFF}, uint32(4294967295), 4, nil}, - } - - validEnums := make(map[int32]bool) - validEnums[0] = true - validEnums[1] = true - - var rv interface{} - var n int - var err error - - for i, test := range tests { - err = nil - dec := NewDecoder(bytes.NewReader(test.in)) - switch test.f { - case fDecodeBool: - rv, n, err = dec.DecodeBool() - case fDecodeDouble: - rv, n, err = dec.DecodeDouble() - case fDecodeEnum: - rv, n, err = dec.DecodeEnum(validEnums) - case fDecodeFixedOpaque: - want := test.wantVal.([]byte) - rv, n, err = dec.DecodeFixedOpaque(int32(len(want))) - case fDecodeFloat: - rv, n, err = dec.DecodeFloat() - case fDecodeHyper: - rv, n, err = dec.DecodeHyper() - case fDecodeInt: - rv, n, err = dec.DecodeInt() - case fDecodeOpaque: - rv, n, err = dec.DecodeOpaque(0) - case fDecodeString: - rv, n, err = dec.DecodeString(0) - case fDecodeUhyper: - rv, n, err = dec.DecodeUhyper() - case fDecodeUint: - rv, n, err = dec.DecodeUint() - default: - t.Errorf("%v #%d unrecognized function", test.f, i) - continue - } - - // First ensure the number of bytes read is the expected value - // and the error is the expected one. - testName := fmt.Sprintf("%v #%d", test.f, i) - if !testExpectedURet(t, testName, n, test.wantN, err, test.err) { - continue - } - if test.err != nil { - continue - } +// testNestedType is a test type that tracks maxDepth to verify it's accessible +type testNestedType struct { + ReceivedMaxDepth uint + Value int32 +} - // Finally, ensure the read value is the expected one. - if !reflect.DeepEqual(rv, test.wantVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, rv, test.wantVal) - continue - } +func (t *testNestedType) DecodeFrom(d *Decoder, maxDepth uint) (int, error) { + t.ReceivedMaxDepth = d.MaxDepth() + v, n, err := d.DecodeInt() + if err != nil { + return n, err } + t.Value = v + return n, nil } -// TestUnmarshalCorners ensures the Unmarshal function properly handles various -// cases not already covered by the other tests. -func TestUnmarshalCorners(t *testing.T) { - buf := []byte{ - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x02, - } - - // Ensure unmarshal to unsettable pointer returns the expected error. - testName := "Unmarshal to unsettable pointer" - var i32p *int32 - expectedN := 0 - expectedErr := error(&UnmarshalError{ErrorCode: ErrNotSettable}) - n, err := Unmarshal(bytes.NewReader(buf), i32p) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure decode of unsettable pointer returns the expected error. - testName = "Decode to unsettable pointer" - expectedN = 4 - expectedErr = &UnmarshalError{ErrorCode: ErrNotSettable} - n, err = TstDecode(bytes.NewReader(buf))(reflect.ValueOf(i32p), 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to indirected unsettable pointer returns the - // expected error. - testName = "Unmarshal to indirected unsettable pointer" - ii32p := interface{}(i32p) - expectedN = 0 - expectedErr = &UnmarshalError{ErrorCode: ErrNotSettable} - n, err = Unmarshal(bytes.NewReader(buf), &ii32p) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to embedded unsettable interface value returns the - // expected error. - testName = "Unmarshal to embedded unsettable interface value" - var i32 int32 - ii32 := interface{}(i32) - expectedN = 0 - expectedErr = &UnmarshalError{ErrorCode: ErrNotSettable} - n, err = Unmarshal(bytes.NewReader(buf), &ii32) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to embedded interface value works properly. - testName = "Unmarshal to embedded interface value" - ii32vp := interface{}(&i32) - expectedN = 4 - expectedErr = nil - ii32vpr := int32(1) - expectedVal := interface{}(&ii32vpr) - n, err = Unmarshal(bytes.NewReader(buf), &ii32vp) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(ii32vp, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, ii32vp, expectedVal) - } - } +// TestUnmarshal tests the Unmarshal function with a type implementing DecoderFrom +func TestUnmarshal(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x2A} // 42 in big-endian - // Ensure decode of an invalid reflect value returns the expected - // error. - testName = "Decode invalid reflect value" - expectedN = 0 - expectedErr = error(&UnmarshalError{ErrorCode: ErrUnsupportedType}) - n, err = TstDecode(bytes.NewReader(buf))(reflect.Value{}, 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to a slice with a cap and 0 length adjusts the - // length properly. - testName = "Unmarshal to capped slice" - cappedSlice := make([]bool, 0, 1) - expectedN = 8 - expectedErr = nil - expectedVal = []bool{true} - n, err = Unmarshal(bytes.NewReader(buf), &cappedSlice) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(cappedSlice, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, cappedSlice, expectedVal) - } + var result testDecoderFromType + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) } - - // Ensure decode to a slice retuns expected number of elements - testName = "Unmarshal to oversized slice" - oversizedSlice := make([]bool, 2, 2) - expectedN = 8 - expectedErr = nil - expectedVal = []bool{true} - n, err = Unmarshal(bytes.NewReader(buf), &oversizedSlice) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(oversizedSlice, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, oversizedSlice, expectedVal) - } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) } - - // Ensure unmarshal to struct with both exported and unexported fields - // skips the unexported fields but still unmarshals to the exported - // fields. - type unexportedStruct struct { - unexported int - Exported int - } - testName = "Unmarshal to struct with exported and unexported fields" - var tstruct unexportedStruct - expectedN = 4 - expectedErr = nil - expectedVal = unexportedStruct{0, 1} - n, err = Unmarshal(bytes.NewReader(buf), &tstruct) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(tstruct, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, tstruct, expectedVal) - } + if result.Value != 42 { + t.Errorf("expected value 42, got %d", result.Value) } - - // Ensure decode to struct with unsettable fields return expected error. - type unsettableStruct struct { - Exported int - } - testName = "Decode to struct with unsettable fields" - var ustruct unsettableStruct - expectedN = 0 - expectedErr = error(&UnmarshalError{ErrorCode: ErrNotSettable}) - n, err = TstDecode(bytes.NewReader(buf))(reflect.ValueOf(ustruct), 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure decode to struct with unsettable pointer fields return - // expected error. - type unsettablePointerStruct struct { - Exported *int - } - testName = "Decode to struct with unsettable pointer fields" - var upstruct unsettablePointerStruct - expectedN = 0 - expectedErr = error(&UnmarshalError{ErrorCode: ErrNotSettable}) - n, err = TstDecode(bytes.NewReader(buf))(reflect.ValueOf(upstruct), 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) } -type String32 string - -var _ Sized = String32("hello") - -func (s String32) XDRMaxSize() int { - return 32 +// TestUnmarshal_ReflectionFallback tests Unmarshal with types that don't implement DecoderFrom +// but can be decoded via reflection +func TestUnmarshal_ReflectionFallback(t *testing.T) { + // int32 doesn't implement DecoderFrom but can be decoded via reflection + data := []byte{0x00, 0x00, 0x00, 0x01} + var result int32 + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("expected reflection decode to work, got error: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result != 1 { + t.Errorf("expected result 1, got %d", result) + } } -func TestSizedType(t *testing.T) { - cases := map[string]struct { - input []byte - out string - err error +// TestUnmarshal_ReflectionAllTypes tests reflection-based decoding for all primitive types +func TestUnmarshal_ReflectionAllTypes(t *testing.T) { + tests := []struct { + name string + data []byte + dest interface{} + expected interface{} + wantN int }{ - // works for 0 length - "0 length": {[]byte{0x00, 0x00, 0x00, 0x00}, "", nil}, - // works for 1 length - "1 length": {[]byte{0x00, 0x00, 0x00, 0x01, 0x48, 0x00, 0x00, 0x00}, "H", nil}, - // works for 32 length - "32 length": { - []byte{ - 0x00, 0x00, 0x00, 0x20, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - }, - "HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH", - nil, + { + name: "int8", + data: []byte{0x00, 0x00, 0x00, 0x7F}, + dest: new(int8), + expected: int8(127), + wantN: 4, }, - // fails for 33 length - "33 length": { - []byte{ - 0x00, 0x00, 0x00, 0x21, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x00, 0x00, 0x00, - }, - "", - &UnmarshalError{ErrorCode: ErrOverflow}, + { + name: "int16", + data: []byte{0x00, 0x00, 0x7F, 0xFF}, + dest: new(int16), + expected: int16(32767), + wantN: 4, + }, + { + name: "int32", + data: []byte{0x7F, 0xFF, 0xFF, 0xFF}, + dest: new(int32), + expected: int32(2147483647), + wantN: 4, + }, + { + name: "int64", + data: []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + dest: new(int64), + expected: int64(9223372036854775807), + wantN: 8, + }, + { + name: "uint8", + data: []byte{0x00, 0x00, 0x00, 0xFF}, + dest: new(uint8), + expected: uint8(255), + wantN: 4, + }, + { + name: "uint16", + data: []byte{0x00, 0x00, 0xFF, 0xFF}, + dest: new(uint16), + expected: uint16(65535), + wantN: 4, + }, + { + name: "uint32", + data: []byte{0xFF, 0xFF, 0xFF, 0xFF}, + dest: new(uint32), + expected: uint32(4294967295), + wantN: 4, + }, + { + name: "uint64", + data: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + dest: new(uint64), + expected: uint64(18446744073709551615), + wantN: 8, + }, + { + name: "bool true", + data: []byte{0x00, 0x00, 0x00, 0x01}, + dest: new(bool), + expected: true, + wantN: 4, + }, + { + name: "bool false", + data: []byte{0x00, 0x00, 0x00, 0x00}, + dest: new(bool), + expected: false, + wantN: 4, + }, + { + name: "string", + data: []byte{0x00, 0x00, 0x00, 0x03, 'x', 'd', 'r', 0x00}, + dest: new(string), + expected: "xdr", + wantN: 8, + }, + { + name: "[]byte opaque", + data: []byte{0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03, 0x00}, + dest: new([]byte), + expected: []byte{0x01, 0x02, 0x03}, + wantN: 8, }, } - for name, kase := range cases { - var out String32 - r := bytes.NewReader(kase.input) - _, err := Unmarshal(r, &out) - - if !assertError(t, name, err, kase.err) { - continue - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, err := Unmarshal(tt.data, tt.dest) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != tt.wantN { + t.Errorf("bytes read = %d, want %d", n, tt.wantN) + } - if string(out) != kase.out { - t.Errorf("%s: expected output to be %#v, but got %#v", name, kase.out, out) - continue - } + // Compare values using reflect + got := reflect.ValueOf(tt.dest).Elem().Interface() + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("got %v (%T), want %v (%T)", got, got, tt.expected, tt.expected) + } + }) } } -type sizedField struct { - Val string `xdrmaxsize:"32"` -} +// TestUnmarshal_ReflectionStruct tests reflection-based decoding for structs +func TestUnmarshal_ReflectionStruct(t *testing.T) { + type simpleStruct struct { + A int32 + B string + C bool + } -func TestSizedField(t *testing.T) { - cases := map[string]struct { - input []byte - out string - err error - }{ - // works for 0 length - "0 length": {[]byte{0x00, 0x00, 0x00, 0x00}, "", nil}, - // works for 1 length - "1 length": {[]byte{0x00, 0x00, 0x00, 0x01, 0x48, 0x00, 0x00, 0x00}, "H", nil}, - // works for 32 length - "32 length": { - []byte{ - 0x00, 0x00, 0x00, 0x20, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - }, - "HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH", - nil, - }, - // fails for 33 length - "33 length": { - []byte{ - 0x00, 0x00, 0x00, 0x21, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x00, 0x00, 0x00, - }, - "", - &UnmarshalError{ErrorCode: ErrOverflow}, - }, + // Encoded: A=42, B="hi", C=true + data := []byte{ + 0x00, 0x00, 0x00, 0x2A, // A = 42 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // B = "hi" (padded) + 0x00, 0x00, 0x00, 0x01, // C = true } - for name, kase := range cases { - var out sizedField - r := bytes.NewReader(kase.input) - _, err := Unmarshal(r, &out) + var result simpleStruct + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 16 { + t.Errorf("bytes read = %d, want 16", n) + } + if result.A != 42 { + t.Errorf("A = %d, want 42", result.A) + } + if result.B != "hi" { + t.Errorf("B = %q, want \"hi\"", result.B) + } + if result.C != true { + t.Errorf("C = %v, want true", result.C) + } +} - if !assertError(t, name, err, kase.err) { - continue - } +// TestUnmarshal_ReflectionSlice tests reflection-based decoding for slices +func TestUnmarshal_ReflectionSlice(t *testing.T) { + // Slice of 3 int32s: [1, 2, 3] + data := []byte{ + 0x00, 0x00, 0x00, 0x03, // length = 3 + 0x00, 0x00, 0x00, 0x01, // [0] = 1 + 0x00, 0x00, 0x00, 0x02, // [1] = 2 + 0x00, 0x00, 0x00, 0x03, // [2] = 3 + } - if out.Val != kase.out { - t.Errorf("%s: expected output to be %#v, but got %#v", name, kase.out, out.Val) - continue - } + var result []int32 + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 16 { + t.Errorf("bytes read = %d, want 16", n) + } + expected := []int32{1, 2, 3} + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v, want %v", result, expected) } } -type defEnum int32 +// TestUnmarshal_ReflectionMap tests reflection-based decoding for maps +func TestUnmarshal_ReflectionMap(t *testing.T) { + // Map with 2 entries: {"a": 1, "b": 2} + data := []byte{ + 0x00, 0x00, 0x00, 0x02, // count = 2 + 0x00, 0x00, 0x00, 0x01, 'a', 0x00, 0x00, 0x00, // key "a" + 0x00, 0x00, 0x00, 0x01, // value 1 + 0x00, 0x00, 0x00, 0x01, 'b', 0x00, 0x00, 0x00, // key "b" + 0x00, 0x00, 0x00, 0x02, // value 2 + } -func (e defEnum) ValidEnum(v int32) bool { - return v < 1 + var result map[string]int32 + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 28 { + t.Errorf("bytes read = %d, want 28", n) + } + if result["a"] != 1 { + t.Errorf("result[\"a\"] = %d, want 1", result["a"]) + } + if result["b"] != 2 { + t.Errorf("result[\"b\"] = %d, want 2", result["b"]) + } } -type defUnion struct { - Type defEnum - Data *int32 -} +// TestUnmarshal_ReflectionPointer tests reflection-based decoding for optional (pointer) types +func TestUnmarshal_ReflectionPointer(t *testing.T) { + // Pointer present with value 42 + dataPresent := []byte{ + 0x00, 0x00, 0x00, 0x01, // present = true + 0x00, 0x00, 0x00, 0x2A, // value = 42 + } -func (u defUnion) SwitchFieldName() string { - return "Type" -} + var result1 *int32 + n, err := Unmarshal(dataPresent, &result1) + if err != nil { + t.Fatalf("Unmarshal (present) failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result1 == nil || *result1 != 42 { + t.Errorf("got %v, want pointer to 42", result1) + } -func (u defUnion) ArmForSwitch(sw int32) (string, bool) { - switch sw { - case 0: - return "Data", true + // Pointer absent (nil) + dataAbsent := []byte{ + 0x00, 0x00, 0x00, 0x00, // present = false } - return "-", false + var result2 *int32 + result2 = new(int32) // Pre-allocate to verify it gets set to nil + *result2 = 999 + n, err = Unmarshal(dataAbsent, &result2) + if err != nil { + t.Fatalf("Unmarshal (absent) failed: %v", err) + } + if n != 4 { + t.Errorf("bytes read = %d, want 4", n) + } + if result2 != nil { + t.Errorf("got %v, want nil", result2) + } } -func TestUnion_EnumValidation(t *testing.T) { +// TestUnmarshalWithOptions tests UnmarshalWithOptions with custom MaxDepth +func TestUnmarshalWithOptions(t *testing.T) { + data := []byte{0x00, 0x00, 0x01, 0x00} // 256 in big-endian - var u defUnion - var buf bytes.Buffer + var result testDecoderFromType + opts := DecodeOptions{MaxDepth: 100} + n, err := UnmarshalWithOptions(data, &result, opts) + if err != nil { + t.Fatalf("UnmarshalWithOptions failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result.Value != 256 { + t.Errorf("expected value 256, got %d", result.Value) + } +} - // encode a union with invalid value fails - u.Type = defEnum(3) - _, err := Marshal(&buf, u) +// TestMaxDepthPassed tests that MaxDepth is correctly passed to DecodeFrom +func TestMaxDepthPassed(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} - if err == nil { - t.Errorf("expected error when marshaling invalid enum, got none. result: %#v", buf.Bytes()) + // Test with default options + var result1 testNestedType + _, err := Unmarshal(data, &result1) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if result1.ReceivedMaxDepth != DecodeDefaultMaxDepth { + t.Errorf("expected maxDepth %d, got %d", DecodeDefaultMaxDepth, result1.ReceivedMaxDepth) } - // decoding an invalid enum value into a union results in an error - u = defUnion{} - invalid := []byte{0x0, 0x0, 0x0, 0x3} - _, err = Unmarshal(bytes.NewReader(invalid), &u) - if err == nil { - t.Errorf("expected error when unmarshaling invalid enum into union, got none. result: %#v", u) + // Test with custom MaxDepth + var result2 testNestedType + opts := DecodeOptions{MaxDepth: 50} + _, err = UnmarshalWithOptions(data, &result2, opts) + if err != nil { + t.Fatalf("UnmarshalWithOptions failed: %v", err) + } + if result2.ReceivedMaxDepth != 50 { + t.Errorf("expected maxDepth 50, got %d", result2.ReceivedMaxDepth) } -} -func TestPaddedReads(t *testing.T) { - // regression test for non-zeroed padding + // Test with MaxDepth 0 (should use default) + var result3 testNestedType + opts = DecodeOptions{MaxDepth: 0} + _, err = UnmarshalWithOptions(data, &result3, opts) + if err != nil { + t.Fatalf("UnmarshalWithOptions failed: %v", err) + } + if result3.ReceivedMaxDepth != DecodeDefaultMaxDepth { + t.Errorf("expected maxDepth %d for zero option, got %d", DecodeDefaultMaxDepth, result3.ReceivedMaxDepth) + } +} - // opaque - dec := NewDecoder(bytes.NewReader([]byte{0x0, 0x0, 0x1, 0x1})) - _, _, err := dec.DecodeFixedOpaque(3) +// TestSkip_NegativeLength tests that negative skip length is rejected +func TestSkip_NegativeLength(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x00}) + err := d.Skip(-1) if err == nil { - t.Error("expected error when unmarshaling opaque with non-zero padding byte, got none") + t.Fatal("expected error for negative skip length") + } + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrBadArguments { + t.Errorf("expected ErrBadArguments, got %v", unmarshalErr.ErrorCode) } +} - // string - dec = NewDecoder(bytes.NewReader([]byte{ - 0x0, 0x0, 0x0, 0x1, - 0x1, 0x1, 0x0, 0x0, - })) - _, _, err = dec.DecodeString(3) +// TestDecodeOpaque_LengthOverflow tests that length > maxInt32 is rejected +func TestDecodeOpaque_LengthOverflow(t *testing.T) { + // Encode a length of 0x80000000 (2^31, which is > maxInt32) + data := []byte{0x80, 0x00, 0x00, 0x00} + d := NewDecoder(data) + _, _, err := d.DecodeOpaque(0) if err == nil { - t.Error("expected error when unmarshaling string with non-zero padding byte, got none") + t.Fatal("expected error for length > maxInt32") + } + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) } + if unmarshalErr.ErrorCode != ErrOverflow { + t.Errorf("expected ErrOverflow, got %v", unmarshalErr.ErrorCode) + } +} - // read varopaque - dec = NewDecoder(bytes.NewReader([]byte{ - 0x0, 0x0, 0x0, 0x1, - 0x1, 0x0, 0x0, 0x1, - })) - _, _, err = dec.DecodeOpaque(3) +// TestDecodeString_LengthOverflow tests that length > maxInt32 is rejected +func TestDecodeString_LengthOverflow(t *testing.T) { + // Encode a length of 0x80000000 (2^31, which is > maxInt32) + data := []byte{0x80, 0x00, 0x00, 0x00} + d := NewDecoder(data) + _, _, err := d.DecodeString(0) if err == nil { - t.Error("expected error when unmarshaling varopaque with non-zero padding byte, got none") + t.Fatal("expected error for length > maxInt32") + } + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrOverflow { + t.Errorf("expected ErrOverflow, got %v", unmarshalErr.ErrorCode) } } -func TestDecodeNilPointerIntoExistingObjectWithNotNilPointer(t *testing.T) { - var buf bytes.Buffer - data := "data" - _, err := Marshal(&buf, structWithPointer{Data: &data}) +// TestDecoder_Decode tests the Decode convenience method +func TestDecoder_Decode(t *testing.T) { + // Create decoder with initial data + data1 := []byte{0x00, 0x00, 0x00, 0x2A} // 42 + decoder := NewDecoder(data1) + + var result testDecoderFromType + n, err := decoder.Decode(&result) if err != nil { - t.Error("unexpected error") + t.Fatalf("Decode failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) } + if result.Value != 42 { + t.Errorf("expected 42, got %d", result.Value) + } + + // Test decoder reuse with Reset + Decode + data2 := []byte{0x00, 0x00, 0x01, 0x00} // 256 + decoder.Reset(data2) - var s structWithPointer - _, err = Unmarshal(&buf, &s) + n, err = decoder.Decode(&result) if err != nil { - t.Error("unexpected error") + t.Fatalf("Decode after Reset failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) } + if result.Value != 256 { + t.Errorf("expected 256, got %d", result.Value) + } +} - // Note: - // 1. structWithPointer.Data is nil. - // 2. We unmarshal into previously used object. - _, err = Marshal(&buf, structWithPointer{}) +// TestDecoder_Decode_MaxDepth tests that Decode passes MaxDepth correctly +func TestDecoder_Decode_MaxDepth(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} + + // Test with custom MaxDepth + decoder := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 75}) + + var result testNestedType + _, err := decoder.Decode(&result) if err != nil { - t.Error("unexpected error") + t.Fatalf("Decode failed: %v", err) + } + if result.ReceivedMaxDepth != 75 { + t.Errorf("expected maxDepth 75, got %d", result.ReceivedMaxDepth) } +} - _, err = Unmarshal(&buf, &s) - if err != nil { - t.Error("unexpected error") +// TestDecoder_Decode_Error tests that Decode propagates errors correctly +func TestDecoder_Decode_Error(t *testing.T) { + // Insufficient data + data := []byte{0x00, 0x00} // Only 2 bytes, need 4 + decoder := NewDecoder(data) + + var result testDecoderFromType + _, err := decoder.Decode(&result) + if err == nil { + t.Fatal("expected error for insufficient data") } - if s.Data != nil { - t.Error("Data should be nil") + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrIO { + t.Errorf("expected ErrIO, got %v", unmarshalErr.ErrorCode) } } -func TestDecodeUnionIntoExistingObject(t *testing.T) { - var buf bytes.Buffer - var idata int32 = 1 - sdata := "data" - _, err := Marshal(&buf, aUnion{ - Type: 0, - Data: &idata, - }) - if err != nil { - t.Error("unexpected error") +// TestDecoder_MaxDepthExceeded tests that deeply nested structures trigger ErrMaxDecodingDepth +func TestDecoder_MaxDepthExceeded(t *testing.T) { + // Create a linked list type with pointer to next + type node struct { + Value int32 + Next *node } - var s aUnion - _, err = Unmarshal(&buf, &s) - if err != nil { - t.Error("unexpected error") + // Build XDR data for a linked list with 5 nodes + // Each node: int32 value, bool (present), then next node + // Node 1 -> Node 2 -> Node 3 -> Node 4 -> Node 5 -> nil + data := []byte{ + // Node 1 + 0x00, 0x00, 0x00, 0x01, // Value = 1 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 2 + 0x00, 0x00, 0x00, 0x02, // Value = 2 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 3 + 0x00, 0x00, 0x00, 0x03, // Value = 3 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 4 + 0x00, 0x00, 0x00, 0x04, // Value = 4 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 5 + 0x00, 0x00, 0x00, 0x05, // Value = 5 + 0x00, 0x00, 0x00, 0x00, // Next present = false (nil) } - _, err = Marshal(&buf, aUnion{ - Type: 1, - Text: &sdata, - }) + // With default MaxDepth (200), this should succeed + var result1 node + _, err := Unmarshal(data, &result1) if err != nil { - t.Error("unexpected error") + t.Fatalf("Unmarshal with default depth failed: %v", err) } - _, err = Unmarshal(&buf, &s) - if err != nil { - t.Error("unexpected error") + // With MaxDepth = 3, this should fail (each pointer adds depth) + // The structure needs: depth for node struct, depth for pointer, depth for next node... + var result2 node + opts := DecodeOptions{MaxDepth: 3} + _, err = UnmarshalWithOptions(data, &result2, opts) + if err == nil { + t.Fatal("expected ErrMaxDecodingDepth error with MaxDepth=3") } - if s.Data != nil { - t.Error("Data should be nil") + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T: %v", err, err) } - - if s.Type != 1 { - t.Error("Type does not match") + if unmarshalErr.ErrorCode != ErrMaxDecodingDepth { + t.Errorf("expected ErrMaxDecodingDepth, got %v", unmarshalErr.ErrorCode) } +} - if *s.Text != sdata { - t.Error("Text does not match") - } +// testUnionValueArm is a union type with value-type arms for testing +type testUnionValueArm struct { + Type int32 + Int int32 // value type arm (switch 0) + Str string // value type arm (switch 1) + Empty bool // value type arm (switch 2) } -func TestDecodeMaxDepth(t *testing.T) { - var buf bytes.Buffer - data := "data" - _, err := Marshal(&buf, structWithPointer{Data: &data}) - if err != nil { - t.Error("unexpected error") +func (u testUnionValueArm) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm } + return "-", false +} - bufCopy := buf - decoder := NewDecoderWithOptions(&bufCopy, DecodeOptions{MaxDepth: 3}) - var s structWithPointer - _, err = decoder.Decode(&s) - if err != nil { - t.Error("unexpected error") - } +func (u testUnionValueArm) SwitchFieldName() string { + return "Type" +} - bufCopy = buf - decoder = NewDecoderWithOptions(&bufCopy, DecodeOptions{MaxDepth: 2}) - _, err = decoder.Decode(&s) - assertError(t, "", err, &UnmarshalError{ErrorCode: ErrMaxDecodingDepth}) +// testUnionPointerArm is a union type with pointer-type arms for testing +type testUnionPointerArm struct { + Type int32 + Int *int32 // pointer type arm (switch 0) + Str *string // pointer type arm (switch 1) } -func TestDecodeMaxAllocationCheck_ImplicitLenReader(t *testing.T) { - var buf bytes.Buffer - _, err := Marshal(&buf, "thisstringis23charslong") - if err != nil { - t.Error("unexpected error") +func (u testUnionPointerArm) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm } + return "-", false +} + +func (u testUnionPointerArm) SwitchFieldName() string { + return "Type" +} - // Reduce the buffer size so that the length of the buffer - // is shorter than the encoded XDR length - buf.Truncate(buf.Len() - 4) +// TestDecodeUnionWithValueTypeArm tests decoding unions with value-type arms +func TestDecodeUnionWithValueTypeArm(t *testing.T) { + t.Run("value-type int arm", func(t *testing.T) { + // XDR encoded union: Type=0, Int=42 + data := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } - decoder := NewDecoder(&buf) - var s string - _, err = decoder.Decode(&s) - assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow}) + var result testUnionValueArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result.Type != 0 { + t.Errorf("Type = %d, want 0", result.Type) + } + if result.Int != 42 { + t.Errorf("Int = %d, want 42", result.Int) + } + }) + + t.Run("value-type string arm", func(t *testing.T) { + // XDR encoded union: Type=1, Str="hi" + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } + + var result testUnionValueArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes read = %d, want 12", n) + } + if result.Type != 1 { + t.Errorf("Type = %d, want 1", result.Type) + } + if result.Str != "hi" { + t.Errorf("Str = %q, want \"hi\"", result.Str) + } + }) + + t.Run("void arm", func(t *testing.T) { + // XDR encoded union: Type=2 (void arm, no data) + data := []byte{ + 0x00, 0x00, 0x00, 0x02, // Type = 2 + } + + var result testUnionValueArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 4 { + t.Errorf("bytes read = %d, want 4", n) + } + if result.Type != 2 { + t.Errorf("Type = %d, want 2", result.Type) + } + }) } -func TestDecodeMaxAllocationCheck_ExplicitLenReader(t *testing.T) { - var buf bytes.Buffer - encoder := base64.NewEncoder(base64.StdEncoding, &buf) - _, err := Marshal(encoder, "thisstringis23charslong") - if err != nil { - t.Error("unexpected error") - } +// TestDecodeUnionWithPointerTypeArm tests decoding unions with pointer-type arms +func TestDecodeUnionWithPointerTypeArm(t *testing.T) { + t.Run("pointer-type int arm", func(t *testing.T) { + // XDR encoded union: Type=0, Int=42 + data := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + var result testUnionPointerArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result.Type != 0 { + t.Errorf("Type = %d, want 0", result.Type) + } + if result.Int == nil || *result.Int != 42 { + t.Errorf("Int = %v, want pointer to 42", result.Int) + } + }) - xdrLen := base64.StdEncoding.DecodedLen(buf.Len()) - // Reduce the buffer size so that the length of the buffer - // is shorter than the encoded XDR length - reducedLen := xdrLen - 4 + t.Run("pointer-type string arm", func(t *testing.T) { + // XDR encoded union: Type=1, Str="hi" + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } - decoder := NewDecoderWithOptions(&buf, DecodeOptions{MaxInputLen: reducedLen}) - var s string - _, err = decoder.Decode(&s) - assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow}) + var result testUnionPointerArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes read = %d, want 12", n) + } + if result.Type != 1 { + t.Errorf("Type = %d, want 1", result.Type) + } + if result.Str == nil || *result.Str != "hi" { + t.Errorf("Str = %v, want pointer to \"hi\"", result.Str) + } + }) } diff --git a/xdr3/encode.go b/xdr3/encode.go index 2e0cf5b..c1341ca 100644 --- a/xdr3/encode.go +++ b/xdr3/encode.go @@ -131,7 +131,7 @@ func (enc *Encoder) EncodeUint(v uint32) (int, error) { n, err := enc.w.Write(enc.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 4) - err := marshalError("EncodeUint", ErrIO, msg, enc.scratchBuf[:4], err) + err := marshalError("EncodeUint", ErrIO, msg, enc.scratchBuf[:n], err) return n, err } @@ -201,7 +201,7 @@ func (enc *Encoder) EncodeHyper(v int64) (int, error) { n, err := enc.w.Write(enc.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 8) - err := marshalError("EncodeHyper", ErrIO, msg, enc.scratchBuf[:8], err) + err := marshalError("EncodeHyper", ErrIO, msg, enc.scratchBuf[:n], err) return n, err } @@ -312,7 +312,7 @@ func (enc *Encoder) EncodeFixedOpaque(v []byte) (int, error) { written := make([]byte, l+n2) copy(written, v) copy(written[l:], b[:n2]) - msg := fmt.Sprintf(errIOEncode, err.Error(), l+pad) + msg := fmt.Sprintf(errIOEncode, err.Error(), l+n2) err := marshalError("EncodeFixedOpaque", ErrIO, msg, written, err) return n, err @@ -460,7 +460,6 @@ func (enc *Encoder) encodeUnion(v reflect.Value) (int, error) { vs := v.FieldByName(u.SwitchFieldName()) n, err := enc.encode(vs) - if err != nil { return n, err } @@ -473,6 +472,7 @@ func (enc *Encoder) encodeUnion(v reflect.Value) (int, error) { } else { sw = int32(vs.Int()) } + arm, ok := u.ArmForSwitch(sw) // void arm, we're done @@ -481,26 +481,25 @@ func (enc *Encoder) encodeUnion(v reflect.Value) (int, error) { } vv := v.FieldByName(arm) - if !vv.IsValid() || !ok { msg := fmt.Sprintf("invalid union switch: %d", sw) err := marshalError("encodeUnion", ErrBadUnionSwitch, msg, nil, nil) return n, err } - if vv.Kind() != reflect.Ptr { - msg := fmt.Sprintf("invalid union value field: %v", vv.Kind()) - err := marshalError("encodeUnion", ErrBadUnionValue, msg, nil, nil) - return n, err - } - - if vv.IsNil() { - msg := fmt.Sprintf("can't encode nil union value") - err := marshalError("encodeUnion", ErrBadUnionValue, msg, nil, nil) + // Handle both pointer and value-type union arms + if vv.Kind() == reflect.Ptr { + if vv.IsNil() { + return n, marshalError("encodeUnion", ErrBadUnionValue, + "can't encode nil pointer union arm", nil, nil) + } + n2, err := enc.encode(vv.Elem()) + n += n2 return n, err } - n2, err := enc.encode(vv.Elem()) + // Value-type arm - encode directly + n2, err := enc.encode(vv) n += n2 return n, err } diff --git a/xdr3/encode_test.go b/xdr3/encode_test.go index bda0873..d4d8c00 100644 --- a/xdr3/encode_test.go +++ b/xdr3/encode_test.go @@ -26,6 +26,72 @@ import ( . "github.com/stellar/go-xdr/xdr3" ) +// subTest is used to allow testing of the Marshal function into struct fields +// which are structs themselves. +type subTest struct { + A string + B uint8 +} + +// allTypesTest is used to allow testing of the Marshal function into struct +// fields of all supported types. +type allTypesTest struct { + A int8 + B uint8 + C int16 + D uint16 + E int32 + F uint32 + G int64 + H uint64 + I bool + J float32 + K float64 + L string + M []byte + N [3]byte + O []int16 + P [2]subTest + Q subTest + R map[string]uint32 + S time.Time +} + +// opaqueStruct is used to test handling of uint8 slices and arrays. +type opaqueStruct struct { + Slice []uint8 `xdropaque:"false"` + Array [1]uint8 `xdropaque:"false"` +} + +type AnEnum int32 + +func (e AnEnum) ValidEnum(v int32) bool { + return v < 3 +} + +type aUnion struct { + Type AnEnum + Data *int32 + Text *string `xdrmaxsize:"28"` +} + +func (u aUnion) SwitchFieldName() string { + return "Type" +} + +func (u aUnion) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Data", true + case 1: + return "Text", true + case 2: // void + return "", true + } + + return "-", false +} + // testExpectedMRet is a convenience method to test an expected number of bytes // written and error for a marshal. func testExpectedMRet(t *testing.T, name string, n, wantN int, err, wantErr error) bool { @@ -703,6 +769,159 @@ func TestEncoder(t *testing.T) { } } +// valueTypeUnion is a union type with value-type arms for testing encoding +type valueTypeUnion struct { + Type int32 + Int int32 // value type arm (switch 0) + Str string // value type arm (switch 1) +} + +func (u valueTypeUnion) SwitchFieldName() string { + return "Type" +} + +func (u valueTypeUnion) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm + } + return "-", false +} + +// pointerTypeUnion is a union type with pointer-type arms for testing encoding +type pointerTypeUnion struct { + Type int32 + Int *int32 // pointer type arm (switch 0) + Str *string // pointer type arm (switch 1) +} + +func (u pointerTypeUnion) SwitchFieldName() string { + return "Type" +} + +func (u pointerTypeUnion) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm + } + return "-", false +} + +// TestMarshalUnionValueTypeArms tests encoding unions with value-type arms +func TestMarshalUnionValueTypeArms(t *testing.T) { + t.Run("value-type int arm", func(t *testing.T) { + u := valueTypeUnion{Type: 0, Int: 42} + expected := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + data := newFixedWriter(8) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes written = %d, want 8", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) + + t.Run("value-type string arm", func(t *testing.T) { + u := valueTypeUnion{Type: 1, Str: "hi"} + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } + + data := newFixedWriter(12) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes written = %d, want 12", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) + + t.Run("void arm", func(t *testing.T) { + u := valueTypeUnion{Type: 2} + expected := []byte{ + 0x00, 0x00, 0x00, 0x02, // Type = 2 + } + + data := newFixedWriter(4) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 4 { + t.Errorf("bytes written = %d, want 4", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) +} + +// TestMarshalUnionPointerTypeArms tests encoding unions with pointer-type arms +func TestMarshalUnionPointerTypeArms(t *testing.T) { + t.Run("pointer-type int arm", func(t *testing.T) { + val := int32(42) + u := pointerTypeUnion{Type: 0, Int: &val} + expected := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + data := newFixedWriter(8) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes written = %d, want 8", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) + + t.Run("pointer-type string arm", func(t *testing.T) { + val := "hi" + u := pointerTypeUnion{Type: 1, Str: &val} + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } + + data := newFixedWriter(12) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes written = %d, want 12", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) +} + // TestMarshalCorners ensures the Marshal function properly handles various // cases not already covered by the other tests. func TestMarshalCorners(t *testing.T) { diff --git a/xdr3/example_test.go b/xdr3/example_test.go index ce2f52d..2b452cb 100644 --- a/xdr3/example_test.go +++ b/xdr3/example_test.go @@ -54,47 +54,8 @@ func ExampleMarshal() { // encoded data: [171 205 239 0 0 0 0 2 0 0 0 1 0 0 0 10] } -// This example demonstrates how to use Unmarshal to decode XDR encoded data -// from a byte slice into a struct. -func ExampleUnmarshal() { - // Hypothetical image header format. - type ImageHeader struct { - Signature [3]byte - Version uint32 - IsGrayscale bool - NumSections uint32 - } - - // XDR encoded data described by the above structure. Typically this - // would be read from a file or across the network, but use a manual - // byte array here as an example. - encodedData := []byte{ - 0xAB, 0xCD, 0xEF, 0x00, // Signature - 0x00, 0x00, 0x00, 0x02, // Version - 0x00, 0x00, 0x00, 0x01, // IsGrayscale - 0x00, 0x00, 0x00, 0x0A, // NumSections - } - - // Declare a variable to provide Unmarshal with a concrete type and - // instance to decode into. - var h ImageHeader - bytesRead, err := xdr.Unmarshal(bytes.NewReader(encodedData), &h) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("bytes read:", bytesRead) - fmt.Printf("h: %+v", h) - - // Output: - // bytes read: 16 - // h: {Signature:[171 205 239] Version:2 IsGrayscale:true NumSections:10} -} - // This example demonstrates how to manually decode XDR encoded data from a -// reader. Compare this example with the Unmarshal example which performs the -// same task automatically by utilizing a struct type definition and reflection. +// byte slice using Decoder. func ExampleNewDecoder() { // XDR encoded data for a hypothetical ImageHeader struct as follows: // type ImageHeader struct { @@ -111,7 +72,7 @@ func ExampleNewDecoder() { } // Get a new decoder for manual decoding. - dec := xdr.NewDecoder(bytes.NewReader(encodedData)) + dec := xdr.NewDecoder(encodedData) signature, _, err := dec.DecodeFixedOpaque(3) if err != nil { diff --git a/xdr3/internal_test.go b/xdr3/internal_test.go index ba575ce..4b2ee2e 100644 --- a/xdr3/internal_test.go +++ b/xdr3/internal_test.go @@ -34,10 +34,3 @@ func TstEncode(w io.Writer) func(v reflect.Value) (int, error) { enc := NewEncoder(w) return enc.encode } - -// TstDecode creates a new Decoder for the passed reader and returns the -// internal decode function on the Decoder. -func TstDecode(r io.Reader) func(v reflect.Value, maxLen int, maxDepth uint) (int, error) { - dec := NewDecoder(r) - return dec.decode -}