From 81c4cd42c070617c66b30ef4214166cc458dcbbe Mon Sep 17 00:00:00 2001 From: tamirms Date: Wed, 7 Jan 2026 23:02:16 +0000 Subject: [PATCH 1/2] Refactor Decoder to use byte slices and simplify DecoderFrom interface This refactors the Decoder to work directly with byte slices instead of io.Reader for improved performance, and simplifies the DecoderFrom interface. API Changes: - Decoder now takes []byte instead of io.Reader - Unmarshal/UnmarshalWithOptions take []byte instead of io.Reader - Add DecoderFrom interface for types to implement fast-path decoding - Add Decoder methods: Reset(), Remaining(), Position(), MaxDepth() - Remove MaxInputLen from DecodeOptions (now uses Remaining()) - Remove redundant maxDepth parameter from DecoderFrom interface Depth Tracking: - Add EnterScope()/LeaveScope() methods for stateful depth tracking - Add currentDepth field to Decoder struct - Allows generated code to track recursion depth across nested DecodeFrom calls without passing maxDepth as a parameter Performance: - Eliminate io.Reader overhead with direct buffer access Bug Fixes: - Fix setUnionArmsToNil to nil individual fields instead of whole struct - Support value-type union arms in encode/decode (not just pointers) - Fix error messages to show actual bytes written on partial writes Co-Authored-By: Claude Opus 4.5 --- .github/workflows/build.yml | 2 +- go.mod | 2 +- xdr3/bench_test.go | 76 -- xdr3/decode.go | 456 ++++--- xdr3/decode_bench_test.go | 239 ++++ xdr3/decode_test.go | 2306 ++++++++++++++++++++--------------- xdr3/encode.go | 29 +- xdr3/encode_test.go | 219 ++++ xdr3/example_test.go | 43 +- xdr3/internal_test.go | 7 - 10 files changed, 1995 insertions(+), 1384 deletions(-) delete mode 100644 xdr3/bench_test.go create mode 100644 xdr3/decode_bench_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2560488..d069293 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,7 +10,7 @@ jobs: name: Build strategy: matrix: - go: ["1.20", "1.21"] + go: ["1.24", "1.25"] runs-on: ubuntu-latest container: golang:${{ matrix.go }}-bookworm steps: diff --git a/go.mod b/go.mod index edd23af..16eed6e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/stellar/go-xdr -go 1.12 +go 1.22 diff --git a/xdr3/bench_test.go b/xdr3/bench_test.go deleted file mode 100644 index 19d4886..0000000 --- a/xdr3/bench_test.go +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2012-2014 Dave Collins - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -package xdr_test - -import ( - "bytes" - "testing" - "unsafe" - - xdr "github.com/stellar/go-xdr/xdr3" -) - -// BenchmarkUnmarshal benchmarks the Unmarshal function by using a dummy -// ImageHeader structure. -func BenchmarkUnmarshal(b *testing.B) { - b.StopTimer() - // Hypothetical image header format. - type ImageHeader struct { - Signature [3]byte - Version uint32 - IsGrayscale bool - NumSections uint32 - } - // XDR encoded data described by the above structure. - encodedData := []byte{ - 0xAB, 0xCD, 0xEF, 0x00, - 0x00, 0x00, 0x00, 0x02, - 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x0A, - } - var h ImageHeader - b.StartTimer() - - for i := 0; i < b.N; i++ { - r := bytes.NewReader(encodedData) - _, _ = xdr.Unmarshal(r, &h) - } - b.SetBytes(int64(len(encodedData))) -} - -// BenchmarkMarshal benchmarks the Marshal function by using a dummy ImageHeader -// structure. -func BenchmarkMarshal(b *testing.B) { - b.StopTimer() - // Hypothetical image header format. - type ImageHeader struct { - Signature [3]byte - Version uint32 - IsGrayscale bool - NumSections uint32 - } - h := ImageHeader{[3]byte{0xAB, 0xCD, 0xEF}, 2, true, 10} - size := unsafe.Sizeof(h) - w := bytes.NewBuffer(nil) - b.StartTimer() - - for i := 0; i < b.N; i++ { - w.Reset() - _, _ = xdr.Marshal(w, &h) - } - b.SetBytes(int64(size)) -} diff --git a/xdr3/decode.go b/xdr3/decode.go index 1c8f383..b55840b 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -1,5 +1,6 @@ /* * Copyright (c) 2012-2014 Dave Collins + * Copyright (c) 2026 Stellar Development Foundation * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -17,8 +18,8 @@ package xdr import ( + "encoding/binary" "fmt" - "io" "math" "reflect" "strconv" @@ -28,7 +29,6 @@ import ( const maxInt32 = math.MaxInt32 var errMaxSlice = "data exceeds max slice limit" -var errIODecode = "%s while decoding %d bytes" // DecodeDefaultMaxDepth is the default maximum decoding depth const DecodeDefaultMaxDepth = 200 @@ -39,30 +39,31 @@ type DecodeOptions struct { // It prevents infinite recursions in cyclic datastructures and determines the maximum callstack growth. // If set to 0, DecodeDefaultMaxDepth will be used. MaxDepth uint - - // MaxInputLen sets the maximum input size. It is used by the decoder to sanity-check - // allocation sizes and avoid heap explosions from doctored inputs. - // - // If set to 0, the decoder will try to figure out the input size by checking whether - // the provided io.Reader implements Len() (e.g. strings.Reader, bytes.Reader and bytes.Buffer do). - // Otherwise, no sanity checks will be done. - MaxInputLen int } // DefaultDecodeOptions are the default decoding options. var DefaultDecodeOptions = DecodeOptions{ - MaxDepth: DecodeDefaultMaxDepth, - MaxInputLen: 0, + MaxDepth: DecodeDefaultMaxDepth, +} + +// DecoderFrom is implemented by types that can decode themselves from a Decoder. +// Types implementing this interface get a fast path in Decode(), bypassing reflection. +// Implementations can call d.MaxDepth() if they need to track recursion depth. +type DecoderFrom interface { + DecodeFrom(d *Decoder) (int, error) } /* Unmarshal parses XDR-encoded data into the value pointed to by v reading from -reader r and returning the total number of bytes read. An addressable pointer +a byte slice and returning the total number of bytes read. An addressable pointer must be provided since Unmarshal needs to both store the result of the decode as well as obtain target type information. Unmarhsal traverses v recursively and automatically indirects pointers through arbitrary depth, allocating them as necessary, to decode the data into the underlying value pointed to. +If v implements DecoderFrom, its DecodeFrom method is called directly for +better performance. Otherwise, reflection-based decoding is used. + Unmarshal uses reflection to determine the type of the concrete value contained by v and performs a mapping of underlying XDR types to Go types as follows: @@ -99,82 +100,92 @@ an ErrorCode value for further inspection from sophisticated callers. Some potential issues are unsupported Go types, attempting to decode a value which is too large to fit into a specified Go type, and exceeding max slice limitations. */ -func Unmarshal(r io.Reader, v interface{}) (int, error) { - d := NewDecoder(r) +func Unmarshal(data []byte, v interface{}) (int, error) { + d := NewDecoder(data) return d.Decode(v) } // UnmarshalWithOptions works like Unmarshal but accepts decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options DecodeOptions) (int, error) { - d := NewDecoderWithOptions(r, options) +func UnmarshalWithOptions(data []byte, v interface{}, options DecodeOptions) (int, error) { + d := NewDecoderWithOptions(data, options) return d.Decode(v) } -type lenLeft interface { - Len() int -} - -// A Decoder wraps an io.Reader that is expected to provide an XDR-encoded byte -// stream and provides several exposed methods to manually decode various XDR -// primitives without relying on reflection. The NewDecoder function can be -// used to get a new Decoder directly. +// A Decoder reads XDR-encoded data from a byte slice and provides several +// exposed methods to manually decode various XDR primitives without relying +// on reflection. The NewDecoder function can be used to get a new Decoder directly. // -// Typically, Unmarshal should be used instead of manual decoding. A Decoder +// Typically, Unmarshal should be used instead of manual decoding. A Decoder // is exposed, so it is possible to perform manual decoding should it be // necessary in complex scenarios where automatic reflection-based decoding // won't work. type Decoder struct { - // used to minimize heap allocations during decoding - scratchBuf [8]byte - r io.Reader - l lenLeft - maxDepth uint -} - -// readerLenWrapper wraps a reader an initial length and provides a Len() method indicating -// how much input is left -type readerLenWrapper struct { - inner io.Reader - readCount int - initialLen int -} - -func (l *readerLenWrapper) Len() int { - return l.initialLen - l.readCount -} - -func (l *readerLenWrapper) Read(p []byte) (int, error) { - n, err := l.inner.Read(p) - if n > 0 { - l.readCount += n - } - return n, err + buf []byte + pos int + maxDepth uint + currentDepth uint } // NewDecoder returns a Decoder that can be used to manually decode XDR data -// from a provided reader. Typically, Unmarshal should be used instead of +// from a provided byte slice. Typically, Unmarshal should be used instead of // manually creating a Decoder. -func NewDecoder(r io.Reader) *Decoder { - return NewDecoderWithOptions(r, DefaultDecodeOptions) +func NewDecoder(data []byte) *Decoder { + return NewDecoderWithOptions(data, DefaultDecodeOptions) } // NewDecoderWithOptions works like NewDecoder but allows supplying decoding options. -func NewDecoderWithOptions(r io.Reader, options DecodeOptions) *Decoder { +func NewDecoderWithOptions(data []byte, options DecodeOptions) *Decoder { maxDepth := options.MaxDepth if maxDepth < 1 { maxDepth = DecodeDefaultMaxDepth } - if l, ok := r.(lenLeft); ok { - return &Decoder{r: r, l: l, maxDepth: maxDepth} + return &Decoder{ + buf: data, + pos: 0, + maxDepth: maxDepth, + currentDepth: maxDepth, } - if options.MaxInputLen > 0 { - rlw := &readerLenWrapper{ - inner: r, - initialLen: options.MaxInputLen, - } - return &Decoder{r: rlw, l: rlw, maxDepth: maxDepth} +} + +// Reset resets the decoder to read from a new byte slice, allowing reuse +// of the decoder to reduce allocations. CurrentDepth is reset to MaxDepth. +func (d *Decoder) Reset(data []byte) { + d.buf = data + d.pos = 0 + d.currentDepth = d.maxDepth +} + +// Remaining returns the number of unread bytes in the buffer. +func (d *Decoder) Remaining() int { + return len(d.buf) - d.pos +} + +// Position returns the current read position in the buffer. +func (d *Decoder) Position() int { + return d.pos +} + +// MaxDepth returns the maximum decoding depth setting. +func (d *Decoder) MaxDepth() uint { + return d.maxDepth +} + +// EnterScope should be called at the start of decoding a compound type +// (struct, union, or array element). Returns an error if max depth would be exceeded. +// Use with LeaveScope: `if err := d.EnterScope(); err != nil { return err }; defer d.LeaveScope()` +func (d *Decoder) EnterScope() error { + if d.currentDepth == 0 { + return unmarshalError("EnterScope", ErrMaxDecodingDepth, "maximum decoding depth reached", nil, nil) + } + d.currentDepth-- + return nil +} + +// LeaveScope should be called when exiting a compound type. Should be used with defer. +func (d *Decoder) LeaveScope() { + if d.currentDepth < d.maxDepth { + d.currentDepth++ } - return &Decoder{r: r, l: nil, maxDepth: options.MaxDepth} } // DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the @@ -187,16 +198,12 @@ func NewDecoderWithOptions(r io.Reader, options DecodeOptions) *Decoder { // RFC Section 4.1 - Integer // 32-bit big-endian signed integer in range [-2147483648, 2147483647] func (d *Decoder) DecodeInt() (int32, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:4]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeInt", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 4 { + return 0, 0, unmarshalError("DecodeInt", ErrIO, "unexpected end of input", nil, nil) } - - rv := int32(d.scratchBuf[3]) | int32(d.scratchBuf[2])<<8 | - int32(d.scratchBuf[1])<<16 | int32(d.scratchBuf[0])<<24 - return rv, n, nil + v := int32(binary.BigEndian.Uint32(d.buf[d.pos:])) + d.pos += 4 + return v, 4, nil } // DecodeUint treats the next 4 bytes as an XDR encoded unsigned integer and @@ -209,16 +216,12 @@ func (d *Decoder) DecodeInt() (int32, int, error) { // RFC Section 4.2 - Unsigned Integer // 32-bit big-endian unsigned integer in range [0, 4294967295] func (d *Decoder) DecodeUint() (uint32, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:4]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeUint", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 4 { + return 0, 0, unmarshalError("DecodeUint", ErrIO, "unexpected end of input", nil, nil) } - - rv := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 | - uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24 - return rv, n, nil + v := binary.BigEndian.Uint32(d.buf[d.pos:]) + d.pos += 4 + return v, 4, nil } // DecodeEnum treats the next 4 bytes as an XDR encoded enumeration value and @@ -284,18 +287,12 @@ func (d *Decoder) DecodeBool() (bool, int, error) { // RFC Section 4.5 - Hyper Integer // 64-bit big-endian signed integer in range [-9223372036854775808, 9223372036854775807] func (d *Decoder) DecodeHyper() (int64, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:8]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeHyper", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 8 { + return 0, 0, unmarshalError("DecodeHyper", ErrIO, "unexpected end of input", nil, nil) } - - rv := int64(d.scratchBuf[7]) | int64(d.scratchBuf[6])<<8 | - int64(d.scratchBuf[5])<<16 | int64(d.scratchBuf[4])<<24 | - int64(d.scratchBuf[3])<<32 | int64(d.scratchBuf[2])<<40 | - int64(d.scratchBuf[1])<<48 | int64(d.scratchBuf[0])<<56 - return rv, n, err + v := int64(binary.BigEndian.Uint64(d.buf[d.pos:])) + d.pos += 8 + return v, 8, nil } // DecodeUhyper treats the next 8 bytes as an XDR encoded unsigned hyper value @@ -309,18 +306,12 @@ func (d *Decoder) DecodeHyper() (int64, int, error) { // RFC Section 4.5 - Unsigned Hyper Integer // 64-bit big-endian unsigned integer in range [0, 18446744073709551615] func (d *Decoder) DecodeUhyper() (uint64, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:8]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeUhyper", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 8 { + return 0, 0, unmarshalError("DecodeUhyper", ErrIO, "unexpected end of input", nil, nil) } - - rv := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 | - uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 | - uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 | - uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56 - return rv, n, nil + v := binary.BigEndian.Uint64(d.buf[d.pos:]) + d.pos += 8 + return v, 8, nil } // DecodeFloat treats the next 4 bytes as an XDR encoded floating point and @@ -333,16 +324,12 @@ func (d *Decoder) DecodeUhyper() (uint64, int, error) { // RFC Section 4.6 - Floating Point // 32-bit single-precision IEEE 754 floating point func (d *Decoder) DecodeFloat() (float32, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:4]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeFloat", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 4 { + return 0, 0, unmarshalError("DecodeFloat", ErrIO, "unexpected end of input", nil, nil) } - - val := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 | - uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24 - return math.Float32frombits(val), n, nil + v := binary.BigEndian.Uint32(d.buf[d.pos:]) + d.pos += 4 + return math.Float32frombits(v), 4, nil } // DecodeDouble treats the next 8 bytes as an XDR encoded double-precision @@ -356,18 +343,12 @@ func (d *Decoder) DecodeFloat() (float32, int, error) { // RFC Section 4.7 - Double-Precision Floating Point // 64-bit double-precision IEEE 754 floating point func (d *Decoder) DecodeDouble() (float64, int, error) { - n, err := io.ReadFull(d.r, d.scratchBuf[:8]) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeDouble", ErrIO, msg, d.scratchBuf[:n], err) - return 0, n, err + if d.Remaining() < 8 { + return 0, 0, unmarshalError("DecodeDouble", ErrIO, "unexpected end of input", nil, nil) } - - val := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 | - uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 | - uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 | - uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56 - return math.Float64frombits(val), n, nil + v := binary.BigEndian.Uint64(d.buf[d.pos:]) + d.pos += 8 + return math.Float64frombits(v), 8, nil } // RFC Section 4.8 - Quadruple-Precision Floating Point @@ -395,54 +376,43 @@ func (d *Decoder) DecodeFixedOpaque(size int32) ([]byte, int, error) { return out, n, nil } -// DecodeFixedOpaqueInplace is an in-place version of DecodeFixedOpaque. -// It improves performance when the destination is pre-allocated (which avoids -// internally allocating an extra slice and does not require further copying) -func (d *Decoder) DecodeFixedOpaqueInplace(out []byte) (int, error) { - size := len(out) - // Nothing to do if size is 0. +// consumePaddedData validates XDR padding and advances the position. +// Returns the start position in the buffer and the padded size. +// Callers can access d.buf[start:start+size] for the actual data. +func (d *Decoder) consumePaddedData(size int) (start, paddedSize int, err error) { if size == 0 { - return 0, nil + return d.pos, 0, nil } - pad := (4 - (size % 4)) % 4 - paddedSize := size + pad + paddedSize = (size + 3) &^ 3 // Round up to multiple of 4 if uint(paddedSize) > uint(maxInt32) { - err := unmarshalError("DecodeFixedOpaqueInplace", ErrOverflow, - errMaxSlice, paddedSize, nil) - return 0, err + return 0, 0, unmarshalError("consumePaddedData", ErrOverflow, errMaxSlice, paddedSize, nil) } - - n, err := io.ReadFull(d.r, out) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), size) - err := unmarshalError("DecodeFixedOpaqueInplace", ErrIO, msg, out[:n], - err) - return n, err + if d.Remaining() < paddedSize { + return 0, 0, unmarshalError("consumePaddedData", ErrIO, "unexpected end of input", nil, nil) } - if pad > 0 { - // the maximum value of pad is 3, so the scratch buffer should be enough - _ = d.scratchBuf[2] - padding := d.scratchBuf[:pad] - n2, err := io.ReadFull(d.r, padding) - if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), pad) - err := unmarshalError("DecodeFixedOpaqueInplace", ErrIO, msg, out[:n], - err) - return n, err - } - n += n2 - // check all the padding bytes to be zero - for _, p := range padding { - if p != 0x00 { - msg := "non-zero padding" - err := unmarshalError("DecodeFixedOpaqueInplace", ErrIO, msg, padding[:n2], nil) - return n, err - } + // Validate padding bytes are zero + for i := size; i < paddedSize; i++ { + if d.buf[d.pos+i] != 0 { + return 0, 0, unmarshalError("consumePaddedData", ErrIO, "non-zero padding", d.buf[d.pos+size:d.pos+paddedSize], nil) } } + start = d.pos + d.pos += paddedSize + return start, paddedSize, nil +} + +// DecodeFixedOpaqueInplace is an in-place version of DecodeFixedOpaque. +// It improves performance when the destination is pre-allocated (which avoids +// internally allocating an extra slice and does not require further copying) +func (d *Decoder) DecodeFixedOpaqueInplace(out []byte) (int, error) { + start, n, err := d.consumePaddedData(len(out)) + if err != nil { + return 0, err + } + copy(out, d.buf[start:start+len(out)]) return n, nil } @@ -463,15 +433,9 @@ func (d *Decoder) DecodeOpaque(maxSize int) ([]byte, int, error) { return nil, n, err } - maxSize = d.mergeInputLenAndMaxSize(maxSize) - if maxSize == 0 { - maxSize = maxInt32 - } - + maxSize = d.mergeRemainingAndMaxSize(maxSize) if uint(dataLen) > uint(maxSize) { - err := unmarshalError("DecodeOpaque", ErrOverflow, errMaxSlice, - dataLen, nil) - return nil, n, err + return nil, n, unmarshalError("DecodeOpaque", ErrOverflow, errMaxSlice, dataLen, nil) } rv, n2, err := d.DecodeFixedOpaque(int32(dataLen)) @@ -503,23 +467,49 @@ func (d *Decoder) DecodeString(maxSize int) (string, int, error) { return "", n, err } - maxSize = d.mergeInputLenAndMaxSize(maxSize) - if maxSize == 0 { - maxSize = maxInt32 - } - + maxSize = d.mergeRemainingAndMaxSize(maxSize) if uint(dataLen) > uint(maxSize) { - err = unmarshalError("DecodeString", ErrOverflow, errMaxSlice, - dataLen, nil) - return "", n, err + return "", n, unmarshalError("DecodeString", ErrOverflow, errMaxSlice, dataLen, nil) } - opaque, n2, err := d.DecodeFixedOpaque(int32(dataLen)) - n += n2 + start, n2, err := d.consumePaddedData(int(dataLen)) if err != nil { return "", n, err } - return string(opaque), n, nil + + return string(d.buf[start : start+int(dataLen)]), n + n2, nil +} + +// Skip advances the decoder position by n bytes without decoding. +func (d *Decoder) Skip(n int) error { + if n < 0 { + return unmarshalError("Skip", ErrBadArguments, "negative skip length", n, nil) + } + if d.Remaining() < n { + return unmarshalError("Skip", ErrIO, "unexpected end of input", nil, nil) + } + d.pos += n + return nil +} + +// Bytes returns the remaining unread bytes in the buffer. +// WARNING: The returned slice shares memory with the input buffer. +func (d *Decoder) Bytes() []byte { + return d.buf[d.pos:] +} + +// mergeRemainingAndMaxSize returns the effective maximum size for decoding, +// taking into account both the user-specified maxSize and the remaining +// input bytes. If maxSize is 0 (unlimited), it defaults to maxInt32. +// The result is the minimum of maxSize and remaining bytes. +func (d *Decoder) mergeRemainingAndMaxSize(maxSize int) int { + if maxSize == 0 { + maxSize = maxInt32 + } + if remaining := d.Remaining(); remaining < maxSize { + return remaining + } + return maxSize } // decodeFixedArray treats the next bytes as a series of XDR encoded elements @@ -577,15 +567,9 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int, m return n, err } - maxSize = d.mergeInputLenAndMaxSize(maxSize) - if maxSize == 0 { - maxSize = maxInt32 - } - + maxSize = d.mergeRemainingAndMaxSize(maxSize) if uint(dataLen) > uint(maxSize) { - err := unmarshalError("decodeArray", ErrOverflow, errMaxSlice, - dataLen, nil) - return n, err + return n, unmarshalError("decodeArray", ErrOverflow, errMaxSlice, dataLen, nil) } // Allocate storage for the slice elements (the underlying array) if @@ -621,10 +605,9 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int, m func setUnionArmsToNil(v reflect.Value) { for i := 0; i < v.NumField(); i++ { f := v.Field(i) - if f.Kind() != reflect.Ptr { - continue + if f.Kind() == reflect.Ptr && f.CanSet() { + f.Set(reflect.Zero(f.Type())) } - v.Set(reflect.Zero(v.Type())) } } @@ -644,20 +627,23 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { vs := v.FieldByName(u.SwitchFieldName()) // ensure the switch field is a valid enum value for the union, if possible - enum, ok := vs.Interface().(Enum) - - if ok && !enum.ValidEnum(i) { + enum, isEnum := vs.Interface().(Enum) + if isEnum && !enum.ValidEnum(i) { msg := fmt.Sprintf("switch '%d' is not valid enum value for union", i) err := unmarshalError("decode", ErrBadUnionSwitch, msg, nil, nil) return n, err } + // Set switch field value with proper type handling kind := vs.Kind() - if kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || - kind == reflect.Uint32 || kind == reflect.Uint64 { - vs.SetUint(uint64(i)) - } else { + switch kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: vs.SetInt(int64(i)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vs.SetUint(uint64(i)) + default: + return n, unmarshalError("decodeUnion", ErrBadUnionSwitch, + fmt.Sprintf("switch field has unsupported type: %v", kind), nil, nil) } arm, ok := u.ArmForSwitch(i) @@ -672,11 +658,7 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { return n, nil } - vv := v.FieldByName(arm) - - vvet := vv.Type().Elem() - vv.Set(reflect.New(vvet)) - + // Validate field exists in type before accessing it field, ok := v.Type().FieldByName(arm) if !ok { msg := fmt.Sprintf("switch '%s' is not valid for union", arm) @@ -684,6 +666,8 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { return n, err } + vv := v.FieldByName(arm) + maxSize := 0 sizeTag := field.Tag.Get("xdrmaxsize") if sizeTag != "" { @@ -694,12 +678,25 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) { maxSize = int(sz) } - n2, err := d.decode(vv.Elem(), maxSize, maxDepth) - n += n2 - - if err != nil { - return n, err + // Handle both pointer and value-type union arms + if vv.Kind() == reflect.Ptr { + // Pointer field - allocate new value and decode into it + vvet := vv.Type().Elem() + vv.Set(reflect.New(vvet)) + n2, err := d.decode(vv.Elem(), maxSize, maxDepth) + n += n2 + if err != nil { + return n, err + } + } else { + // Value field - decode directly into the field + n2, err := d.decode(vv, maxSize, maxDepth) + n += n2 + if err != nil { + return n, err + } } + return n, nil } @@ -808,10 +805,11 @@ func (d *Decoder) decodeMap(v reflect.Value, maxDepth uint) (int, error) { if err != nil { return n, err } - if left, ok := d.InputLen(); ok { - if uint(left) < uint(dataLen) { - return n, unmarshalError("decodeMap", ErrOverflow, errMaxSlice, dataLen, nil) - } + // Sanity check: each map entry requires at least 8 bytes (4 for key + 4 for value) + // This prevents allocating huge maps based on malicious length values + // Use multiplication (with uint64 to prevent overflow) instead of division for precision + if uint64(dataLen)*8 > uint64(d.Remaining()) { + return n, unmarshalError("decodeMap", ErrOverflow, errMaxSlice, dataLen, nil) } // Allocate storage for the underlying map if needed. @@ -875,15 +873,6 @@ func (d *Decoder) decodeInterface(v reflect.Value, maxDepth uint) (int, error) { return d.decode(ve, 0, maxDepth) } -func (d *Decoder) mergeInputLenAndMaxSize(maxSize int) int { - if left, ok := d.InputLen(); ok { - if maxSize == 0 || left < maxSize { - return left - } - } - return maxSize -} - // decode is the main workhorse for unmarshalling via reflection. It uses // the passed reflection value to choose the XDR primitives to decode from // the encapsulated reader. It is a recursive function, @@ -1148,9 +1137,12 @@ func (d *Decoder) indirectIfPtr(v reflect.Value) (reflect.Value, error) { } // Decode operates identically to the Unmarshal function with the exception of -// using the reader associated with the Decoder as the source of XDR-encoded -// data instead of a user-supplied reader. See the Unmarhsal documentation for -// specifics. Decode(v) is equivalent to DecodeWithMaxDepth(v, DecodeDefaultMaxDepth) +// using the byte slice associated with the Decoder as the source of XDR-encoded +// data instead of a user-supplied byte slice. See the Unmarshal documentation for +// specifics. +// +// If v implements DecoderFrom, its DecodeFrom method is called directly for +// better performance, bypassing reflection. func (d *Decoder) Decode(v interface{}) (int, error) { if v == nil { msg := "can't unmarshal to nil interface" @@ -1158,6 +1150,12 @@ func (d *Decoder) Decode(v interface{}) (int, error) { nil) } + // Fast path: if v implements DecoderFrom, use it directly + if decodable, ok := v.(DecoderFrom); ok { + return decodable.DecodeFrom(d) + } + + // Fallback: reflection-based decoding vv := reflect.ValueOf(v) if vv.Kind() != reflect.Ptr { msg := fmt.Sprintf("can't unmarshal to non-pointer '%v' - use "+ @@ -1174,11 +1172,3 @@ func (d *Decoder) Decode(v interface{}) (int, error) { return d.decode(vv.Elem(), 0, d.maxDepth) } - -// InputLen returns the size left to read from the decoder's input if available -func (d *Decoder) InputLen() (int, bool) { - if d.l == nil { - return 0, false - } - return d.l.Len(), true -} diff --git a/xdr3/decode_bench_test.go b/xdr3/decode_bench_test.go new file mode 100644 index 0000000..4d1312b --- /dev/null +++ b/xdr3/decode_bench_test.go @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2012-2014 Dave Collins + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +package xdr_test + +import ( + "testing" + + xdr "github.com/stellar/go-xdr/xdr3" +) + +// Test data for benchmarks +var ( + // XDR encoded int32 (value: 12345678) + encodedInt = []byte{0x00, 0xBC, 0x61, 0x4E} + + // XDR encoded uint32 (value: 0xDEADBEEF) + encodedUint = []byte{0xDE, 0xAD, 0xBE, 0xEF} + + // XDR encoded int64/hyper (value: 0x0102030405060708) + encodedHyper = []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + + // XDR encoded string "hello world" (length 11, padded to 12 bytes) + encodedString = []byte{ + 0x00, 0x00, 0x00, 0x0B, // length = 11 + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x00, // "hello world" + 1 byte padding + } + + // XDR encoded longer string (64 bytes of 'x', padded) + encodedLongString = func() []byte { + data := make([]byte, 4+64) // length prefix + 64 chars (no padding needed, 64 % 4 == 0) + data[0], data[1], data[2], data[3] = 0x00, 0x00, 0x00, 0x40 // length = 64 + for i := 4; i < 68; i++ { + data[i] = 'x' + } + return data + }() + + // XDR encoded variable opaque (32 bytes) + encodedOpaque = func() []byte { + data := make([]byte, 4+32) // length prefix + 32 bytes + data[0], data[1], data[2], data[3] = 0x00, 0x00, 0x00, 0x20 // length = 32 + for i := 4; i < 36; i++ { + data[i] = byte(i - 4) + } + return data + }() + + // XDR encoded fixed opaque (32 bytes, no length prefix) + encodedFixedOpaque = func() []byte { + data := make([]byte, 32) + for i := 0; i < 32; i++ { + data[i] = byte(i) + } + return data + }() +) + +// ============================================================================ +// Primitive Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeInt(b *testing.B) { + d := xdr.NewDecoder(encodedInt) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedInt) + _, _, _ = d.DecodeInt() + } + b.SetBytes(int64(len(encodedInt))) +} + +func BenchmarkDecodeUint(b *testing.B) { + d := xdr.NewDecoder(encodedUint) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedUint) + _, _, _ = d.DecodeUint() + } + b.SetBytes(int64(len(encodedUint))) +} + +func BenchmarkDecodeHyper(b *testing.B) { + d := xdr.NewDecoder(encodedHyper) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedHyper) + _, _, _ = d.DecodeHyper() + } + b.SetBytes(int64(len(encodedHyper))) +} + +// ============================================================================ +// String Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeString(b *testing.B) { + d := xdr.NewDecoder(encodedString) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedString) + _, _, _ = d.DecodeString(0) + } + b.SetBytes(int64(len(encodedString))) +} + +func BenchmarkDecodeLongString(b *testing.B) { + d := xdr.NewDecoder(encodedLongString) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedLongString) + _, _, _ = d.DecodeString(0) + } + b.SetBytes(int64(len(encodedLongString))) +} + +// ============================================================================ +// Opaque Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeOpaque(b *testing.B) { + d := xdr.NewDecoder(encodedOpaque) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedOpaque) + _, _, _ = d.DecodeOpaque(0) + } + b.SetBytes(int64(len(encodedOpaque))) +} + +// ============================================================================ +// Fixed Opaque Decoding Benchmarks +// ============================================================================ + +func BenchmarkDecodeFixedOpaque(b *testing.B) { + d := xdr.NewDecoder(encodedFixedOpaque) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedFixedOpaque) + _, _, _ = d.DecodeFixedOpaque(32) + } + b.SetBytes(int64(len(encodedFixedOpaque))) +} + +// ============================================================================ +// Multiple Field Decoding Benchmarks (simulating struct decoding) +// ============================================================================ + +// Encoded struct with: uint32, int64, string "hello world", 32-byte opaque +var encodedStruct = func() []byte { + var buf []byte + buf = append(buf, encodedUint...) + buf = append(buf, encodedHyper...) + buf = append(buf, encodedString...) + buf = append(buf, 0x00, 0x00, 0x00, 0x20) // opaque length = 32 + for i := 0; i < 32; i++ { + buf = append(buf, byte(i)) + } + return buf +}() + +func BenchmarkDecodeMultipleFields(b *testing.B) { + d := xdr.NewDecoder(encodedStruct) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(encodedStruct) + _, _, _ = d.DecodeUint() + _, _, _ = d.DecodeHyper() + _, _, _ = d.DecodeString(0) + _, _, _ = d.DecodeOpaque(0) + } + b.SetBytes(int64(len(encodedStruct))) +} + +// ============================================================================ +// Decoder Reuse Benchmarks +// ============================================================================ + +// benchDecoderFromType implements DecoderFrom for benchmarking +type benchDecoderFromType struct { + Value int32 +} + +func (t *benchDecoderFromType) DecodeFrom(d *xdr.Decoder) (int, error) { + v, n, err := d.DecodeInt() + if err != nil { + return n, err + } + t.Value = v + return n, nil +} + +// BenchmarkNewDecoderEachTime measures creating a new decoder for each decode +func BenchmarkNewDecoderEachTime(b *testing.B) { + data := encodedInt + var result benchDecoderFromType + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d := xdr.NewDecoder(data) + _, _ = d.Decode(&result) + } + b.SetBytes(int64(len(data))) +} + +// BenchmarkDecoderReuse measures reusing a decoder with Reset+Decode +func BenchmarkDecoderReuse(b *testing.B) { + data := encodedInt + d := xdr.NewDecoder(nil) + var result benchDecoderFromType + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + d.Reset(data) + _, _ = d.Decode(&result) + } + b.SetBytes(int64(len(data))) +} diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index 747ebf9..39ab9c3 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -1,5 +1,6 @@ /* * Copyright (c) 2012-2014 Dave Collins + * Copyright (c) 2026 Stellar Development Foundation * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -14,1181 +15,1466 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -package xdr_test +package xdr import ( - "bytes" - "encoding/base64" - "fmt" - "math" + "errors" "reflect" "testing" - "time" - - . "github.com/stellar/go-xdr/xdr3" ) -// subTest is used to allow testing of the Unmarshal function into struct fields -// which are structs themselves. -type subTest struct { - A string - B uint8 -} +// TestDecoder_EOF tests that Decoder returns proper errors on EOF +func TestDecoder_EOF(t *testing.T) { + tests := []struct { + name string + data []byte + decode func(d *Decoder) error + errCode ErrorCode // expected error code, defaults to ErrIO + }{ + { + name: "DecodeInt EOF", + data: []byte{0x00, 0x00}, // Only 2 bytes, need 4 + decode: func(d *Decoder) error { + _, _, err := d.DecodeInt() + return err + }, + }, + { + name: "DecodeUint EOF", + data: []byte{0x00, 0x00, 0x00}, // Only 3 bytes, need 4 + decode: func(d *Decoder) error { + _, _, err := d.DecodeUint() + return err + }, + }, + { + name: "DecodeHyper EOF", + data: []byte{0x00, 0x00, 0x00, 0x00}, // Only 4 bytes, need 8 + decode: func(d *Decoder) error { + _, _, err := d.DecodeHyper() + return err + }, + }, + { + name: "DecodeUhyper EOF", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, // Only 5 bytes, need 8 + decode: func(d *Decoder) error { + _, _, err := d.DecodeUhyper() + return err + }, + }, + { + name: "DecodeFloat EOF", + data: []byte{0x00}, // Only 1 byte, need 4 + decode: func(d *Decoder) error { + _, _, err := d.DecodeFloat() + return err + }, + }, + { + name: "DecodeDouble EOF", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, // Only 6 bytes, need 8 + decode: func(d *Decoder) error { + _, _, err := d.DecodeDouble() + return err + }, + }, + { + name: "DecodeBool EOF", + data: []byte{}, // Empty + decode: func(d *Decoder) error { + _, _, err := d.DecodeBool() + return err + }, + }, + { + name: "DecodeFixedOpaque EOF", + data: []byte{0x01, 0x02}, // Only 2 bytes, need 4 (size=3 padded) + decode: func(d *Decoder) error { + _, _, err := d.DecodeFixedOpaque(3) + return err + }, + }, + { + name: "DecodeFixedOpaqueInplace EOF", + data: []byte{0x01, 0x02, 0x03}, // Only 3 bytes, need 4 (padded) + decode: func(d *Decoder) error { + out := make([]byte, 3) + _, err := d.DecodeFixedOpaqueInplace(out) + return err + }, + }, + { + name: "DecodeOpaque EOF on length", + data: []byte{0x00, 0x00}, // Only 2 bytes for length prefix + decode: func(d *Decoder) error { + _, _, err := d.DecodeOpaque(0) + return err + }, + }, + { + name: "DecodeOpaque EOF on data", + data: []byte{0x00, 0x00, 0x00, 0x08, 0x01, 0x02}, // Length=8, only 2 data bytes + errCode: ErrOverflow, // length exceeds available data + decode: func(d *Decoder) error { + _, _, err := d.DecodeOpaque(0) + return err + }, + }, + { + name: "DecodeString EOF on length", + data: []byte{0x00}, // Only 1 byte for length prefix + decode: func(d *Decoder) error { + _, _, err := d.DecodeString(0) + return err + }, + }, + { + name: "DecodeString EOF on data", + data: []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e'}, // Length=5, only 2 chars + errCode: ErrOverflow, // length exceeds available data + decode: func(d *Decoder) error { + _, _, err := d.DecodeString(0) + return err + }, + }, + { + name: "Skip EOF", + data: []byte{0x01, 0x02}, + decode: func(d *Decoder) error { + return d.Skip(10) + }, + }, + } -// allTypesTest is used to allow testing of the Unmarshal function into struct -// fields of all supported types. -type allTypesTest struct { - A int8 - B uint8 - C int16 - D uint16 - E int32 - F uint32 - G int64 - H uint64 - I bool - J float32 - K float64 - L string - M []byte - N [3]byte - O []int16 - P [2]subTest - Q subTest - R map[string]uint32 - S time.Time -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + err := tt.decode(d) + if err == nil { + t.Fatal("expected error, got nil") + } -// opaqueStruct is used to test handling of uint8 slices and arrays. -type opaqueStruct struct { - Slice []uint8 `xdropaque:"false"` - Array [1]uint8 `xdropaque:"false"` + // Verify it's an UnmarshalError with expected error code + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + expectedCode := tt.errCode + if expectedCode == 0 { + expectedCode = ErrIO + } + if unmarshalErr.ErrorCode != expectedCode { + t.Errorf("expected %v, got %v", expectedCode, unmarshalErr.ErrorCode) + } + }) + } } -type AnEnum int32 - -func (e AnEnum) ValidEnum(v int32) bool { - return v < 3 -} +// TestDecoder_InvalidPadding tests that non-zero padding bytes cause errors +func TestDecoder_InvalidPadding(t *testing.T) { + tests := []struct { + name string + data []byte + decode func(d *Decoder) error + }{ + { + name: "DecodeFixedOpaque invalid padding", + data: []byte{0x01, 0x02, 0x03, 0xFF}, // 3 bytes data + non-zero padding + decode: func(d *Decoder) error { + _, _, err := d.DecodeFixedOpaque(3) + return err + }, + }, + { + name: "DecodeFixedOpaqueInplace invalid padding", + data: []byte{0x01, 0x02, 0x03, 0x01}, // 3 bytes data + non-zero padding + decode: func(d *Decoder) error { + out := make([]byte, 3) + _, err := d.DecodeFixedOpaqueInplace(out) + return err + }, + }, + { + name: "DecodeString invalid padding", + data: []byte{ + 0x00, 0x00, 0x00, 0x03, // length = 3 + 'a', 'b', 'c', 0x01, // "abc" + non-zero padding + }, + decode: func(d *Decoder) error { + _, _, err := d.DecodeString(0) + return err + }, + }, + } -type aUnion struct { - Type AnEnum - Data *int32 - Text *string `xdrmaxsize:"28"` -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + err := tt.decode(d) + if err == nil { + t.Fatal("expected error, got nil") + } -func (u aUnion) SwitchFieldName() string { - return "Type" + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrIO { + t.Errorf("expected ErrIO for padding error, got %v", unmarshalErr.ErrorCode) + } + }) + } } -func (u aUnion) ArmForSwitch(sw int32) (string, bool) { - switch sw { - case 0: - return "Data", true - case 1: - return "Text", true - case 2: // void - return "", true +// TestDecoder_InvalidBool tests that invalid boolean values cause errors +func TestDecoder_InvalidBool(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"bool value 2", []byte{0x00, 0x00, 0x00, 0x02}}, + {"bool value -1", []byte{0xFF, 0xFF, 0xFF, 0xFF}}, + {"bool value 100", []byte{0x00, 0x00, 0x00, 0x64}}, } - return "-", false -} - -type structWithPointer struct { - Data *string -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + _, _, err := d.DecodeBool() + if err == nil { + t.Fatal("expected error, got nil") + } -// testExpectedURet is a convenience method to test an expected number of bytes -// read and error for an unmarshal. -func testExpectedURet(t *testing.T, name string, n, wantN int, err, wantErr error) bool { - // First ensure the number of bytes read is the expected value. The - // byes read should be accurate even when an error occurs. - if n != wantN { - t.Errorf("%s: unexpected num bytes read - got: %v want: %v\n", - name, n, wantN) - return false + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrBadEnumValue { + t.Errorf("expected ErrBadEnumValue, got %v", unmarshalErr.ErrorCode) + } + }) } - - // Next check for the expected error. - return assertError(t, name, err, wantErr) } -func assertError(t *testing.T, name string, err, wantErr error) bool { - if reflect.TypeOf(err) != reflect.TypeOf(wantErr) { - t.Errorf("%s: failed to detect error - got: %v <%[2]T> want: %T", - name, err, wantErr) - return false +// TestDecoder_InvalidEnum tests that invalid enum values cause errors +func TestDecoder_InvalidEnum(t *testing.T) { + validEnums := map[int32]bool{ + 0: true, + 1: true, + 2: true, } - if rerr, ok := err.(*UnmarshalError); ok { - if werr, ok := wantErr.(*UnmarshalError); ok { - if rerr.ErrorCode != werr.ErrorCode { - t.Errorf("%s: failed to detect error code - "+ - "got: %v want: %v", name, - rerr.ErrorCode, werr.ErrorCode) - return false - } - } + tests := []struct { + name string + data []byte + }{ + {"enum value 3", []byte{0x00, 0x00, 0x00, 0x03}}, + {"enum value -1", []byte{0xFF, 0xFF, 0xFF, 0xFF}}, + {"enum value 100", []byte{0x00, 0x00, 0x00, 0x64}}, } - return true -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + _, _, err := d.DecodeEnum(validEnums) + if err == nil { + t.Fatal("expected error, got nil") + } -// TestUnmarshal ensures the Unmarshal function works properly with all types. -func TestUnmarshal(t *testing.T) { - // Variables for various unsupported Unmarshal types. - var nilInterface interface{} - var testChan chan int - var testFunc func() - var testComplex64 complex64 - var testComplex128 complex128 - - // structTestIn is input data for the big struct test of all supported - // types. - structTestIn := []byte{ - 0x00, 0x00, 0x00, 0x7F, // A - 0x00, 0x00, 0x00, 0xFF, // B - 0x00, 0x00, 0x7F, 0xFF, // C - 0x00, 0x00, 0xFF, 0xFF, // D - 0x7F, 0xFF, 0xFF, 0xFF, // E - 0xFF, 0xFF, 0xFF, 0xFF, // F - 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // G - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // H - 0x00, 0x00, 0x00, 0x01, // I - 0x40, 0x48, 0xF5, 0xC3, // J - 0x40, 0x09, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18, // K - 0x00, 0x00, 0x00, 0x03, 0x78, 0x64, 0x72, 0x00, // L - 0x00, 0x00, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // M - 0x01, 0x02, 0x03, 0x00, // N - 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x02, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, // O - 0x00, 0x00, 0x00, 0x03, 0x6F, 0x6E, 0x65, 0x00, // P[0].A - 0x00, 0x00, 0x00, 0x01, // P[0].B - 0x00, 0x00, 0x00, 0x03, 0x74, 0x77, 0x6F, 0x00, // P[1].A - 0x00, 0x00, 0x00, 0x02, // P[1].B - 0x00, 0x00, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, // Q.A - 0x00, 0x00, 0x00, 0x03, // Q.B - 0x00, 0x00, 0x00, 0x02, // R length - 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x31, // R key map1 - 0x00, 0x00, 0x00, 0x01, // R value map1 - 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x32, // R key map2 - 0x00, 0x00, 0x00, 0x02, // R value map2 - 0x00, 0x00, 0x00, 0x14, 0x32, 0x30, 0x31, 0x34, - 0x2d, 0x30, 0x34, 0x2d, 0x30, 0x34, 0x54, 0x30, - 0x33, 0x3a, 0x32, 0x34, 0x3a, 0x34, 0x38, 0x5a, // S - } - - // structTestWant is the expected output after unmarshalling - // structTestIn. - structTestWant := allTypesTest{ - 127, // A - 255, // B - 32767, // C - 65535, // D - 2147483647, // E - 4294967295, // F - 9223372036854775807, // G - 18446744073709551615, // H - true, // I - 3.14, // J - 3.141592653589793, // K - "xdr", // L - []byte{1, 2, 3, 4}, // M - [3]byte{1, 2, 3}, // N - []int16{512, 1024, 2048}, // O - [2]subTest{{"one", 1}, {"two", 2}}, // P - subTest{"bar", 3}, // Q - map[string]uint32{"map1": 1, "map2": 2}, // R - time.Unix(1396581888, 0).UTC(), // S + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrBadEnumValue { + t.Errorf("expected ErrBadEnumValue, got %v", unmarshalErr.ErrorCode) + } + }) } +} +// TestDecoder_MaxSizeExceeded tests that exceeding max size causes errors +func TestDecoder_MaxSizeExceeded(t *testing.T) { tests := []struct { - in []byte // input bytes - wantVal interface{} // expected value - wantN int // expected number of bytes read - err error // expected error + name string + data []byte + maxSize int + decode func(d *Decoder, maxSize int) error }{ - // int8 - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, int8(0), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x40}, int8(64), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x7F}, int8(127), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, int8(-1), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0x80}, int8(-128), 4, nil}, - // Expected Failures -- 128, -129 overflow int8 and not enough - // bytes - {[]byte{0x00, 0x00, 0x00, 0x80}, int8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0xFF, 0x7F}, int8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, int8(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint8 - XDR Unsigned Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, uint8(0), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x40}, uint8(64), 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0xFF}, uint8(255), 4, nil}, - // Expected Failures -- 256, -1 overflow uint8 and not enough - // bytes - {[]byte{0x00, 0x00, 0x01, 0x00}, uint8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, uint8(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, uint8(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // int16 - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, int16(0), 4, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00}, int16(1024), 4, nil}, - {[]byte{0x00, 0x00, 0x7F, 0xFF}, int16(32767), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, int16(-1), 4, nil}, - {[]byte{0xFF, 0xFF, 0x80, 0x00}, int16(-32768), 4, nil}, - // Expected Failures -- 32768, -32769 overflow int16 and not - // enough bytes - {[]byte{0x00, 0x00, 0x80, 0x00}, int16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0x7F, 0xFF}, int16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, uint16(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint16 - XDR Unsigned Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, uint16(0), 4, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00}, uint16(1024), 4, nil}, - {[]byte{0x00, 0x00, 0xFF, 0xFF}, uint16(65535), 4, nil}, - // Expected Failures -- 65536, -1 overflow uint16 and not enough - // bytes - {[]byte{0x00, 0x01, 0x00, 0x00}, uint16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, uint16(0), 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00}, uint16(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // int32 - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, int32(0), 4, nil}, - {[]byte{0x00, 0x04, 0x00, 0x00}, int32(262144), 4, nil}, - {[]byte{0x7F, 0xFF, 0xFF, 0xFF}, int32(2147483647), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, int32(-1), 4, nil}, - {[]byte{0x80, 0x00, 0x00, 0x00}, int32(-2147483648), 4, nil}, - // Expected Failure -- not enough bytes - {[]byte{0x00, 0x00, 0x00}, int32(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint32 - XDR Unsigned Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, uint32(0), 4, nil}, - {[]byte{0x00, 0x04, 0x00, 0x00}, uint32(262144), 4, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, uint32(4294967295), 4, nil}, - // Expected Failure -- not enough bytes - {[]byte{0x00, 0x00, 0x00}, uint32(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // int64 - XDR Hyper Integer - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(0), 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, int64(1 << 34), 8, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(1 << 42), 8, nil}, - {[]byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(9223372036854775807), 8, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(-1), 8, nil}, - {[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(-9223372036854775808), 8, nil}, - // Expected Failures -- not enough bytes - {[]byte{0x7f, 0xff, 0xff}, int64(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x7f, 0x00, 0xff, 0x00}, int64(0), 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // uint64 - XDR Unsigned Hyper Integer - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(0), 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 34), 8, nil}, - {[]byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 42), 8, nil}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, uint64(18446744073709551615), 8, nil}, - {[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(9223372036854775808), 8, nil}, - // Expected Failures -- not enough bytes - {[]byte{0xff, 0xff, 0xff}, uint64(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xff, 0x00, 0xff, 0x00}, uint64(0), 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // bool - XDR Integer - {[]byte{0x00, 0x00, 0x00, 0x00}, false, 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x01}, true, 4, nil}, - // Expected Failures -- only 0 or 1 is a valid bool - {[]byte{0x01, 0x00, 0x00, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {[]byte{0x00, 0x00, 0x40, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - - // float32 - XDR Floating-Point - {[]byte{0x00, 0x00, 0x00, 0x00}, float32(0), 4, nil}, - {[]byte{0x40, 0x48, 0xF5, 0xC3}, float32(3.14), 4, nil}, - {[]byte{0x49, 0x96, 0xB4, 0x38}, float32(1234567.0), 4, nil}, - {[]byte{0xFF, 0x80, 0x00, 0x00}, float32(math.Inf(-1)), 4, nil}, - {[]byte{0x7F, 0x80, 0x00, 0x00}, float32(math.Inf(0)), 4, nil}, - // Expected Failures -- not enough bytes - {[]byte{0xff, 0xff}, float32(0), 2, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xff, 0x00, 0xff}, float32(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - - // float64 - XDR Double-precision Floating-Point - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(0), 8, nil}, - {[]byte{0x40, 0x09, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18}, float64(3.141592653589793), 8, nil}, - {[]byte{0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(-1)), 8, nil}, - {[]byte{0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(0)), 8, nil}, - // Expected Failures -- not enough bytes - {[]byte{0xff, 0xff, 0xff}, float64(0), 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xff, 0x00, 0xff, 0x00}, float64(0), 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // string - XDR String - {[]byte{0x00, 0x00, 0x00, 0x00}, "", 4, nil}, - {[]byte{0x00, 0x00, 0x00, 0x03, 0x78, 0x64, 0x72, 0x00}, "xdr", 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x06, 0xCF, 0x84, 0x3D, 0x32, 0xCF, 0x80, 0x00, 0x00}, "τ=2π", 12, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and len larger than available bytes. - {[]byte{0x00, 0x00, 0xFF}, "", 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // []byte - XDR Variable Opaque - {[]byte{0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00}, []byte{0x01}, 8, nil}, - {[]byte{0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03, 0x00}, []byte{0x01, 0x02, 0x03}, 8, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and data larger than available bytes. - {[]byte{0x00, 0x00, 0xFF}, []byte{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0xFF, 0xFF, 0xFF, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // [#]byte - XDR Fixed Opaque - {[]byte{0x01, 0x00, 0x00, 0x00}, [1]byte{0x01}, 4, nil}, - {[]byte{0x01, 0x02, 0x00, 0x00}, [2]byte{0x01, 0x02}, 4, nil}, - {[]byte{0x01, 0x02, 0x03, 0x00}, [3]byte{0x01, 0x02, 0x03}, 4, nil}, - {[]byte{0x01, 0x02, 0x03, 0x04}, [4]byte{0x01, 0x02, 0x03, 0x04}, 4, nil}, - {[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x00}, [5]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 8, nil}, - // Expected Failure -- fixed opaque data not padded - {[]byte{0x01}, [1]byte{}, 1, &UnmarshalError{ErrorCode: ErrIO}}, - - // [] - XDR Variable-Length Array - {[]byte{0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00}, - []int16{512, 1024, 2048}, 16, nil}, - {[]byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, []bool{true, false}, 12, nil}, - // Expected Failure -- 2 entries in array - not enough bytes - {[]byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}, []bool{}, 8, &UnmarshalError{ErrorCode: ErrIO}}, - - // [#] - XDR Fixed-Length Array - {[]byte{0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00}, [2]uint32{512, 1024}, 8, nil}, - // Expected Failure -- 2 entries in array - not enough bytes - {[]byte{0x00, 0x00, 0x00, 0x02}, [2]uint32{}, 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // map[string]uint32 - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x31, 0x00, 0x00, 0x00, 0x01}, - map[string]uint32{"map1": 1}, 16, nil}, - // Expected Failures -- not enough bytes in length, 1 map - // element no extra bytes, 1 map element not enough bytes for - // key, 1 map element not enough bytes for value. - {[]byte{0x00, 0x00, 0x00}, map[string]uint32{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x01}, map[string]uint32{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, map[string]uint32{}, 7, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x70, 0x31}, - map[string]uint32{}, 12, &UnmarshalError{ErrorCode: ErrIO}}, - - // time.Time - XDR String per RFC3339 - {[]byte{ - 0x00, 0x00, 0x00, 0x14, 0x32, 0x30, 0x31, 0x34, - 0x2d, 0x30, 0x34, 0x2d, 0x30, 0x34, 0x54, 0x30, - 0x33, 0x3a, 0x32, 0x34, 0x3a, 0x34, 0x38, 0x5a, - }, time.Unix(1396581888, 0).UTC(), 24, nil}, - // Expected Failures -- not enough bytes, improperly formatted - // time - {[]byte{0x00, 0x00, 0x00}, time.Time{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x00}, time.Time{}, 4, &UnmarshalError{ErrorCode: ErrParseTime}}, - - // struct - XDR Structure -- test struct contains all supported types - {structTestIn, structTestWant, len(structTestIn), nil}, - {[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02}, - opaqueStruct{[]uint8{1}, [1]uint8{2}}, 12, nil}, - // Expected Failures -- normal struct not enough bytes, non - // opaque data not enough bytes for slice, non opaque data not - // enough bytes for slice. - {[]byte{0x00, 0x00}, allTypesTest{}, 2, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00}, opaqueStruct{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {[]byte{0x00, 0x00, 0x00, 0x00, 0x00}, opaqueStruct{}, 5, &UnmarshalError{ErrorCode: ErrIO}}, - - // Expected errors - {nil, nilInterface, 0, &UnmarshalError{ErrorCode: ErrNilInterface}}, - {nil, &nilInterface, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testChan, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testChan, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testFunc, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testFunc, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testComplex64, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testComplex64, 0, &UnmarshalError{ErrorCode: ErrIO}}, - {nil, testComplex128, 0, &UnmarshalError{ErrorCode: ErrUnsupportedType}}, - {nil, &testComplex128, 0, &UnmarshalError{ErrorCode: ErrIO}}, - } - - for i, test := range tests { - // Attempt to unmarshal to a non-pointer version of each - // positive test type to ensure the appropriate error is - // returned. - if test.err == nil && test.wantVal != nil { - testName := fmt.Sprintf("Unmarshal #%d (non-pointer)", i) - wantErr := &UnmarshalError{ErrorCode: ErrBadArguments} - - wvt := reflect.TypeOf(test.wantVal) - want := reflect.New(wvt).Elem().Interface() - n, err := Unmarshal(bytes.NewReader(test.in), want) - if !testExpectedURet(t, testName, n, 0, err, wantErr) { - continue + { + name: "DecodeOpaque exceeds maxSize", + data: []byte{0x00, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}, + maxSize: 8, // Only allow 8 bytes + decode: func(d *Decoder, maxSize int) error { + _, _, err := d.DecodeOpaque(maxSize) + return err + }, + }, + { + name: "DecodeString exceeds maxSize", + data: []byte{0x00, 0x00, 0x00, 0x08, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'}, + maxSize: 4, // Only allow 4 chars + decode: func(d *Decoder, maxSize int) error { + _, _, err := d.DecodeString(maxSize) + return err + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDecoder(tt.data) + err := tt.decode(d, tt.maxSize) + if err == nil { + t.Fatal("expected error, got nil") + } + + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected *UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrOverflow { + t.Errorf("expected ErrOverflow, got %v", unmarshalErr.ErrorCode) } + }) + } +} + +// TestDecoder_EmptyCases tests edge cases with empty/zero values +func TestDecoder_EmptyCases(t *testing.T) { + t.Run("DecodeFixedOpaque size=0", func(t *testing.T) { + d := NewDecoder([]byte{}) + data, n, err := d.DecodeFixedOpaque(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) } + if len(data) != 0 { + t.Errorf("expected empty slice, got %v", data) + } + if n != 0 { + t.Errorf("expected 0 bytes read, got %d", n) + } + }) - testName := fmt.Sprintf("Unmarshal #%d", i) - // Create a new pointer to the appropriate type. - var want interface{} - if test.wantVal != nil { - wvt := reflect.TypeOf(test.wantVal) - want = reflect.New(wvt).Interface() + t.Run("DecodeFixedOpaqueInplace size=0", func(t *testing.T) { + d := NewDecoder([]byte{}) + n, err := d.DecodeFixedOpaqueInplace([]byte{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - n, err := Unmarshal(bytes.NewReader(test.in), want) + if n != 0 { + t.Errorf("expected 0 bytes read, got %d", n) + } + }) - // First ensure the number of bytes read is the expected value - // and the error is the expected one. - if !testExpectedURet(t, testName, n, test.wantN, err, test.err) { - continue + t.Run("DecodeString empty", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x00}) // length = 0 + s, n, err := d.DecodeString(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "" { + t.Errorf("expected empty string, got %q", s) } - if test.err != nil { - continue + if n != 4 { + t.Errorf("expected 4 bytes read (length prefix), got %d", n) } + }) +} - // Finally, ensure the read value is the expected one. - wantElem := reflect.Indirect(reflect.ValueOf(want)).Interface() - if !reflect.DeepEqual(wantElem, test.wantVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, wantElem, test.wantVal) - continue +// TestDecoder_SuccessfulDecodes tests that valid data decodes correctly +func TestDecoder_SuccessfulDecodes(t *testing.T) { + t.Run("DecodeInt", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x2A}) // 42 + v, n, err := d.DecodeInt() + if err != nil { + t.Fatalf("unexpected error: %v", err) } - } + if v != 42 { + t.Errorf("expected 42, got %d", v) + } + if n != 4 { + t.Errorf("expected 4 bytes, got %d", n) + } + }) - // successful enum decoding - var anEnum AnEnum - _, err := Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x01}), &anEnum) + t.Run("DecodeInt negative", func(t *testing.T) { + d := NewDecoder([]byte{0xFF, 0xFF, 0xFF, 0xFE}) // -2 + v, _, err := d.DecodeInt() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != -2 { + t.Errorf("expected -2, got %d", v) + } + }) - if err != nil { - t.Errorf("enum decode: expected no error, got: %v\n", err) - } + t.Run("DecodeBool true", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x01}) + v, _, err := d.DecodeBool() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !v { + t.Error("expected true") + } + }) - if anEnum != AnEnum(1) { - t.Errorf("enum decode: expected 1, got: %v\n", anEnum) - } + t.Run("DecodeBool false", func(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x00}) + v, _, err := d.DecodeBool() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v { + t.Error("expected false") + } + }) - // failed enum decoding - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x03}), &anEnum) + t.Run("DecodeEnum valid", func(t *testing.T) { + validEnums := map[int32]bool{1: true, 2: true, 3: true} + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x02}) + v, _, err := d.DecodeEnum(validEnums) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v != 2 { + t.Errorf("expected 2, got %d", v) + } + }) - if err == nil { - t.Errorf("enum decode: expected error, got none") - } + t.Run("DecodeString with padding", func(t *testing.T) { + // "abc" = 3 bytes, needs 1 byte padding + d := NewDecoder([]byte{ + 0x00, 0x00, 0x00, 0x03, // length = 3 + 'a', 'b', 'c', 0x00, // data + padding + }) + s, n, err := d.DecodeString(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "abc" { + t.Errorf("expected 'abc', got %q", s) + } + if n != 8 { // 4 (length) + 4 (padded data) + t.Errorf("expected 8 bytes, got %d", n) + } + }) - // union decoding - var u aUnion - // void arm - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x02}), &u) - if err != nil { - t.Errorf("union decode: expected no error, got: %v\n", err) - } + t.Run("DecodeFixedOpaque with padding", func(t *testing.T) { + // 5 bytes of data, needs 3 bytes padding + d := NewDecoder([]byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x00, + }) + data, n, err := d.DecodeFixedOpaque(5) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(data) != 5 { + t.Errorf("expected 5 bytes, got %d", len(data)) + } + if n != 8 { // 5 bytes padded to 8 + t.Errorf("expected 8 bytes read, got %d", n) + } + }) +} - if u.Type != AnEnum(2) { - t.Errorf("union decode: expected type == 2, got: %v\n", u.Type) +// TestDecoder_Position tests position tracking +func TestDecoder_Position(t *testing.T) { + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // int32 = 1 + 0x00, 0x00, 0x00, 0x02, // int32 = 2 } + d := NewDecoder(data) - if u.Data != nil { - t.Errorf("union decode: expected data to be nil, it was not.") + if d.Position() != 0 { + t.Errorf("expected position 0, got %d", d.Position()) } - - // non-void arm - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05}), &u) - if err != nil { - t.Errorf("union decode: expected no error, got: %v\n", err) + if d.Remaining() != 8 { + t.Errorf("expected 8 remaining, got %d", d.Remaining()) } - if u.Type != AnEnum(0) { - t.Errorf("union decode: expected type == 0, got: %v\n", u.Type) + d.DecodeInt() + if d.Position() != 4 { + t.Errorf("expected position 4, got %d", d.Position()) } - - if u.Data == nil { - t.Errorf("union decode: expected data to be filled, it was not.") + if d.Remaining() != 4 { + t.Errorf("expected 4 remaining, got %d", d.Remaining()) } - if *u.Data != 5 { - t.Errorf("union decode: expected data to be 5, got: %v\n", *u.Data) + d.DecodeInt() + if d.Position() != 8 { + t.Errorf("expected position 8, got %d", d.Position()) } - - // non-void arm: xdrmaxsize - _, err = Unmarshal(bytes.NewReader([]byte{ - 0x00, 0x00, 0x00, 0x01, // Text - 0x00, 0x00, 0x00, 0x1D, // String length = 29 - 0x74, 0x65, 0x73, 0x74, // "test" 4 - 0x74, 0x65, 0x73, 0x74, // "test" 8 - 0x74, 0x65, 0x73, 0x74, // "test" 12 - 0x74, 0x65, 0x73, 0x74, // "test" 16 - 0x74, 0x65, 0x73, 0x74, // "test" 20 - 0x74, 0x65, 0x73, 0x74, // "test" 24 - 0x74, 0x65, 0x73, 0x74, // "test" 28 - 0x74, 0x00, 0x00, 0x00, // "test" 29 - }), &u) - if err == nil { - t.Errorf("union decode: expected error") + if d.Remaining() != 0 { + t.Errorf("expected 0 remaining, got %d", d.Remaining()) } +} + +// TestDecoder_Reset tests the Reset method +func TestDecoder_Reset(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x01}) + d.DecodeInt() - if err.Error() != "xdr:DecodeString: data exceeds max slice limit - read: '29'" { - t.Errorf("union decode: expected 'data exceeds max slice limit' error") + if d.Position() != 4 { + t.Errorf("expected position 4, got %d", d.Position()) } - // invalid enum for switch - _, err = Unmarshal(bytes.NewReader([]byte{0x00, 0x00, 0x00, 0x03}), &u) - if err == nil { - t.Errorf("union decode: expected error, got nil") + // Reset with new data + d.Reset([]byte{0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03}) + if d.Position() != 0 { + t.Errorf("expected position 0 after reset, got %d", d.Position()) + } + if d.Remaining() != 8 { + t.Errorf("expected 8 remaining after reset, got %d", d.Remaining()) } - // invalid arm for switch - _, err = Unmarshal(bytes.NewReader([]byte{0xFF, 0xFF, 0xFF, 0xFF}), &u) - if err == nil { - t.Errorf("union decode: expected error, got nil") + v, _, _ := d.DecodeInt() + if v != 2 { + t.Errorf("expected 2, got %d", v) } } -// decodeFunc is used to identify which public function of the Decoder object -// a test applies to. -type decodeFunc int - -const ( - fDecodeBool decodeFunc = iota - fDecodeDouble - fDecodeEnum - fDecodeFixedOpaque - fDecodeFloat - fDecodeHyper - fDecodeInt - fDecodeOpaque - fDecodeString - fDecodeUhyper - fDecodeUint -) - -// Map of decodeFunc values to names for pretty printing. -var decodeFuncStrings = map[decodeFunc]string{ - fDecodeBool: "DecodeBool", - fDecodeDouble: "DecodeDouble", - fDecodeEnum: "DecodeEnum", - fDecodeFixedOpaque: "DecodeFixedOpaque", - fDecodeFloat: "DecodeFloat", - fDecodeHyper: "DecodeHyper", - fDecodeInt: "DecodeInt", - fDecodeOpaque: "DecodeOpaque", - fDecodeString: "DecodeString", - fDecodeUhyper: "DecodeUhyper", - fDecodeUint: "DecodeUint", +// testDecoderFromType is a test type that implements DecoderFrom +type testDecoderFromType struct { + Value int32 } -// String implements the fmt.Stringer interface and returns the encode function -// as a human-readable string. -func (f decodeFunc) String() string { - if s := decodeFuncStrings[f]; s != "" { - return s +func (t *testDecoderFromType) DecodeFrom(d *Decoder) (int, error) { + v, n, err := d.DecodeInt() + if err != nil { + return n, err } - return fmt.Sprintf("Unknown decodeFunc (%d)", f) + t.Value = v + return n, nil } -// TestDecoder ensures a Decoder works as intended. -func TestDecoder(t *testing.T) { - tests := []struct { - f decodeFunc // function to use to decode - in []byte // input bytes - wantVal interface{} // expected value - wantN int // expected number of bytes read - err error // expected error - }{ - // Bool - {fDecodeBool, []byte{0x00, 0x00, 0x00, 0x00}, false, 4, nil}, - {fDecodeBool, []byte{0x00, 0x00, 0x00, 0x01}, true, 4, nil}, - // Expected Failures -- only 0 or 1 is a valid bool - {fDecodeBool, []byte{0x01, 0x00, 0x00, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {fDecodeBool, []byte{0x00, 0x00, 0x40, 0x00}, true, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - - // Double - {fDecodeDouble, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(0), 8, nil}, - {fDecodeDouble, []byte{0x40, 0x09, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18}, float64(3.141592653589793), 8, nil}, - {fDecodeDouble, []byte{0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(-1)), 8, nil}, - {fDecodeDouble, []byte{0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, float64(math.Inf(0)), 8, nil}, - - // Enum - {fDecodeEnum, []byte{0x00, 0x00, 0x00, 0x00}, int32(0), 4, nil}, - {fDecodeEnum, []byte{0x00, 0x00, 0x00, 0x01}, int32(1), 4, nil}, - {fDecodeEnum, []byte{0x00, 0x00, 0x00, 0x02}, nil, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {fDecodeEnum, []byte{0x12, 0x34, 0x56, 0x78}, nil, 4, &UnmarshalError{ErrorCode: ErrBadEnumValue}}, - {fDecodeEnum, []byte{0x00}, nil, 1, &UnmarshalError{ErrorCode: ErrIO}}, - - // FixedOpaque - {fDecodeFixedOpaque, []byte{0x01, 0x00, 0x00, 0x00}, []byte{0x01}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x00, 0x00}, []byte{0x01, 0x02}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x03, 0x00}, []byte{0x01, 0x02, 0x03}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 4, nil}, - {fDecodeFixedOpaque, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x00}, []byte{0x01, 0x02, 0x03, 0x04, 0x05}, 8, nil}, - // Expected Failure -- fixed opaque data not padded - {fDecodeFixedOpaque, []byte{0x01}, []byte{0x00}, 1, &UnmarshalError{ErrorCode: ErrIO}}, - - // Float - {fDecodeFloat, []byte{0x00, 0x00, 0x00, 0x00}, float32(0), 4, nil}, - {fDecodeFloat, []byte{0x40, 0x48, 0xF5, 0xC3}, float32(3.14), 4, nil}, - {fDecodeFloat, []byte{0x49, 0x96, 0xB4, 0x38}, float32(1234567.0), 4, nil}, - {fDecodeFloat, []byte{0xFF, 0x80, 0x00, 0x00}, float32(math.Inf(-1)), 4, nil}, - {fDecodeFloat, []byte{0x7F, 0x80, 0x00, 0x00}, float32(math.Inf(0)), 4, nil}, - - // Hyper - {fDecodeHyper, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(0), 8, nil}, - {fDecodeHyper, []byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, int64(1 << 34), 8, nil}, - {fDecodeHyper, []byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(1 << 42), 8, nil}, - {fDecodeHyper, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(9223372036854775807), 8, nil}, - {fDecodeHyper, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, int64(-1), 8, nil}, - {fDecodeHyper, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, int64(-9223372036854775808), 8, nil}, - - // Int - {fDecodeInt, []byte{0x00, 0x00, 0x00, 0x00}, int32(0), 4, nil}, - {fDecodeInt, []byte{0x00, 0x04, 0x00, 0x00}, int32(262144), 4, nil}, - {fDecodeInt, []byte{0x7F, 0xFF, 0xFF, 0xFF}, int32(2147483647), 4, nil}, - {fDecodeInt, []byte{0xFF, 0xFF, 0xFF, 0xFF}, int32(-1), 4, nil}, - {fDecodeInt, []byte{0x80, 0x00, 0x00, 0x00}, int32(-2147483648), 4, nil}, - - // Opaque - {fDecodeOpaque, []byte{0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00}, []byte{0x01}, 8, nil}, - {fDecodeOpaque, []byte{0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03, 0x00}, []byte{0x01, 0x02, 0x03}, 8, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and data larger than available bytes. - {fDecodeOpaque, []byte{0x00, 0x00, 0xFF}, []byte{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, - {fDecodeOpaque, []byte{0xFF, 0xFF, 0xFF, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {fDecodeOpaque, []byte{0x7F, 0xFF, 0xFF, 0xFD}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {fDecodeOpaque, []byte{0x00, 0x00, 0x00, 0xFF}, []byte{}, 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // String - {fDecodeString, []byte{0x00, 0x00, 0x00, 0x00}, "", 4, nil}, - {fDecodeString, []byte{0x00, 0x00, 0x00, 0x03, 0x78, 0x64, 0x72, 0x00}, "xdr", 8, nil}, - {fDecodeString, []byte{0x00, 0x00, 0x00, 0x06, 0xCF, 0x84, 0x3D, 0x32, 0xCF, 0x80, 0x00, 0x00}, "τ=2π", 12, nil}, - // Expected Failures -- not enough bytes for length, length - // larger than allowed, and len larger than available bytes. - {fDecodeString, []byte{0x00, 0x00, 0xFF}, "", 3, &UnmarshalError{ErrorCode: ErrIO}}, - {fDecodeString, []byte{0xFF, 0xFF, 0xFF, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrOverflow}}, - {fDecodeString, []byte{0x00, 0x00, 0x00, 0xFF}, "", 4, &UnmarshalError{ErrorCode: ErrIO}}, - - // Uhyper - {fDecodeUhyper, []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(0), 8, nil}, - {fDecodeUhyper, []byte{0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 34), 8, nil}, - {fDecodeUhyper, []byte{0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(1 << 42), 8, nil}, - {fDecodeUhyper, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, uint64(18446744073709551615), 8, nil}, - {fDecodeUhyper, []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, uint64(9223372036854775808), 8, nil}, - - // Uint - {fDecodeUint, []byte{0x00, 0x00, 0x00, 0x00}, uint32(0), 4, nil}, - {fDecodeUint, []byte{0x00, 0x04, 0x00, 0x00}, uint32(262144), 4, nil}, - {fDecodeUint, []byte{0xFF, 0xFF, 0xFF, 0xFF}, uint32(4294967295), 4, nil}, - } - - validEnums := make(map[int32]bool) - validEnums[0] = true - validEnums[1] = true - - var rv interface{} - var n int - var err error - - for i, test := range tests { - err = nil - dec := NewDecoder(bytes.NewReader(test.in)) - switch test.f { - case fDecodeBool: - rv, n, err = dec.DecodeBool() - case fDecodeDouble: - rv, n, err = dec.DecodeDouble() - case fDecodeEnum: - rv, n, err = dec.DecodeEnum(validEnums) - case fDecodeFixedOpaque: - want := test.wantVal.([]byte) - rv, n, err = dec.DecodeFixedOpaque(int32(len(want))) - case fDecodeFloat: - rv, n, err = dec.DecodeFloat() - case fDecodeHyper: - rv, n, err = dec.DecodeHyper() - case fDecodeInt: - rv, n, err = dec.DecodeInt() - case fDecodeOpaque: - rv, n, err = dec.DecodeOpaque(0) - case fDecodeString: - rv, n, err = dec.DecodeString(0) - case fDecodeUhyper: - rv, n, err = dec.DecodeUhyper() - case fDecodeUint: - rv, n, err = dec.DecodeUint() - default: - t.Errorf("%v #%d unrecognized function", test.f, i) - continue - } - - // First ensure the number of bytes read is the expected value - // and the error is the expected one. - testName := fmt.Sprintf("%v #%d", test.f, i) - if !testExpectedURet(t, testName, n, test.wantN, err, test.err) { - continue - } - if test.err != nil { - continue - } - - // Finally, ensure the read value is the expected one. - if !reflect.DeepEqual(rv, test.wantVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, rv, test.wantVal) - continue - } - } +// testNestedType is a test type that tracks maxDepth to verify it's accessible +type testNestedType struct { + ReceivedMaxDepth uint + Value int32 } -// TestUnmarshalCorners ensures the Unmarshal function properly handles various -// cases not already covered by the other tests. -func TestUnmarshalCorners(t *testing.T) { - buf := []byte{ - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x02, - } - - // Ensure unmarshal to unsettable pointer returns the expected error. - testName := "Unmarshal to unsettable pointer" - var i32p *int32 - expectedN := 0 - expectedErr := error(&UnmarshalError{ErrorCode: ErrNotSettable}) - n, err := Unmarshal(bytes.NewReader(buf), i32p) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure decode of unsettable pointer returns the expected error. - testName = "Decode to unsettable pointer" - expectedN = 4 - expectedErr = &UnmarshalError{ErrorCode: ErrNotSettable} - n, err = TstDecode(bytes.NewReader(buf))(reflect.ValueOf(i32p), 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to indirected unsettable pointer returns the - // expected error. - testName = "Unmarshal to indirected unsettable pointer" - ii32p := interface{}(i32p) - expectedN = 0 - expectedErr = &UnmarshalError{ErrorCode: ErrNotSettable} - n, err = Unmarshal(bytes.NewReader(buf), &ii32p) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to embedded unsettable interface value returns the - // expected error. - testName = "Unmarshal to embedded unsettable interface value" - var i32 int32 - ii32 := interface{}(i32) - expectedN = 0 - expectedErr = &UnmarshalError{ErrorCode: ErrNotSettable} - n, err = Unmarshal(bytes.NewReader(buf), &ii32) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to embedded interface value works properly. - testName = "Unmarshal to embedded interface value" - ii32vp := interface{}(&i32) - expectedN = 4 - expectedErr = nil - ii32vpr := int32(1) - expectedVal := interface{}(&ii32vpr) - n, err = Unmarshal(bytes.NewReader(buf), &ii32vp) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(ii32vp, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, ii32vp, expectedVal) - } - } - - // Ensure decode of an invalid reflect value returns the expected - // error. - testName = "Decode invalid reflect value" - expectedN = 0 - expectedErr = error(&UnmarshalError{ErrorCode: ErrUnsupportedType}) - n, err = TstDecode(bytes.NewReader(buf))(reflect.Value{}, 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure unmarshal to a slice with a cap and 0 length adjusts the - // length properly. - testName = "Unmarshal to capped slice" - cappedSlice := make([]bool, 0, 1) - expectedN = 8 - expectedErr = nil - expectedVal = []bool{true} - n, err = Unmarshal(bytes.NewReader(buf), &cappedSlice) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(cappedSlice, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, cappedSlice, expectedVal) - } - } - - // Ensure decode to a slice retuns expected number of elements - testName = "Unmarshal to oversized slice" - oversizedSlice := make([]bool, 2, 2) - expectedN = 8 - expectedErr = nil - expectedVal = []bool{true} - n, err = Unmarshal(bytes.NewReader(buf), &oversizedSlice) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(oversizedSlice, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, oversizedSlice, expectedVal) - } - } - - // Ensure unmarshal to struct with both exported and unexported fields - // skips the unexported fields but still unmarshals to the exported - // fields. - type unexportedStruct struct { - unexported int - Exported int - } - testName = "Unmarshal to struct with exported and unexported fields" - var tstruct unexportedStruct - expectedN = 4 - expectedErr = nil - expectedVal = unexportedStruct{0, 1} - n, err = Unmarshal(bytes.NewReader(buf), &tstruct) - if testExpectedURet(t, testName, n, expectedN, err, expectedErr) { - if !reflect.DeepEqual(tstruct, expectedVal) { - t.Errorf("%s: unexpected result - got: %v want: %v\n", - testName, tstruct, expectedVal) - } - } - - // Ensure decode to struct with unsettable fields return expected error. - type unsettableStruct struct { - Exported int - } - testName = "Decode to struct with unsettable fields" - var ustruct unsettableStruct - expectedN = 0 - expectedErr = error(&UnmarshalError{ErrorCode: ErrNotSettable}) - n, err = TstDecode(bytes.NewReader(buf))(reflect.ValueOf(ustruct), 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) - - // Ensure decode to struct with unsettable pointer fields return - // expected error. - type unsettablePointerStruct struct { - Exported *int - } - testName = "Decode to struct with unsettable pointer fields" - var upstruct unsettablePointerStruct - expectedN = 0 - expectedErr = error(&UnmarshalError{ErrorCode: ErrNotSettable}) - n, err = TstDecode(bytes.NewReader(buf))(reflect.ValueOf(upstruct), 0, DecodeDefaultMaxDepth) - testExpectedURet(t, testName, n, expectedN, err, expectedErr) +func (t *testNestedType) DecodeFrom(d *Decoder) (int, error) { + t.ReceivedMaxDepth = d.MaxDepth() + v, n, err := d.DecodeInt() + if err != nil { + return n, err + } + t.Value = v + return n, nil } -type String32 string +// TestUnmarshal tests the Unmarshal function with a type implementing DecoderFrom +func TestUnmarshal(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x2A} // 42 in big-endian -var _ Sized = String32("hello") + var result testDecoderFromType + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result.Value != 42 { + t.Errorf("expected value 42, got %d", result.Value) + } +} -func (s String32) XDRMaxSize() int { - return 32 +// TestUnmarshal_ReflectionFallback tests Unmarshal with types that don't implement DecoderFrom +// but can be decoded via reflection +func TestUnmarshal_ReflectionFallback(t *testing.T) { + // int32 doesn't implement DecoderFrom but can be decoded via reflection + data := []byte{0x00, 0x00, 0x00, 0x01} + var result int32 + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("expected reflection decode to work, got error: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result != 1 { + t.Errorf("expected result 1, got %d", result) + } } -func TestSizedType(t *testing.T) { - cases := map[string]struct { - input []byte - out string - err error +// TestUnmarshal_ReflectionAllTypes tests reflection-based decoding for all primitive types +func TestUnmarshal_ReflectionAllTypes(t *testing.T) { + tests := []struct { + name string + data []byte + dest interface{} + expected interface{} + wantN int }{ - // works for 0 length - "0 length": {[]byte{0x00, 0x00, 0x00, 0x00}, "", nil}, - // works for 1 length - "1 length": {[]byte{0x00, 0x00, 0x00, 0x01, 0x48, 0x00, 0x00, 0x00}, "H", nil}, - // works for 32 length - "32 length": { - []byte{ - 0x00, 0x00, 0x00, 0x20, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - }, - "HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH", - nil, + { + name: "int8", + data: []byte{0x00, 0x00, 0x00, 0x7F}, + dest: new(int8), + expected: int8(127), + wantN: 4, }, - // fails for 33 length - "33 length": { - []byte{ - 0x00, 0x00, 0x00, 0x21, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x00, 0x00, 0x00, - }, - "", - &UnmarshalError{ErrorCode: ErrOverflow}, + { + name: "int16", + data: []byte{0x00, 0x00, 0x7F, 0xFF}, + dest: new(int16), + expected: int16(32767), + wantN: 4, + }, + { + name: "int32", + data: []byte{0x7F, 0xFF, 0xFF, 0xFF}, + dest: new(int32), + expected: int32(2147483647), + wantN: 4, + }, + { + name: "int64", + data: []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + dest: new(int64), + expected: int64(9223372036854775807), + wantN: 8, + }, + { + name: "uint8", + data: []byte{0x00, 0x00, 0x00, 0xFF}, + dest: new(uint8), + expected: uint8(255), + wantN: 4, + }, + { + name: "uint16", + data: []byte{0x00, 0x00, 0xFF, 0xFF}, + dest: new(uint16), + expected: uint16(65535), + wantN: 4, + }, + { + name: "uint32", + data: []byte{0xFF, 0xFF, 0xFF, 0xFF}, + dest: new(uint32), + expected: uint32(4294967295), + wantN: 4, + }, + { + name: "uint64", + data: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + dest: new(uint64), + expected: uint64(18446744073709551615), + wantN: 8, + }, + { + name: "bool true", + data: []byte{0x00, 0x00, 0x00, 0x01}, + dest: new(bool), + expected: true, + wantN: 4, + }, + { + name: "bool false", + data: []byte{0x00, 0x00, 0x00, 0x00}, + dest: new(bool), + expected: false, + wantN: 4, + }, + { + name: "string", + data: []byte{0x00, 0x00, 0x00, 0x03, 'x', 'd', 'r', 0x00}, + dest: new(string), + expected: "xdr", + wantN: 8, + }, + { + name: "[]byte opaque", + data: []byte{0x00, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03, 0x00}, + dest: new([]byte), + expected: []byte{0x01, 0x02, 0x03}, + wantN: 8, }, } - for name, kase := range cases { - var out String32 - r := bytes.NewReader(kase.input) - _, err := Unmarshal(r, &out) - - if !assertError(t, name, err, kase.err) { - continue - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, err := Unmarshal(tt.data, tt.dest) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != tt.wantN { + t.Errorf("bytes read = %d, want %d", n, tt.wantN) + } - if string(out) != kase.out { - t.Errorf("%s: expected output to be %#v, but got %#v", name, kase.out, out) - continue - } + // Compare values using reflect + got := reflect.ValueOf(tt.dest).Elem().Interface() + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("got %v (%T), want %v (%T)", got, got, tt.expected, tt.expected) + } + }) } } -type sizedField struct { - Val string `xdrmaxsize:"32"` -} +// TestUnmarshal_ReflectionStruct tests reflection-based decoding for structs +func TestUnmarshal_ReflectionStruct(t *testing.T) { + type simpleStruct struct { + A int32 + B string + C bool + } -func TestSizedField(t *testing.T) { - cases := map[string]struct { - input []byte - out string - err error - }{ - // works for 0 length - "0 length": {[]byte{0x00, 0x00, 0x00, 0x00}, "", nil}, - // works for 1 length - "1 length": {[]byte{0x00, 0x00, 0x00, 0x01, 0x48, 0x00, 0x00, 0x00}, "H", nil}, - // works for 32 length - "32 length": { - []byte{ - 0x00, 0x00, 0x00, 0x20, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - }, - "HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH", - nil, - }, - // fails for 33 length - "33 length": { - []byte{ - 0x00, 0x00, 0x00, 0x21, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, 0x48, - 0x48, 0x00, 0x00, 0x00, - }, - "", - &UnmarshalError{ErrorCode: ErrOverflow}, - }, + // Encoded: A=42, B="hi", C=true + data := []byte{ + 0x00, 0x00, 0x00, 0x2A, // A = 42 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // B = "hi" (padded) + 0x00, 0x00, 0x00, 0x01, // C = true } - for name, kase := range cases { - var out sizedField - r := bytes.NewReader(kase.input) - _, err := Unmarshal(r, &out) + var result simpleStruct + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 16 { + t.Errorf("bytes read = %d, want 16", n) + } + if result.A != 42 { + t.Errorf("A = %d, want 42", result.A) + } + if result.B != "hi" { + t.Errorf("B = %q, want \"hi\"", result.B) + } + if result.C != true { + t.Errorf("C = %v, want true", result.C) + } +} - if !assertError(t, name, err, kase.err) { - continue - } +// TestUnmarshal_ReflectionSlice tests reflection-based decoding for slices +func TestUnmarshal_ReflectionSlice(t *testing.T) { + // Slice of 3 int32s: [1, 2, 3] + data := []byte{ + 0x00, 0x00, 0x00, 0x03, // length = 3 + 0x00, 0x00, 0x00, 0x01, // [0] = 1 + 0x00, 0x00, 0x00, 0x02, // [1] = 2 + 0x00, 0x00, 0x00, 0x03, // [2] = 3 + } - if out.Val != kase.out { - t.Errorf("%s: expected output to be %#v, but got %#v", name, kase.out, out.Val) - continue - } + var result []int32 + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 16 { + t.Errorf("bytes read = %d, want 16", n) + } + expected := []int32{1, 2, 3} + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v, want %v", result, expected) } } -type defEnum int32 +// TestUnmarshal_ReflectionMap tests reflection-based decoding for maps +func TestUnmarshal_ReflectionMap(t *testing.T) { + // Map with 2 entries: {"a": 1, "b": 2} + data := []byte{ + 0x00, 0x00, 0x00, 0x02, // count = 2 + 0x00, 0x00, 0x00, 0x01, 'a', 0x00, 0x00, 0x00, // key "a" + 0x00, 0x00, 0x00, 0x01, // value 1 + 0x00, 0x00, 0x00, 0x01, 'b', 0x00, 0x00, 0x00, // key "b" + 0x00, 0x00, 0x00, 0x02, // value 2 + } -func (e defEnum) ValidEnum(v int32) bool { - return v < 1 + var result map[string]int32 + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 28 { + t.Errorf("bytes read = %d, want 28", n) + } + if result["a"] != 1 { + t.Errorf("result[\"a\"] = %d, want 1", result["a"]) + } + if result["b"] != 2 { + t.Errorf("result[\"b\"] = %d, want 2", result["b"]) + } } -type defUnion struct { - Type defEnum - Data *int32 -} +// TestUnmarshal_ReflectionPointer tests reflection-based decoding for optional (pointer) types +func TestUnmarshal_ReflectionPointer(t *testing.T) { + // Pointer present with value 42 + dataPresent := []byte{ + 0x00, 0x00, 0x00, 0x01, // present = true + 0x00, 0x00, 0x00, 0x2A, // value = 42 + } -func (u defUnion) SwitchFieldName() string { - return "Type" -} + var result1 *int32 + n, err := Unmarshal(dataPresent, &result1) + if err != nil { + t.Fatalf("Unmarshal (present) failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result1 == nil || *result1 != 42 { + t.Errorf("got %v, want pointer to 42", result1) + } -func (u defUnion) ArmForSwitch(sw int32) (string, bool) { - switch sw { - case 0: - return "Data", true + // Pointer absent (nil) + dataAbsent := []byte{ + 0x00, 0x00, 0x00, 0x00, // present = false } - return "-", false + var result2 *int32 + result2 = new(int32) // Pre-allocate to verify it gets set to nil + *result2 = 999 + n, err = Unmarshal(dataAbsent, &result2) + if err != nil { + t.Fatalf("Unmarshal (absent) failed: %v", err) + } + if n != 4 { + t.Errorf("bytes read = %d, want 4", n) + } + if result2 != nil { + t.Errorf("got %v, want nil", result2) + } } -func TestUnion_EnumValidation(t *testing.T) { +// TestUnmarshalWithOptions tests UnmarshalWithOptions with custom MaxDepth +func TestUnmarshalWithOptions(t *testing.T) { + data := []byte{0x00, 0x00, 0x01, 0x00} // 256 in big-endian - var u defUnion - var buf bytes.Buffer + var result testDecoderFromType + opts := DecodeOptions{MaxDepth: 100} + n, err := UnmarshalWithOptions(data, &result, opts) + if err != nil { + t.Fatalf("UnmarshalWithOptions failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result.Value != 256 { + t.Errorf("expected value 256, got %d", result.Value) + } +} - // encode a union with invalid value fails - u.Type = defEnum(3) - _, err := Marshal(&buf, u) +// TestMaxDepthPassed tests that MaxDepth is correctly passed to DecodeFrom +func TestMaxDepthPassed(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} - if err == nil { - t.Errorf("expected error when marshaling invalid enum, got none. result: %#v", buf.Bytes()) + // Test with default options + var result1 testNestedType + _, err := Unmarshal(data, &result1) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if result1.ReceivedMaxDepth != DecodeDefaultMaxDepth { + t.Errorf("expected maxDepth %d, got %d", DecodeDefaultMaxDepth, result1.ReceivedMaxDepth) } - // decoding an invalid enum value into a union results in an error - u = defUnion{} - invalid := []byte{0x0, 0x0, 0x0, 0x3} - _, err = Unmarshal(bytes.NewReader(invalid), &u) - if err == nil { - t.Errorf("expected error when unmarshaling invalid enum into union, got none. result: %#v", u) + // Test with custom MaxDepth + var result2 testNestedType + opts := DecodeOptions{MaxDepth: 50} + _, err = UnmarshalWithOptions(data, &result2, opts) + if err != nil { + t.Fatalf("UnmarshalWithOptions failed: %v", err) + } + if result2.ReceivedMaxDepth != 50 { + t.Errorf("expected maxDepth 50, got %d", result2.ReceivedMaxDepth) } -} -func TestPaddedReads(t *testing.T) { - // regression test for non-zeroed padding + // Test with MaxDepth 0 (should use default) + var result3 testNestedType + opts = DecodeOptions{MaxDepth: 0} + _, err = UnmarshalWithOptions(data, &result3, opts) + if err != nil { + t.Fatalf("UnmarshalWithOptions failed: %v", err) + } + if result3.ReceivedMaxDepth != DecodeDefaultMaxDepth { + t.Errorf("expected maxDepth %d for zero option, got %d", DecodeDefaultMaxDepth, result3.ReceivedMaxDepth) + } +} - // opaque - dec := NewDecoder(bytes.NewReader([]byte{0x0, 0x0, 0x1, 0x1})) - _, _, err := dec.DecodeFixedOpaque(3) +// TestSkip_NegativeLength tests that negative skip length is rejected +func TestSkip_NegativeLength(t *testing.T) { + d := NewDecoder([]byte{0x00, 0x00, 0x00, 0x00}) + err := d.Skip(-1) if err == nil { - t.Error("expected error when unmarshaling opaque with non-zero padding byte, got none") + t.Fatal("expected error for negative skip length") + } + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrBadArguments { + t.Errorf("expected ErrBadArguments, got %v", unmarshalErr.ErrorCode) } +} - // string - dec = NewDecoder(bytes.NewReader([]byte{ - 0x0, 0x0, 0x0, 0x1, - 0x1, 0x1, 0x0, 0x0, - })) - _, _, err = dec.DecodeString(3) +// TestDecodeOpaque_LengthOverflow tests that length > maxInt32 is rejected +func TestDecodeOpaque_LengthOverflow(t *testing.T) { + // Encode a length of 0x80000000 (2^31, which is > maxInt32) + data := []byte{0x80, 0x00, 0x00, 0x00} + d := NewDecoder(data) + _, _, err := d.DecodeOpaque(0) if err == nil { - t.Error("expected error when unmarshaling string with non-zero padding byte, got none") + t.Fatal("expected error for length > maxInt32") + } + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrOverflow { + t.Errorf("expected ErrOverflow, got %v", unmarshalErr.ErrorCode) } +} - // read varopaque - dec = NewDecoder(bytes.NewReader([]byte{ - 0x0, 0x0, 0x0, 0x1, - 0x1, 0x0, 0x0, 0x1, - })) - _, _, err = dec.DecodeOpaque(3) +// TestDecodeString_LengthOverflow tests that length > maxInt32 is rejected +func TestDecodeString_LengthOverflow(t *testing.T) { + // Encode a length of 0x80000000 (2^31, which is > maxInt32) + data := []byte{0x80, 0x00, 0x00, 0x00} + d := NewDecoder(data) + _, _, err := d.DecodeString(0) if err == nil { - t.Error("expected error when unmarshaling varopaque with non-zero padding byte, got none") + t.Fatal("expected error for length > maxInt32") + } + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrOverflow { + t.Errorf("expected ErrOverflow, got %v", unmarshalErr.ErrorCode) } } -func TestDecodeNilPointerIntoExistingObjectWithNotNilPointer(t *testing.T) { - var buf bytes.Buffer - data := "data" - _, err := Marshal(&buf, structWithPointer{Data: &data}) +// TestDecoder_Decode tests the Decode convenience method +func TestDecoder_Decode(t *testing.T) { + // Create decoder with initial data + data1 := []byte{0x00, 0x00, 0x00, 0x2A} // 42 + decoder := NewDecoder(data1) + + var result testDecoderFromType + n, err := decoder.Decode(&result) if err != nil { - t.Error("unexpected error") + t.Fatalf("Decode failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result.Value != 42 { + t.Errorf("expected 42, got %d", result.Value) } - var s structWithPointer - _, err = Unmarshal(&buf, &s) + // Test decoder reuse with Reset + Decode + data2 := []byte{0x00, 0x00, 0x01, 0x00} // 256 + decoder.Reset(data2) + + n, err = decoder.Decode(&result) if err != nil { - t.Error("unexpected error") + t.Fatalf("Decode after Reset failed: %v", err) + } + if n != 4 { + t.Errorf("expected 4 bytes read, got %d", n) + } + if result.Value != 256 { + t.Errorf("expected 256, got %d", result.Value) } +} + +// TestDecoder_Decode_MaxDepth tests that Decode passes MaxDepth correctly +func TestDecoder_Decode_MaxDepth(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} - // Note: - // 1. structWithPointer.Data is nil. - // 2. We unmarshal into previously used object. - _, err = Marshal(&buf, structWithPointer{}) + // Test with custom MaxDepth + decoder := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 75}) + + var result testNestedType + _, err := decoder.Decode(&result) if err != nil { - t.Error("unexpected error") + t.Fatalf("Decode failed: %v", err) + } + if result.ReceivedMaxDepth != 75 { + t.Errorf("expected maxDepth 75, got %d", result.ReceivedMaxDepth) } +} - _, err = Unmarshal(&buf, &s) - if err != nil { - t.Error("unexpected error") +// TestDecoder_Decode_Error tests that Decode propagates errors correctly +func TestDecoder_Decode_Error(t *testing.T) { + // Insufficient data + data := []byte{0x00, 0x00} // Only 2 bytes, need 4 + decoder := NewDecoder(data) + + var result testDecoderFromType + _, err := decoder.Decode(&result) + if err == nil { + t.Fatal("expected error for insufficient data") } - if s.Data != nil { - t.Error("Data should be nil") + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T", err) + } + if unmarshalErr.ErrorCode != ErrIO { + t.Errorf("expected ErrIO, got %v", unmarshalErr.ErrorCode) } } -func TestDecodeUnionIntoExistingObject(t *testing.T) { - var buf bytes.Buffer - var idata int32 = 1 - sdata := "data" - _, err := Marshal(&buf, aUnion{ - Type: 0, - Data: &idata, - }) - if err != nil { - t.Error("unexpected error") +// TestDecoder_MaxDepthExceeded tests that deeply nested structures trigger ErrMaxDecodingDepth +func TestDecoder_MaxDepthExceeded(t *testing.T) { + // Create a linked list type with pointer to next + type node struct { + Value int32 + Next *node + } + + // Build XDR data for a linked list with 5 nodes + // Each node: int32 value, bool (present), then next node + // Node 1 -> Node 2 -> Node 3 -> Node 4 -> Node 5 -> nil + data := []byte{ + // Node 1 + 0x00, 0x00, 0x00, 0x01, // Value = 1 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 2 + 0x00, 0x00, 0x00, 0x02, // Value = 2 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 3 + 0x00, 0x00, 0x00, 0x03, // Value = 3 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 4 + 0x00, 0x00, 0x00, 0x04, // Value = 4 + 0x00, 0x00, 0x00, 0x01, // Next present = true + // Node 5 + 0x00, 0x00, 0x00, 0x05, // Value = 5 + 0x00, 0x00, 0x00, 0x00, // Next present = false (nil) } - var s aUnion - _, err = Unmarshal(&buf, &s) + // With default MaxDepth (200), this should succeed + var result1 node + _, err := Unmarshal(data, &result1) if err != nil { - t.Error("unexpected error") + t.Fatalf("Unmarshal with default depth failed: %v", err) + } + + // With MaxDepth = 3, this should fail (each pointer adds depth) + // The structure needs: depth for node struct, depth for pointer, depth for next node... + var result2 node + opts := DecodeOptions{MaxDepth: 3} + _, err = UnmarshalWithOptions(data, &result2, opts) + if err == nil { + t.Fatal("expected ErrMaxDecodingDepth error with MaxDepth=3") + } + + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrMaxDecodingDepth { + t.Errorf("expected ErrMaxDecodingDepth, got %v", unmarshalErr.ErrorCode) + } +} + +// testUnionValueArm is a union type with value-type arms for testing +type testUnionValueArm struct { + Type int32 + Int int32 // value type arm (switch 0) + Str string // value type arm (switch 1) + Empty bool // value type arm (switch 2) +} + +func (u testUnionValueArm) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm } + return "-", false +} + +func (u testUnionValueArm) SwitchFieldName() string { + return "Type" +} + +// testUnionPointerArm is a union type with pointer-type arms for testing +type testUnionPointerArm struct { + Type int32 + Int *int32 // pointer type arm (switch 0) + Str *string // pointer type arm (switch 1) +} + +func (u testUnionPointerArm) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm + } + return "-", false +} + +func (u testUnionPointerArm) SwitchFieldName() string { + return "Type" +} + +// TestDecodeUnionWithValueTypeArm tests decoding unions with value-type arms +func TestDecodeUnionWithValueTypeArm(t *testing.T) { + t.Run("value-type int arm", func(t *testing.T) { + // XDR encoded union: Type=0, Int=42 + data := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + var result testUnionValueArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result.Type != 0 { + t.Errorf("Type = %d, want 0", result.Type) + } + if result.Int != 42 { + t.Errorf("Int = %d, want 42", result.Int) + } + }) + + t.Run("value-type string arm", func(t *testing.T) { + // XDR encoded union: Type=1, Str="hi" + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } - _, err = Marshal(&buf, aUnion{ - Type: 1, - Text: &sdata, + var result testUnionValueArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes read = %d, want 12", n) + } + if result.Type != 1 { + t.Errorf("Type = %d, want 1", result.Type) + } + if result.Str != "hi" { + t.Errorf("Str = %q, want \"hi\"", result.Str) + } }) - if err != nil { - t.Error("unexpected error") + + t.Run("void arm", func(t *testing.T) { + // XDR encoded union: Type=2 (void arm, no data) + data := []byte{ + 0x00, 0x00, 0x00, 0x02, // Type = 2 + } + + var result testUnionValueArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 4 { + t.Errorf("bytes read = %d, want 4", n) + } + if result.Type != 2 { + t.Errorf("Type = %d, want 2", result.Type) + } + }) +} + +// TestDecodeUnionWithPointerTypeArm tests decoding unions with pointer-type arms +func TestDecodeUnionWithPointerTypeArm(t *testing.T) { + t.Run("pointer-type int arm", func(t *testing.T) { + // XDR encoded union: Type=0, Int=42 + data := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + var result testUnionPointerArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result.Type != 0 { + t.Errorf("Type = %d, want 0", result.Type) + } + if result.Int == nil || *result.Int != 42 { + t.Errorf("Int = %v, want pointer to 42", result.Int) + } + }) + + t.Run("pointer-type string arm", func(t *testing.T) { + // XDR encoded union: Type=1, Str="hi" + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } + + var result testUnionPointerArm + n, err := Unmarshal(data, &result) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes read = %d, want 12", n) + } + if result.Type != 1 { + t.Errorf("Type = %d, want 1", result.Type) + } + if result.Str == nil || *result.Str != "hi" { + t.Errorf("Str = %v, want pointer to \"hi\"", result.Str) + } + }) +} + +// TestEnterScope tests the EnterScope method for depth tracking +func TestEnterScope(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} + + t.Run("decrement depth", func(t *testing.T) { + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 10}) + + // Initial state: currentDepth should equal maxDepth + if d.MaxDepth() != 10 { + t.Errorf("MaxDepth() = %d, want 10", d.MaxDepth()) + } + + // Enter scope should succeed and decrement depth + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope failed: %v", err) + } + + // Enter again + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope failed: %v", err) + } + }) + + t.Run("depth exceeded", func(t *testing.T) { + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 2}) + + // First two should succeed + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope 1 failed: %v", err) + } + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope 2 failed: %v", err) + } + + // Third should fail + err := d.EnterScope() + if err == nil { + t.Fatal("expected error when depth exceeded") + } + + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrMaxDecodingDepth { + t.Errorf("expected ErrMaxDecodingDepth, got %v", unmarshalErr.ErrorCode) + } + }) +} + +// TestLeaveScope tests the LeaveScope method for depth tracking +func TestLeaveScope(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 5}) + + // Enter scope twice + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope 1 failed: %v", err) + } + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope 2 failed: %v", err) } - _, err = Unmarshal(&buf, &s) - if err != nil { - t.Error("unexpected error") + // Leave scope once + d.LeaveScope() + + // Should be able to enter again (we freed up one level) + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope after LeaveScope failed: %v", err) } +} + +// TestEnterLeaveScopeBalance tests that Enter/Leave are balanced correctly +func TestEnterLeaveScopeBalance(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x01} + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 3}) - if s.Data != nil { - t.Error("Data should be nil") + // Simulate nested decoding: enter 3 times, leave 3 times + for i := 0; i < 3; i++ { + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope %d failed: %v", i+1, err) + } } - if s.Type != 1 { - t.Error("Type does not match") + // Should fail - at max depth + if err := d.EnterScope(); err == nil { + t.Fatal("expected error at max depth") } - if *s.Text != sdata { - t.Error("Text does not match") + // Leave all 3 + for i := 0; i < 3; i++ { + d.LeaveScope() + } + + // Should be able to enter 3 more times again + for i := 0; i < 3; i++ { + if err := d.EnterScope(); err != nil { + t.Fatalf("EnterScope after reset %d failed: %v", i+1, err) + } } } -func TestDecodeMaxDepth(t *testing.T) { - var buf bytes.Buffer - data := "data" - _, err := Marshal(&buf, structWithPointer{Data: &data}) - if err != nil { - t.Error("unexpected error") +// testScopedType implements DecoderFrom using EnterScope/LeaveScope +type testScopedType struct { + Value int32 +} + +func (t *testScopedType) DecodeFrom(d *Decoder) (int, error) { + if err := d.EnterScope(); err != nil { + return 0, err } + defer d.LeaveScope() - bufCopy := buf - decoder := NewDecoderWithOptions(&bufCopy, DecodeOptions{MaxDepth: 3}) - var s structWithPointer - _, err = decoder.Decode(&s) + v, n, err := d.DecodeInt() if err != nil { - t.Error("unexpected error") + return n, err } + t.Value = v + return n, nil +} + +// TestDecoderFromWithScope tests that DecodeFrom implementations using EnterScope work correctly +func TestDecoderFromWithScope(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x2A} // 42 + + t.Run("decode with scope", func(t *testing.T) { + d := NewDecoder(data) + var result testScopedType + n, err := result.DecodeFrom(d) + if err != nil { + t.Fatalf("DecodeFrom failed: %v", err) + } + if n != 4 { + t.Errorf("bytes read = %d, want 4", n) + } + if result.Value != 42 { + t.Errorf("Value = %d, want 42", result.Value) + } + }) + + t.Run("depth limit with scope", func(t *testing.T) { + // With MaxDepth=1, a single DecodeFrom should work + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 1}) + var result testScopedType + _, err := result.DecodeFrom(d) + if err != nil { + t.Fatalf("DecodeFrom with depth 1 failed: %v", err) + } + + // Reset and try with MaxDepth=0 (defaults to 200, so will work) + d.Reset(data) + }) +} - bufCopy = buf - decoder = NewDecoderWithOptions(&bufCopy, DecodeOptions{MaxDepth: 2}) - _, err = decoder.Decode(&s) - assertError(t, "", err, &UnmarshalError{ErrorCode: ErrMaxDecodingDepth}) +// testNestedScopedType implements DecoderFrom and contains another testScopedType +type testNestedScopedType struct { + Outer int32 + Inner testScopedType } -func TestDecodeMaxAllocationCheck_ImplicitLenReader(t *testing.T) { - var buf bytes.Buffer - _, err := Marshal(&buf, "thisstringis23charslong") +func (t *testNestedScopedType) DecodeFrom(d *Decoder) (int, error) { + if err := d.EnterScope(); err != nil { + return 0, err + } + defer d.LeaveScope() + + var n, nTmp int + var err error + + t.Outer, nTmp, err = d.DecodeInt() + n += nTmp if err != nil { - t.Error("unexpected error") + return n, err } - // Reduce the buffer size so that the length of the buffer - // is shorter than the encoded XDR length - buf.Truncate(buf.Len() - 4) + nTmp, err = t.Inner.DecodeFrom(d) + n += nTmp + if err != nil { + return n, err + } - decoder := NewDecoder(&buf) - var s string - _, err = decoder.Decode(&s) - assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow}) + return n, nil } -func TestDecodeMaxAllocationCheck_ExplicitLenReader(t *testing.T) { - var buf bytes.Buffer - encoder := base64.NewEncoder(base64.StdEncoding, &buf) - _, err := Marshal(encoder, "thisstringis23charslong") - if err != nil { - t.Error("unexpected error") +// TestNestedDecodeFromWithScope tests depth tracking across nested DecodeFrom calls +func TestNestedDecodeFromWithScope(t *testing.T) { + // Two int32 values: outer=1, inner=2 + data := []byte{ + 0x00, 0x00, 0x00, 0x01, // Outer = 1 + 0x00, 0x00, 0x00, 0x02, // Inner.Value = 2 } - xdrLen := base64.StdEncoding.DecodedLen(buf.Len()) - // Reduce the buffer size so that the length of the buffer - // is shorter than the encoded XDR length - reducedLen := xdrLen - 4 + t.Run("nested decode succeeds with sufficient depth", func(t *testing.T) { + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 10}) + var result testNestedScopedType + n, err := result.DecodeFrom(d) + if err != nil { + t.Fatalf("DecodeFrom failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + if result.Outer != 1 { + t.Errorf("Outer = %d, want 1", result.Outer) + } + if result.Inner.Value != 2 { + t.Errorf("Inner.Value = %d, want 2", result.Inner.Value) + } + }) + + t.Run("nested decode fails with insufficient depth", func(t *testing.T) { + // Need depth 2: one for outer, one for inner + // With MaxDepth=1, inner should fail + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 1}) + var result testNestedScopedType + _, err := result.DecodeFrom(d) + if err == nil { + t.Fatal("expected depth error with MaxDepth=1") + } + + var unmarshalErr *UnmarshalError + if !errors.As(err, &unmarshalErr) { + t.Fatalf("expected UnmarshalError, got %T: %v", err, err) + } + if unmarshalErr.ErrorCode != ErrMaxDecodingDepth { + t.Errorf("expected ErrMaxDecodingDepth, got %v", unmarshalErr.ErrorCode) + } + }) - decoder := NewDecoderWithOptions(&buf, DecodeOptions{MaxInputLen: reducedLen}) - var s string - _, err = decoder.Decode(&s) - assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow}) + t.Run("nested decode succeeds with exact depth", func(t *testing.T) { + // With MaxDepth=2, should just barely work + d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 2}) + var result testNestedScopedType + n, err := result.DecodeFrom(d) + if err != nil { + t.Fatalf("DecodeFrom with exact depth failed: %v", err) + } + if n != 8 { + t.Errorf("bytes read = %d, want 8", n) + } + }) } diff --git a/xdr3/encode.go b/xdr3/encode.go index 2e0cf5b..c1341ca 100644 --- a/xdr3/encode.go +++ b/xdr3/encode.go @@ -131,7 +131,7 @@ func (enc *Encoder) EncodeUint(v uint32) (int, error) { n, err := enc.w.Write(enc.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 4) - err := marshalError("EncodeUint", ErrIO, msg, enc.scratchBuf[:4], err) + err := marshalError("EncodeUint", ErrIO, msg, enc.scratchBuf[:n], err) return n, err } @@ -201,7 +201,7 @@ func (enc *Encoder) EncodeHyper(v int64) (int, error) { n, err := enc.w.Write(enc.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 8) - err := marshalError("EncodeHyper", ErrIO, msg, enc.scratchBuf[:8], err) + err := marshalError("EncodeHyper", ErrIO, msg, enc.scratchBuf[:n], err) return n, err } @@ -312,7 +312,7 @@ func (enc *Encoder) EncodeFixedOpaque(v []byte) (int, error) { written := make([]byte, l+n2) copy(written, v) copy(written[l:], b[:n2]) - msg := fmt.Sprintf(errIOEncode, err.Error(), l+pad) + msg := fmt.Sprintf(errIOEncode, err.Error(), l+n2) err := marshalError("EncodeFixedOpaque", ErrIO, msg, written, err) return n, err @@ -460,7 +460,6 @@ func (enc *Encoder) encodeUnion(v reflect.Value) (int, error) { vs := v.FieldByName(u.SwitchFieldName()) n, err := enc.encode(vs) - if err != nil { return n, err } @@ -473,6 +472,7 @@ func (enc *Encoder) encodeUnion(v reflect.Value) (int, error) { } else { sw = int32(vs.Int()) } + arm, ok := u.ArmForSwitch(sw) // void arm, we're done @@ -481,26 +481,25 @@ func (enc *Encoder) encodeUnion(v reflect.Value) (int, error) { } vv := v.FieldByName(arm) - if !vv.IsValid() || !ok { msg := fmt.Sprintf("invalid union switch: %d", sw) err := marshalError("encodeUnion", ErrBadUnionSwitch, msg, nil, nil) return n, err } - if vv.Kind() != reflect.Ptr { - msg := fmt.Sprintf("invalid union value field: %v", vv.Kind()) - err := marshalError("encodeUnion", ErrBadUnionValue, msg, nil, nil) - return n, err - } - - if vv.IsNil() { - msg := fmt.Sprintf("can't encode nil union value") - err := marshalError("encodeUnion", ErrBadUnionValue, msg, nil, nil) + // Handle both pointer and value-type union arms + if vv.Kind() == reflect.Ptr { + if vv.IsNil() { + return n, marshalError("encodeUnion", ErrBadUnionValue, + "can't encode nil pointer union arm", nil, nil) + } + n2, err := enc.encode(vv.Elem()) + n += n2 return n, err } - n2, err := enc.encode(vv.Elem()) + // Value-type arm - encode directly + n2, err := enc.encode(vv) n += n2 return n, err } diff --git a/xdr3/encode_test.go b/xdr3/encode_test.go index bda0873..d4d8c00 100644 --- a/xdr3/encode_test.go +++ b/xdr3/encode_test.go @@ -26,6 +26,72 @@ import ( . "github.com/stellar/go-xdr/xdr3" ) +// subTest is used to allow testing of the Marshal function into struct fields +// which are structs themselves. +type subTest struct { + A string + B uint8 +} + +// allTypesTest is used to allow testing of the Marshal function into struct +// fields of all supported types. +type allTypesTest struct { + A int8 + B uint8 + C int16 + D uint16 + E int32 + F uint32 + G int64 + H uint64 + I bool + J float32 + K float64 + L string + M []byte + N [3]byte + O []int16 + P [2]subTest + Q subTest + R map[string]uint32 + S time.Time +} + +// opaqueStruct is used to test handling of uint8 slices and arrays. +type opaqueStruct struct { + Slice []uint8 `xdropaque:"false"` + Array [1]uint8 `xdropaque:"false"` +} + +type AnEnum int32 + +func (e AnEnum) ValidEnum(v int32) bool { + return v < 3 +} + +type aUnion struct { + Type AnEnum + Data *int32 + Text *string `xdrmaxsize:"28"` +} + +func (u aUnion) SwitchFieldName() string { + return "Type" +} + +func (u aUnion) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Data", true + case 1: + return "Text", true + case 2: // void + return "", true + } + + return "-", false +} + // testExpectedMRet is a convenience method to test an expected number of bytes // written and error for a marshal. func testExpectedMRet(t *testing.T, name string, n, wantN int, err, wantErr error) bool { @@ -703,6 +769,159 @@ func TestEncoder(t *testing.T) { } } +// valueTypeUnion is a union type with value-type arms for testing encoding +type valueTypeUnion struct { + Type int32 + Int int32 // value type arm (switch 0) + Str string // value type arm (switch 1) +} + +func (u valueTypeUnion) SwitchFieldName() string { + return "Type" +} + +func (u valueTypeUnion) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm + } + return "-", false +} + +// pointerTypeUnion is a union type with pointer-type arms for testing encoding +type pointerTypeUnion struct { + Type int32 + Int *int32 // pointer type arm (switch 0) + Str *string // pointer type arm (switch 1) +} + +func (u pointerTypeUnion) SwitchFieldName() string { + return "Type" +} + +func (u pointerTypeUnion) ArmForSwitch(sw int32) (string, bool) { + switch sw { + case 0: + return "Int", true + case 1: + return "Str", true + case 2: + return "", true // void arm + } + return "-", false +} + +// TestMarshalUnionValueTypeArms tests encoding unions with value-type arms +func TestMarshalUnionValueTypeArms(t *testing.T) { + t.Run("value-type int arm", func(t *testing.T) { + u := valueTypeUnion{Type: 0, Int: 42} + expected := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + data := newFixedWriter(8) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes written = %d, want 8", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) + + t.Run("value-type string arm", func(t *testing.T) { + u := valueTypeUnion{Type: 1, Str: "hi"} + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } + + data := newFixedWriter(12) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes written = %d, want 12", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) + + t.Run("void arm", func(t *testing.T) { + u := valueTypeUnion{Type: 2} + expected := []byte{ + 0x00, 0x00, 0x00, 0x02, // Type = 2 + } + + data := newFixedWriter(4) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 4 { + t.Errorf("bytes written = %d, want 4", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) +} + +// TestMarshalUnionPointerTypeArms tests encoding unions with pointer-type arms +func TestMarshalUnionPointerTypeArms(t *testing.T) { + t.Run("pointer-type int arm", func(t *testing.T) { + val := int32(42) + u := pointerTypeUnion{Type: 0, Int: &val} + expected := []byte{ + 0x00, 0x00, 0x00, 0x00, // Type = 0 + 0x00, 0x00, 0x00, 0x2A, // Int = 42 + } + + data := newFixedWriter(8) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 8 { + t.Errorf("bytes written = %d, want 8", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) + + t.Run("pointer-type string arm", func(t *testing.T) { + val := "hi" + u := pointerTypeUnion{Type: 1, Str: &val} + expected := []byte{ + 0x00, 0x00, 0x00, 0x01, // Type = 1 + 0x00, 0x00, 0x00, 0x02, 'h', 'i', 0x00, 0x00, // Str = "hi" (padded) + } + + data := newFixedWriter(12) + n, err := Marshal(data, u) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if n != 12 { + t.Errorf("bytes written = %d, want 12", n) + } + if !reflect.DeepEqual(data.Bytes(), expected) { + t.Errorf("got %v, want %v", data.Bytes(), expected) + } + }) +} + // TestMarshalCorners ensures the Marshal function properly handles various // cases not already covered by the other tests. func TestMarshalCorners(t *testing.T) { diff --git a/xdr3/example_test.go b/xdr3/example_test.go index ce2f52d..2b452cb 100644 --- a/xdr3/example_test.go +++ b/xdr3/example_test.go @@ -54,47 +54,8 @@ func ExampleMarshal() { // encoded data: [171 205 239 0 0 0 0 2 0 0 0 1 0 0 0 10] } -// This example demonstrates how to use Unmarshal to decode XDR encoded data -// from a byte slice into a struct. -func ExampleUnmarshal() { - // Hypothetical image header format. - type ImageHeader struct { - Signature [3]byte - Version uint32 - IsGrayscale bool - NumSections uint32 - } - - // XDR encoded data described by the above structure. Typically this - // would be read from a file or across the network, but use a manual - // byte array here as an example. - encodedData := []byte{ - 0xAB, 0xCD, 0xEF, 0x00, // Signature - 0x00, 0x00, 0x00, 0x02, // Version - 0x00, 0x00, 0x00, 0x01, // IsGrayscale - 0x00, 0x00, 0x00, 0x0A, // NumSections - } - - // Declare a variable to provide Unmarshal with a concrete type and - // instance to decode into. - var h ImageHeader - bytesRead, err := xdr.Unmarshal(bytes.NewReader(encodedData), &h) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("bytes read:", bytesRead) - fmt.Printf("h: %+v", h) - - // Output: - // bytes read: 16 - // h: {Signature:[171 205 239] Version:2 IsGrayscale:true NumSections:10} -} - // This example demonstrates how to manually decode XDR encoded data from a -// reader. Compare this example with the Unmarshal example which performs the -// same task automatically by utilizing a struct type definition and reflection. +// byte slice using Decoder. func ExampleNewDecoder() { // XDR encoded data for a hypothetical ImageHeader struct as follows: // type ImageHeader struct { @@ -111,7 +72,7 @@ func ExampleNewDecoder() { } // Get a new decoder for manual decoding. - dec := xdr.NewDecoder(bytes.NewReader(encodedData)) + dec := xdr.NewDecoder(encodedData) signature, _, err := dec.DecodeFixedOpaque(3) if err != nil { diff --git a/xdr3/internal_test.go b/xdr3/internal_test.go index ba575ce..4b2ee2e 100644 --- a/xdr3/internal_test.go +++ b/xdr3/internal_test.go @@ -34,10 +34,3 @@ func TstEncode(w io.Writer) func(v reflect.Value) (int, error) { enc := NewEncoder(w) return enc.encode } - -// TstDecode creates a new Decoder for the passed reader and returns the -// internal decode function on the Decoder. -func TstDecode(r io.Reader) func(v reflect.Value, maxLen int, maxDepth uint) (int, error) { - dec := NewDecoder(r) - return dec.decode -} From 32a41723eebcc458ee0756b26ce4f26e915a90f7 Mon Sep 17 00:00:00 2001 From: tamirms Date: Thu, 8 Jan 2026 14:14:24 +0000 Subject: [PATCH 2/2] xdr: pass maxDepth as parameter instead of tracking in Decoder Change DecoderFrom interface to accept maxDepth parameter directly, eliminating the need for EnterScope/LeaveScope methods and the currentDepth field in Decoder. This reduces overhead by: - Removing defer calls for LeaveScope - Eliminating struct field access for depth tracking - Enabling better inlining opportunities Implementations should decrement maxDepth when calling DecodeFrom on nested types and return an error if maxDepth reaches 0. Co-Authored-By: Claude Opus 4.5 --- xdr3/decode.go | 43 +++----- xdr3/decode_test.go | 247 +------------------------------------------- 2 files changed, 14 insertions(+), 276 deletions(-) diff --git a/xdr3/decode.go b/xdr3/decode.go index b55840b..4fe531d 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -48,9 +48,11 @@ var DefaultDecodeOptions = DecodeOptions{ // DecoderFrom is implemented by types that can decode themselves from a Decoder. // Types implementing this interface get a fast path in Decode(), bypassing reflection. -// Implementations can call d.MaxDepth() if they need to track recursion depth. +// The maxDepth parameter tracks recursion depth to prevent stack overflow from +// maliciously crafted deeply-nested data. Implementations should decrement maxDepth +// when calling DecodeFrom on nested types and return an error if maxDepth reaches 0. type DecoderFrom interface { - DecodeFrom(d *Decoder) (int, error) + DecodeFrom(d *Decoder, maxDepth uint) (int, error) } /* @@ -120,10 +122,9 @@ func UnmarshalWithOptions(data []byte, v interface{}, options DecodeOptions) (in // necessary in complex scenarios where automatic reflection-based decoding // won't work. type Decoder struct { - buf []byte - pos int - maxDepth uint - currentDepth uint + buf []byte + pos int + maxDepth uint } // NewDecoder returns a Decoder that can be used to manually decode XDR data @@ -140,19 +141,17 @@ func NewDecoderWithOptions(data []byte, options DecodeOptions) *Decoder { maxDepth = DecodeDefaultMaxDepth } return &Decoder{ - buf: data, - pos: 0, - maxDepth: maxDepth, - currentDepth: maxDepth, + buf: data, + pos: 0, + maxDepth: maxDepth, } } // Reset resets the decoder to read from a new byte slice, allowing reuse -// of the decoder to reduce allocations. CurrentDepth is reset to MaxDepth. +// of the decoder to reduce allocations. func (d *Decoder) Reset(data []byte) { d.buf = data d.pos = 0 - d.currentDepth = d.maxDepth } // Remaining returns the number of unread bytes in the buffer. @@ -170,24 +169,6 @@ func (d *Decoder) MaxDepth() uint { return d.maxDepth } -// EnterScope should be called at the start of decoding a compound type -// (struct, union, or array element). Returns an error if max depth would be exceeded. -// Use with LeaveScope: `if err := d.EnterScope(); err != nil { return err }; defer d.LeaveScope()` -func (d *Decoder) EnterScope() error { - if d.currentDepth == 0 { - return unmarshalError("EnterScope", ErrMaxDecodingDepth, "maximum decoding depth reached", nil, nil) - } - d.currentDepth-- - return nil -} - -// LeaveScope should be called when exiting a compound type. Should be used with defer. -func (d *Decoder) LeaveScope() { - if d.currentDepth < d.maxDepth { - d.currentDepth++ - } -} - // DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the // result as an int32 along with the number of bytes actually read. // @@ -1152,7 +1133,7 @@ func (d *Decoder) Decode(v interface{}) (int, error) { // Fast path: if v implements DecoderFrom, use it directly if decodable, ok := v.(DecoderFrom); ok { - return decodable.DecodeFrom(d) + return decodable.DecodeFrom(d, d.maxDepth) } // Fallback: reflection-based decoding diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index 39ab9c3..ba0b71d 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -540,7 +540,7 @@ type testDecoderFromType struct { Value int32 } -func (t *testDecoderFromType) DecodeFrom(d *Decoder) (int, error) { +func (t *testDecoderFromType) DecodeFrom(d *Decoder, maxDepth uint) (int, error) { v, n, err := d.DecodeInt() if err != nil { return n, err @@ -555,7 +555,7 @@ type testNestedType struct { Value int32 } -func (t *testNestedType) DecodeFrom(d *Decoder) (int, error) { +func (t *testNestedType) DecodeFrom(d *Decoder, maxDepth uint) (int, error) { t.ReceivedMaxDepth = d.MaxDepth() v, n, err := d.DecodeInt() if err != nil { @@ -1235,246 +1235,3 @@ func TestDecodeUnionWithPointerTypeArm(t *testing.T) { } }) } - -// TestEnterScope tests the EnterScope method for depth tracking -func TestEnterScope(t *testing.T) { - data := []byte{0x00, 0x00, 0x00, 0x01} - - t.Run("decrement depth", func(t *testing.T) { - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 10}) - - // Initial state: currentDepth should equal maxDepth - if d.MaxDepth() != 10 { - t.Errorf("MaxDepth() = %d, want 10", d.MaxDepth()) - } - - // Enter scope should succeed and decrement depth - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope failed: %v", err) - } - - // Enter again - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope failed: %v", err) - } - }) - - t.Run("depth exceeded", func(t *testing.T) { - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 2}) - - // First two should succeed - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope 1 failed: %v", err) - } - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope 2 failed: %v", err) - } - - // Third should fail - err := d.EnterScope() - if err == nil { - t.Fatal("expected error when depth exceeded") - } - - var unmarshalErr *UnmarshalError - if !errors.As(err, &unmarshalErr) { - t.Fatalf("expected UnmarshalError, got %T: %v", err, err) - } - if unmarshalErr.ErrorCode != ErrMaxDecodingDepth { - t.Errorf("expected ErrMaxDecodingDepth, got %v", unmarshalErr.ErrorCode) - } - }) -} - -// TestLeaveScope tests the LeaveScope method for depth tracking -func TestLeaveScope(t *testing.T) { - data := []byte{0x00, 0x00, 0x00, 0x01} - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 5}) - - // Enter scope twice - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope 1 failed: %v", err) - } - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope 2 failed: %v", err) - } - - // Leave scope once - d.LeaveScope() - - // Should be able to enter again (we freed up one level) - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope after LeaveScope failed: %v", err) - } -} - -// TestEnterLeaveScopeBalance tests that Enter/Leave are balanced correctly -func TestEnterLeaveScopeBalance(t *testing.T) { - data := []byte{0x00, 0x00, 0x00, 0x01} - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 3}) - - // Simulate nested decoding: enter 3 times, leave 3 times - for i := 0; i < 3; i++ { - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope %d failed: %v", i+1, err) - } - } - - // Should fail - at max depth - if err := d.EnterScope(); err == nil { - t.Fatal("expected error at max depth") - } - - // Leave all 3 - for i := 0; i < 3; i++ { - d.LeaveScope() - } - - // Should be able to enter 3 more times again - for i := 0; i < 3; i++ { - if err := d.EnterScope(); err != nil { - t.Fatalf("EnterScope after reset %d failed: %v", i+1, err) - } - } -} - -// testScopedType implements DecoderFrom using EnterScope/LeaveScope -type testScopedType struct { - Value int32 -} - -func (t *testScopedType) DecodeFrom(d *Decoder) (int, error) { - if err := d.EnterScope(); err != nil { - return 0, err - } - defer d.LeaveScope() - - v, n, err := d.DecodeInt() - if err != nil { - return n, err - } - t.Value = v - return n, nil -} - -// TestDecoderFromWithScope tests that DecodeFrom implementations using EnterScope work correctly -func TestDecoderFromWithScope(t *testing.T) { - data := []byte{0x00, 0x00, 0x00, 0x2A} // 42 - - t.Run("decode with scope", func(t *testing.T) { - d := NewDecoder(data) - var result testScopedType - n, err := result.DecodeFrom(d) - if err != nil { - t.Fatalf("DecodeFrom failed: %v", err) - } - if n != 4 { - t.Errorf("bytes read = %d, want 4", n) - } - if result.Value != 42 { - t.Errorf("Value = %d, want 42", result.Value) - } - }) - - t.Run("depth limit with scope", func(t *testing.T) { - // With MaxDepth=1, a single DecodeFrom should work - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 1}) - var result testScopedType - _, err := result.DecodeFrom(d) - if err != nil { - t.Fatalf("DecodeFrom with depth 1 failed: %v", err) - } - - // Reset and try with MaxDepth=0 (defaults to 200, so will work) - d.Reset(data) - }) -} - -// testNestedScopedType implements DecoderFrom and contains another testScopedType -type testNestedScopedType struct { - Outer int32 - Inner testScopedType -} - -func (t *testNestedScopedType) DecodeFrom(d *Decoder) (int, error) { - if err := d.EnterScope(); err != nil { - return 0, err - } - defer d.LeaveScope() - - var n, nTmp int - var err error - - t.Outer, nTmp, err = d.DecodeInt() - n += nTmp - if err != nil { - return n, err - } - - nTmp, err = t.Inner.DecodeFrom(d) - n += nTmp - if err != nil { - return n, err - } - - return n, nil -} - -// TestNestedDecodeFromWithScope tests depth tracking across nested DecodeFrom calls -func TestNestedDecodeFromWithScope(t *testing.T) { - // Two int32 values: outer=1, inner=2 - data := []byte{ - 0x00, 0x00, 0x00, 0x01, // Outer = 1 - 0x00, 0x00, 0x00, 0x02, // Inner.Value = 2 - } - - t.Run("nested decode succeeds with sufficient depth", func(t *testing.T) { - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 10}) - var result testNestedScopedType - n, err := result.DecodeFrom(d) - if err != nil { - t.Fatalf("DecodeFrom failed: %v", err) - } - if n != 8 { - t.Errorf("bytes read = %d, want 8", n) - } - if result.Outer != 1 { - t.Errorf("Outer = %d, want 1", result.Outer) - } - if result.Inner.Value != 2 { - t.Errorf("Inner.Value = %d, want 2", result.Inner.Value) - } - }) - - t.Run("nested decode fails with insufficient depth", func(t *testing.T) { - // Need depth 2: one for outer, one for inner - // With MaxDepth=1, inner should fail - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 1}) - var result testNestedScopedType - _, err := result.DecodeFrom(d) - if err == nil { - t.Fatal("expected depth error with MaxDepth=1") - } - - var unmarshalErr *UnmarshalError - if !errors.As(err, &unmarshalErr) { - t.Fatalf("expected UnmarshalError, got %T: %v", err, err) - } - if unmarshalErr.ErrorCode != ErrMaxDecodingDepth { - t.Errorf("expected ErrMaxDecodingDepth, got %v", unmarshalErr.ErrorCode) - } - }) - - t.Run("nested decode succeeds with exact depth", func(t *testing.T) { - // With MaxDepth=2, should just barely work - d := NewDecoderWithOptions(data, DecodeOptions{MaxDepth: 2}) - var result testNestedScopedType - n, err := result.DecodeFrom(d) - if err != nil { - t.Fatalf("DecodeFrom with exact depth failed: %v", err) - } - if n != 8 { - t.Errorf("bytes read = %d, want 8", n) - } - }) -}