From 34f3187dc14b7c9e431467ea083676cb6b5bbe2c Mon Sep 17 00:00:00 2001 From: tamirms Date: Thu, 8 Jan 2026 19:36:20 +0000 Subject: [PATCH] go: optimize XDR decoding performance and reduce allocations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Performance optimizations for Go XDR decoding: - Use byte slice Decoder instead of io.Reader for zero-copy decoding - Preserve slice capacity during decoding (grow-only, no shrinking) - Handle optional types without unnecessary allocations - Union arms with primitive types decode directly into value fields 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- lib/xdrgen/generators/go.rb | 560 ++++++++++++++---- .../block_comments.x/MyXDR_generated.go | 48 +- .../const.x/MyXDR_generated.go | 87 +-- .../enum.x/MyXDR_generated.go | 117 ++-- .../nesting.x/MyXDR_generated.go | 128 ++-- .../optional.x/MyXDR_generated.go | 55 +- .../struct.x/MyXDR_generated.go | 47 +- .../test.x/MyXDR_generated.go | 357 +++++------ .../union.x/MyXDR_generated.go | 219 +++---- 9 files changed, 892 insertions(+), 726 deletions(-) diff --git a/lib/xdrgen/generators/go.rb b/lib/xdrgen/generators/go.rb index 4d0ec53b2..46735095f 100644 --- a/lib/xdrgen/generators/go.rb +++ b/lib/xdrgen/generators/go.rb @@ -123,8 +123,7 @@ def render_union_typedef(out, typedef, union) out.break - # Add accessors for of form val, ok := union.GetArmName() - # and val := union.MustArmName() + # Add accessors: GetX() returns (T, bool), MustX() returns T union.arms.each do |arm| next if arm.void? out.puts <<-EOS.strip_heredoc @@ -242,6 +241,8 @@ def render_struct(out, struct) end def render_enum(out, enum) + is_contiguous, min_val, max_val, all_values = enum_info(enum) + # render the "enum" out.puts "type #{name enum} int32" out.puts "const (" @@ -256,7 +257,17 @@ def render_enum(out, enum) end out.puts ")" - # render the map used by xdr to decide valid values + # render min/max constants for optimized validation (contiguous enums) + if is_contiguous + out.puts "const (" + out.indent do + out.puts "_#{name enum}_Min int32 = #{min_val}" + out.puts "_#{name enum}_Max int32 = #{max_val}" + end + out.puts ")" + end + + # render the map used for String() method out.puts "var #{private_name enum}Map = map[int32]string{" out.indent do @@ -269,14 +280,30 @@ def render_enum(out, enum) out.break - out.puts <<-EOS.strip_heredoc - // ValidEnum validates a proposed value for this enum. Implements - // the Enum interface for #{name enum} - func (e #{name enum}) ValidEnum(v int32) bool { - _, ok := #{private_name enum}Map[v] - return ok - } - EOS + # ValidEnum - use range check for contiguous, switch for non-contiguous + if is_contiguous + out.puts <<-EOS.strip_heredoc + // ValidEnum validates a proposed value for this enum. Implements + // the Enum interface for #{name enum} + func (e #{name enum}) ValidEnum(v int32) bool { + return v >= _#{name enum}_Min && v <= _#{name enum}_Max + } + EOS + else + cases = all_values.join(", ") + out.puts <<-EOS.strip_heredoc + // ValidEnum validates a proposed value for this enum. Implements + // the Enum interface for #{name enum} + func (e #{name enum}) ValidEnum(v int32) bool { + switch v { + case #{cases}: + return true + default: + return false + } + } + EOS + end out.puts <<-EOS.strip_heredoc // String returns the name of `e` @@ -297,7 +324,13 @@ def render_union(out, union) union.arms.each do |arm| next if arm.void? - out.puts "#{name arm} *#{reference arm.type} #{field_tag union, arm}" + if is_union_inline_type?(arm.type) + # Primitive types as values (public field) - no allocation needed + out.puts "#{name arm} #{reference arm.type} #{field_tag union, arm}" + else + # Complex types as pointers (public field) + out.puts "#{name arm} *#{reference arm.type} #{field_tag union, arm}" + end end end out.puts "}" @@ -370,7 +403,13 @@ def render_union_encode_to_interface(out, union) out2.puts "// Void" else mn = name(arm) - render_encode_to_body(out2, "(*u.#{mn})", arm.type, self_encode: false) + if is_union_inline_type?(arm.type) + # Primitive: encode from value field + render_encode_to_body(out2, "u.#{mn}", arm.type, self_encode: false) + else + # Complex: dereference pointer field + render_encode_to_body(out2, "(*u.#{mn})", arm.type, self_encode: false) + end end out2.puts "return nil" out2.string @@ -386,18 +425,27 @@ def render_union_encode_to_interface(out, union) end def render_enum_encode_to_interface(out, typedef) - name = name(typedef) - type = typedef - out.puts <<-EOS.strip_heredoc - // EncodeTo encodes this value using the Encoder. - func (e #{name}) EncodeTo(enc *xdr.Encoder) error { - if _, ok := #{private_name type}Map[int32(e)]; !ok { - return fmt.Errorf("'%d' is not a valid #{name} enum value", e) - } - _, err := enc.EncodeInt(int32(e)) - return err - } - EOS + enum_name = name(typedef) + is_contiguous, _, _, all_values = enum_info(typedef) + + out.puts "// EncodeTo encodes this value using the Encoder." + out.puts "func (e #{enum_name}) EncodeTo(enc *xdr.Encoder) error {" + if is_contiguous + out.puts " if int32(e) < _#{enum_name}_Min || int32(e) > _#{enum_name}_Max {" + out.puts " return fmt.Errorf(\"'%d' is not a valid #{enum_name} enum value\", e)" + out.puts " }" + else + cases = all_values.join(", ") + out.puts " switch int32(e) {" + out.puts " case #{cases}:" + out.puts " // valid" + out.puts " default:" + out.puts " return fmt.Errorf(\"'%d' is not a valid #{enum_name} enum value\", e)" + out.puts " }" + end + out.puts " _, err := enc.EncodeInt(int32(e))" + out.puts " return err" + out.puts "}" end def is_fixed_array_type(type) @@ -425,6 +473,65 @@ def render_typedef_encode_to_interface(out, typedef) out.break end + # Returns [encode_method, decode_method, go_type] for primitive types, or nil if not primitive + def primitive_type_info(type) + case type + when AST::Typespecs::UnsignedHyper then ['EncodeUhyper', 'DecodeUhyper', 'uint64'] + when AST::Typespecs::Hyper then ['EncodeHyper', 'DecodeHyper', 'int64'] + when AST::Typespecs::UnsignedInt then ['EncodeUint', 'DecodeUint', 'uint32'] + when AST::Typespecs::Int then ['EncodeInt', 'DecodeInt', 'int32'] + when AST::Typespecs::Bool then ['EncodeBool', 'DecodeBool', 'bool'] + when AST::Typespecs::Float then ['EncodeFloat', 'DecodeFloat', 'float32'] + when AST::Typespecs::Double then ['EncodeDouble', 'DecodeDouble', 'float64'] + else nil + end + end + + # Generates encode code for primitive types (int, uint, hyper, bool, float, double) + # Handles :simple, :optional, :array, and :var_array sub_types + def render_encode_primitive(out, var, type, encode_method, go_type) + check_err = ->(str) { " if #{str}; err != nil {\n return err\n }" } + + case type.sub_type + when :simple, :optional + # :optional is handled like :simple here because the optional wrapping + # (nil check + bool encode) is done before calling this helper + out.puts check_err.call("_, err = e.#{encode_method}(#{go_type}(#{var}))") + when :array + out.puts " for i := 0; i < len(#{var}); i++ {" + out.puts check_err.call(" _, err = e.#{encode_method}(#{go_type}(#{var}[i]))") + out.puts " }" + when :var_array + out.puts check_err.call("_, err = e.EncodeUint(uint32(len(#{var})))") + out.puts " for i := 0; i < len(#{var}); i++ {" + out.puts check_err.call(" _, err = e.#{encode_method}(#{go_type}(#{var}[i]))") + out.puts " }" + else + raise "Unknown sub_type: #{type.sub_type}" + end + end + + # Generates decode code for primitive types (int, uint, hyper, bool, float, double) + # Handles :simple, :optional, :array, and :var_array sub_types + def render_decode_primitive(out, var, type, decode_method, go_type, declared_variables, tail) + case type.sub_type + when :simple, :optional + # :optional is handled like :simple here because the optional wrapping + # (bool decode + nil check) is done before calling this helper + out.puts " #{var}, nTmp, err = d.#{decode_method}()" + out.puts tail + when :array + out.puts " for i := 0; i < len(#{var}); i++ {" + out.puts " #{var}[i], nTmp, err = d.#{decode_method}()" + out.puts tail + out.puts " }" + when :var_array + render_decode_var_array_primitive(out, var, go_type, decode_method, type, declared_variables, tail) + else + raise "Unknown sub_type: #{type.sub_type}" + end + end + # render_encode_to_body assumes there is an `e` variable containing an # xdr.Encoder, and a variable defined by `var` that is the value to # encode. @@ -442,26 +549,20 @@ def check_error(str) out.puts " if #{var} != nil {" var = "(*#{var})" end - case type - when AST::Typespecs::UnsignedHyper - out.puts check_error " _, err = e.EncodeUhyper(uint64(#{var}))" - when AST::Typespecs::Hyper - out.puts check_error "_, err = e.EncodeHyper(int64(#{var}))" - when AST::Typespecs::UnsignedInt - out.puts check_error "_, err = e.EncodeUint(uint32(#{var}))" - when AST::Typespecs::Int - out.puts (check_error "_, err = e.EncodeInt(int32(#{var}))") - when AST::Typespecs::Bool - out.puts (check_error "_, err = e.EncodeBool(bool(#{var}))") - when AST::Typespecs::String + # Check if this is a primitive type we can handle with the helper + primitive_info = primitive_type_info(type) + if primitive_info + encode_method, _, go_type = primitive_info + render_encode_primitive(out, var, type, encode_method, go_type) + elsif type.is_a?(AST::Typespecs::String) out.puts check_error "_, err = e.EncodeString(string(#{var}))" - when AST::Typespecs::Opaque + elsif type.is_a?(AST::Typespecs::Opaque) if type.fixed? out.puts check_error "_, err = e.EncodeFixedOpaque(#{var}[:])" else out.puts check_error "_, err = e.EncodeOpaque(#{var}[:])" end - when AST::Typespecs::Simple + elsif type.is_a?(AST::Typespecs::Simple) case type.sub_type when :simple, :optional optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional @@ -519,7 +620,7 @@ def check_error(str) else raise "Unknown sub_type: #{type.sub_type}" end - when AST::Definitions::Base + elsif type.is_a?(AST::Definitions::Base) if self_encode out.puts check_error "err = #{name type}(#{var}).EncodeTo(e)" else @@ -535,7 +636,7 @@ def check_error(str) def render_struct_decode_from_interface(out, struct) name = name(struct) - out.puts "// DecodeFrom decodes this value using the Decoder." + out.puts "// DecodeFrom decodes this value from the given decoder." out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" out.puts " if maxDepth == 0 {" out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)" @@ -555,7 +656,7 @@ def render_struct_decode_from_interface(out, struct) def render_union_decode_from_interface(out, union) name = name(union) - out.puts "// DecodeFrom decodes this value using the Decoder." + out.puts "// DecodeFrom decodes this value from the given decoder." out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" out.puts " if maxDepth == 0 {" out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)" @@ -571,8 +672,16 @@ def render_union_decode_from_interface(out, union) else mn = name(arm) type = arm.type - out2.puts " u.#{mn} = new(#{reference arm.type})" - render_decode_from_body(out2, "(*u.#{mn})",type, declared_variables: [], self_encode: false) + if is_union_inline_type?(type) + # Primitive: decode directly into value field - no allocation + render_decode_from_body(out2, "u.#{mn}", type, declared_variables: [], self_encode: false) + else + # Complex: allocate if nil, then decode + out2.puts " if u.#{mn} == nil {" + out2.puts " u.#{mn} = new(#{reference arm.type})" + out2.puts " }" + render_decode_from_body(out2, "(*u.#{mn})", type, declared_variables: [], self_encode: false) + end end out2.puts " return n, nil" out2.string @@ -585,32 +694,40 @@ def render_union_decode_from_interface(out, union) end def render_enum_decode_from_interface(out, typedef) - name = name(typedef) - type = typedef - out.puts <<-EOS.strip_heredoc - // DecodeFrom decodes this value using the Decoder. - func (e *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { - if maxDepth == 0 { - return 0, fmt.Errorf("decoding #{name}: %w", ErrMaxDecodingDepthReached) - } - maxDepth -= 1 - v, n, err := d.DecodeInt() - if err != nil { - return n, fmt.Errorf("decoding #{name}: %w", err) - } - if _, ok := #{private_name type}Map[v]; !ok { - return n, fmt.Errorf("'%d' is not a valid #{name} enum value", v) - } - *e = #{name}(v) - return n, nil - } - EOS + enum_name = name(typedef) + is_contiguous, _, _, all_values = enum_info(typedef) + + out.puts "// DecodeFrom decodes this value from the given decoder." + out.puts "func (e *#{enum_name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" + out.puts " if maxDepth == 0 {" + out.puts " return 0, fmt.Errorf(\"decoding #{enum_name}: %w\", ErrMaxDecodingDepthReached)" + out.puts " }" + out.puts " v, n, err := d.DecodeInt()" + out.puts " if err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{enum_name}: %w\", err)" + out.puts " }" + if is_contiguous + out.puts " if v < _#{enum_name}_Min || v > _#{enum_name}_Max {" + out.puts " return n, fmt.Errorf(\"'%d' is not a valid #{enum_name} enum value\", v)" + out.puts " }" + else + cases = all_values.join(", ") + out.puts " switch v {" + out.puts " case #{cases}:" + out.puts " // valid" + out.puts " default:" + out.puts " return n, fmt.Errorf(\"'%d' is not a valid #{enum_name} enum value\", v)" + out.puts " }" + end + out.puts " *e = #{enum_name}(v)" + out.puts " return n, nil" + out.puts "}" end def render_typedef_decode_from_interface(out, typedef) name = name(typedef) type = typedef.declaration.type - out.puts "// DecodeFrom decodes this value using the Decoder." + out.puts "// DecodeFrom decodes this value from the given decoder." out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {" out.puts " if maxDepth == 0 {" out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)" @@ -645,6 +762,37 @@ def render_variable_declaration(out, indent, var, type, declared_variables:) end end + # render_decode_var_array_primitive generates decode code for variable-length + # arrays of primitive types (int, uint, hyper, etc.) + def render_decode_var_array_primitive(out, var, go_type, decode_method, type, declared_variables, tail) + type_name = go_type # For primitives, the Go type is the element type + render_variable_declaration(out, " ", 'l', "uint32", declared_variables: declared_variables) + out.puts " l, nTmp, err = d.DecodeUint()" + out.puts tail + unless type.decl.resolved_size.nil? + out.puts " if l > #{type.decl.resolved_size} {" + out.puts " return n, fmt.Errorf(\"decoding #{type_name}: data size (%d) exceeds size limit (#{type.decl.resolved_size})\", l)" + out.puts " }" + end + # Slice capacity preservation (grow-only, no shrinking) + out.puts " if l == 0 {" + out.puts " #{var} = #{var}[:0]" + out.puts " } else {" + out.puts " if uint(d.Remaining()) < uint(l) {" + out.puts " return n, fmt.Errorf(\"decoding #{type_name}: length (%d) exceeds remaining input length (%d)\", l, d.Remaining())" + out.puts " }" + out.puts " if cap(#{var}) >= int(l) {" + out.puts " #{var} = #{var}[:l]" + out.puts " } else {" + out.puts " #{var} = make([]#{go_type}, l)" + out.puts " }" + out.puts " for i := uint32(0); i < l; i++ {" + out.puts " #{var}[i], nTmp, err = d.#{decode_method}()" + out.puts tail + out.puts " }" + out.puts " }" + end + # render_decode_from_body assumes there is an `d` variable containing an # xdr.Decoder, and a variable defined by `var` that is the value to # encode. @@ -664,28 +812,17 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " if b {" out.puts " #{var} = new(#{name type})" end - case type - when AST::Typespecs::UnsignedHyper - out.puts " #{var}, nTmp, err = d.DecodeUhyper()" - out.puts tail - when AST::Typespecs::Hyper - out.puts " #{var}, nTmp, err = d.DecodeHyper()" - out.puts tail - when AST::Typespecs::UnsignedInt - out.puts " #{var}, nTmp, err = d.DecodeUint()" - out.puts tail - when AST::Typespecs::Int - out.puts " #{var}, nTmp, err = d.DecodeInt()" - out.puts tail - when AST::Typespecs::Bool - out.puts " #{var}, nTmp, err = d.DecodeBool()" - out.puts tail - when AST::Typespecs::String + # Check if this is a primitive type we can handle with the helper + primitive_info = primitive_type_info(type) + if primitive_info + _, decode_method, go_type = primitive_info + render_decode_primitive(out, var, type, decode_method, go_type, declared_variables, tail) + elsif type.is_a?(AST::Typespecs::String) arg = "0" arg = type.decl.resolved_size unless type.decl.resolved_size.nil? out.puts " #{var}, nTmp, err = d.DecodeString(#{arg})" out.puts tail - when AST::Typespecs::Opaque + elsif type.is_a?(AST::Typespecs::Opaque) if type.fixed? out.puts " nTmp, err = d.DecodeFixedOpaqueInplace(#{var}[:])" else @@ -694,7 +831,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " #{var}, nTmp, err = d.DecodeOpaque(#{arg})" end out.puts tail - when AST::Typespecs::Simple + elsif type.is_a?(AST::Typespecs::Simple) case type.sub_type when :simple, :optional optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional @@ -739,12 +876,18 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " return n, fmt.Errorf(\"decoding #{name type}: data size (%d) exceeds size limit (#{type.decl.resolved_size})\", l)" out.puts " }" end - out.puts " #{var} = nil" - out.puts " if l > 0 {" - out.puts " if il, ok := d.InputLen(); ok && uint(il) < uint(l) {" - out.puts " return n, fmt.Errorf(\"decoding #{name type}: length (%d) exceeds remaining input length (%d)\", l, il)" + # Slice capacity preservation (grow-only, no shrinking) + out.puts " if l == 0 {" + out.puts " #{var} = #{var}[:0]" + out.puts " } else {" + out.puts " if uint(d.Remaining()) < uint(l) {" + out.puts " return n, fmt.Errorf(\"decoding #{name type}: length (%d) exceeds remaining input length (%d)\", l, d.Remaining())" + out.puts " }" + out.puts " if cap(#{var}) >= int(l) {" + out.puts " #{var} = #{var}[:l]" + out.puts " } else {" + out.puts " #{var} = make([]#{name type}, l)" out.puts " }" - out.puts " #{var} = make([]#{name type}, l)" out.puts " for i := uint32(0); i < l; i++ {" element_var = "#{var}[i]" optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional @@ -767,7 +910,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) else raise "Unknown sub_type: #{type.sub_type}" end - when AST::Definitions::Base + elsif type.is_a?(AST::Definitions::Base) if self_encode out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d, maxDepth)" else @@ -794,11 +937,8 @@ def render_binary_interface(out, name) out.break out.puts "// UnmarshalBinary implements encoding.BinaryUnmarshaler." out.puts "func (s *#{name}) UnmarshalBinary(inp []byte) error {" - out.puts " r := bytes.NewReader(inp)" - out.puts " o := xdr.DefaultDecodeOptions" - out.puts " o.MaxInputLen = len(inp)" - out.puts " d := xdr.NewDecoderWithOptions(r, o)" - out.puts " _, err := s.DecodeFrom(d, o.MaxDepth)" + out.puts " d := xdr.NewDecoder(inp)" + out.puts " _, err := s.DecodeFrom(d, d.MaxDepth())" out.puts " return err" out.puts "}" out.break @@ -818,7 +958,7 @@ def render_xdr_type_interface(out, name) end def render_decoder_from_interface(out, name) - out.puts "var _ decoderFrom = (*#{name})(nil)" + out.puts "var _ xdr.DecoderFrom = (*#{name})(nil)" end def render_top_matter(out) @@ -852,29 +992,22 @@ def render_top_matter(out) EOS out.break out.puts <<-EOS.strip_heredoc - var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } - type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) - } - - // Unmarshal reads an xdr element from `r` into `v`. - func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) - } + // ErrMaxDecodingDepthReached is returned when the maximum decoding depth is + // exceeded. This prevents stack overflow from deeply nested structures. + var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - // UnmarshalWithOptions works like Unmarshal but uses decoding options. - func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) + // Unmarshal reads an xdr element from `data` into `v`. + func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -975,6 +1108,82 @@ def private_name(named) escape_name named.name.underscore.camelize(:lower) end + # Returns [is_contiguous, min_value, max_value, all_values] for an enum + # An enum is contiguous if its values form a complete sequence from min to max + def enum_info(enum) + values = enum.members.map { |m| m.value } + min_val = values.min + max_val = values.max + is_contiguous = (values.length == (max_val - min_val + 1)) && (values.uniq.length == values.length) + [is_contiguous, min_val, max_val, values] + end + + # Determines if a union arm type should be inlined (stored as a value, not + # a pointer) in the union struct. Inlined types eliminate heap allocations + # during decode. + # + # Only small primitives (≤8 bytes) are inlined: + # - Base primitives: bool, int32, uint32, int64, uint64, float32, float64 + # - Enums (int32) + # - Typedefs that resolve to the above + # + # NOT inlined (to prevent memory bloat in unions with many arms): + # - Fixed-length opaque (could be 32+ bytes like Hash) + # - Fixed-length arrays + # - Structs (even fixed-size ones like Int256Parts = 32 bytes) + # - Strings, variable-length opaque, variable-length arrays + # - Optional types (pointers) + def is_union_inline_type?(type) + # Reject optional, var_array, array sub_types + return false unless type.sub_type == :simple + + case type + when AST::Typespecs::Bool, AST::Typespecs::Int, AST::Typespecs::UnsignedInt, + AST::Typespecs::Hyper, AST::Typespecs::UnsignedHyper, + AST::Typespecs::Float, AST::Typespecs::Double + true # Base primitives ≤8 bytes + when AST::Typespecs::Simple + resolved = type.resolved_type + case resolved + when AST::Definitions::Typedef + is_union_inline_type?(resolved.declaration.type) + when AST::Definitions::Enum + true # Enums are int32 (4 bytes) + else + false # Exclude structs, unions + end + else + false # Exclude opaque, strings, arrays + end + end + + # Returns Go condition string for checking if an arm is active based on discriminant + # Returns nil for default arms (which need special handling) + def arm_discriminant_check(arm) + union = arm.union + discriminant_field = "u.#{name union.discriminant}" + + # Default arms need special handling - return nil to signal fallback + return nil if arm == union.default_arm + + # Build condition for all cases that map to this arm + conditions = arm.cases.map do |c| + value = if c.value.is_a?(AST::Identifier) + member = union.resolved_case(c) + if union.discriminant_type.nil? + "int32(#{name member.enum}#{name member})" + else + "#{name union.discriminant_type}#{name member}" + end + else + c.value.text_value + end + "#{reference union.discriminant.type}(#{discriminant_field}) == #{value}" + end + + conditions.join(" || ") + end + def escape_name(name) case name when "type" ; "aType" @@ -999,7 +1208,18 @@ def render_union_constructor(out, union) switch_for(out, union, discriminant_arg) do |arm, kase| if arm.void? "// void" + elsif is_union_inline_type?(arm.type) + # Primitive: direct value assignment + <<-EOS + tv, ok := value.(#{reference arm.type}) + if !ok { + err = errors.New("invalid value, must be #{reference arm.type}") + return + } + result.#{name arm} = tv + EOS else + # Complex: pointer assignment <<-EOS tv, ok := value.(#{reference arm.type}) if !ok { @@ -1018,33 +1238,113 @@ def render_union_constructor(out, union) end def access_arm(arm) + condition = arm_discriminant_check(arm) - <<-EOS.strip_heredoc - // Must#{name arm} retrieves the #{name arm} value from the union, - // panicing if the value is not set. - func (u #{name arm.union}) Must#{name arm}() #{reference arm.type} { - val, ok := u.Get#{name arm}() + # For default arms, fall back to ArmForSwitch (less common case) + if condition.nil? + return access_arm_with_arm_for_switch(arm) + end - if !ok { + # Optimized accessors using direct discriminant comparison + # Note: No X() *T pointer-returning accessors per spec + if is_union_inline_type?(arm.type) + # Primitive: GetX() returns (T, bool), MustX() returns T + <<-EOS.strip_heredoc + // Must#{name arm} retrieves the #{name arm} value from the union, + // panicing if the value is not set. + func (u #{name arm.union}) Must#{name arm}() #{reference arm.type} { + if #{condition} { + return u.#{name arm} + } panic("arm #{name arm} is not set") } - return val - } + // Get#{name arm} retrieves the #{name arm} value from the union, + // returning ok if the union's switch indicated the value is valid. + func (u #{name arm.union}) Get#{name arm}() (result #{reference arm.type}, ok bool) { + if #{condition} { + result = u.#{name arm} + ok = true + } + return + } + EOS + else + # Complex: GetX() returns (T, bool), MustX() returns T (dereference pointer) + <<-EOS.strip_heredoc + // Must#{name arm} retrieves the #{name arm} value from the union, + // panicing if the value is not set. + func (u #{name arm.union}) Must#{name arm}() #{reference arm.type} { + if #{condition} { + return *u.#{name arm} + } + panic("arm #{name arm} is not set") + } - // Get#{name arm} retrieves the #{name arm} value from the union, - // returning ok if the union's switch indicated the value is valid. - func (u #{name arm.union}) Get#{name arm}() (result #{reference arm.type}, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.#{name arm.union.discriminant})) + // Get#{name arm} retrieves the #{name arm} value from the union, + // returning ok if the union's switch indicated the value is valid. + func (u #{name arm.union}) Get#{name arm}() (result #{reference arm.type}, ok bool) { + if #{condition} { + result = *u.#{name arm} + ok = true + } + return + } + EOS + end + end - if armName == "#{name arm}" { - result = *u.#{name arm} - ok = true + # Fallback for default arms - uses ArmForSwitch (less common) + def access_arm_with_arm_for_switch(arm) + if is_union_inline_type?(arm.type) + # Primitive: GetX() returns (T, bool), MustX() returns T + <<-EOS.strip_heredoc + // Must#{name arm} retrieves the #{name arm} value from the union, + // panicing if the value is not set. + func (u #{name arm.union}) Must#{name arm}() #{reference arm.type} { + val, ok := u.Get#{name arm}() + if !ok { + panic("arm #{name arm} is not set") + } + return val } - return - } - EOS + // Get#{name arm} retrieves the #{name arm} value from the union, + // returning ok if the union's switch indicated the value is valid. + func (u #{name arm.union}) Get#{name arm}() (result #{reference arm.type}, ok bool) { + armName, _ := u.ArmForSwitch(int32(u.#{name arm.union.discriminant})) + if armName == "#{name arm}" { + result = u.#{name arm} + ok = true + } + return + } + EOS + else + # Complex: GetX() returns (T, bool), MustX() returns T (dereference pointer) + <<-EOS.strip_heredoc + // Must#{name arm} retrieves the #{name arm} value from the union, + // panicing if the value is not set. + func (u #{name arm.union}) Must#{name arm}() #{reference arm.type} { + val, ok := u.Get#{name arm}() + if !ok { + panic("arm #{name arm} is not set") + } + return val + } + + // Get#{name arm} retrieves the #{name arm} value from the union, + // returning ok if the union's switch indicated the value is valid. + func (u #{name arm.union}) Get#{name arm}() (result #{reference arm.type}, ok bool) { + armName, _ := u.ArmForSwitch(int32(u.#{name arm.union.discriminant})) + if armName == "#{name arm}" { + result = *u.#{name arm} + ok = true + } + return + } + EOS + end end def size(size_s) diff --git a/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go b/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go index 887aeed6f..52e6c3377 100644 --- a/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/block_comments.x": "e13131bc4134f38da17b9d5e9f67d2695a69ef98e3ef272833f4c18d0cc88a30", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -74,6 +67,10 @@ type AccountFlags int32 const ( AccountFlagsAuthRequiredFlag AccountFlags = 1 ) +const ( + _AccountFlags_Min int32 = 1 + _AccountFlags_Max int32 = 1 +) var accountFlagsMap = map[int32]string{ 1: "AccountFlagsAuthRequiredFlag", } @@ -81,8 +78,7 @@ var accountFlagsMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for AccountFlags func (e AccountFlags) ValidEnum(v int32) bool { - _, ok := accountFlagsMap[v] - return ok + return v >= _AccountFlags_Min && v <= _AccountFlags_Max } // String returns the name of `e` func (e AccountFlags) String() string { @@ -92,24 +88,23 @@ func (e AccountFlags) String() string { // EncodeTo encodes this value using the Encoder. func (e AccountFlags) EncodeTo(enc *xdr.Encoder) error { - if _, ok := accountFlagsMap[int32(e)]; !ok { + if int32(e) < _AccountFlags_Min || int32(e) > _AccountFlags_Max { return fmt.Errorf("'%d' is not a valid AccountFlags enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*AccountFlags)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*AccountFlags)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *AccountFlags) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding AccountFlags: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding AccountFlags: %w", err) } - if _, ok := accountFlagsMap[v]; !ok { + if v < _AccountFlags_Min || v > _AccountFlags_Max { return n, fmt.Errorf("'%d' is not a valid AccountFlags enum value", v) } *e = AccountFlags(v) @@ -125,11 +120,8 @@ func (s AccountFlags) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *AccountFlags) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/const.x/MyXDR_generated.go b/spec/output/generator_spec_go/const.x/MyXDR_generated.go index d3f983e1b..bf487bcf2 100644 --- a/spec/output/generator_spec_go/const.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/const.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/const.x": "0bff3b37592fcc16cad2fe10b9a72f5d39d033a114917c24e86a9ebd9cda9c37", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -77,14 +70,16 @@ type TestArray [Foo]int32 // EncodeTo encodes this value using the Encoder. func (s *TestArray) EncodeTo(e *xdr.Encoder) error { var err error - if _, err = e.EncodeInt(int32(s)); err != nil { + for i := 0; i < len(s); i++ { + if _, err = e.EncodeInt(int32(s[i])); err != nil { return err } + } return nil } -var _ decoderFrom = (*TestArray)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*TestArray)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *TestArray) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding TestArray: %w", ErrMaxDecodingDepthReached) @@ -93,11 +88,13 @@ func (s *TestArray) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { var err error var n, nTmp int var v [Foo]int32 - v, nTmp, err = d.DecodeInt() + for i := 0; i < len(v); i++ { + v[i], nTmp, err = d.DecodeInt() n += nTmp if err != nil { return n, fmt.Errorf("decoding Int: %w", err) } + } *s = TestArray(v) return n, nil } @@ -112,11 +109,8 @@ func (s TestArray) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *TestArray) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -142,14 +136,19 @@ func (e TestArray2) XDRMaxSize() int { // EncodeTo encodes this value using the Encoder. func (s TestArray2) EncodeTo(e *xdr.Encoder) error { var err error - if _, err = e.EncodeInt(int32(s)); err != nil { + if _, err = e.EncodeUint(uint32(len(s))); err != nil { + return err + } + for i := 0; i < len(s); i++ { + if _, err = e.EncodeInt(int32(s[i])); err != nil { return err } + } return nil } -var _ decoderFrom = (*TestArray2)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*TestArray2)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *TestArray2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding TestArray2: %w", ErrMaxDecodingDepthReached) @@ -158,10 +157,33 @@ func (s *TestArray2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { var err error var n, nTmp int var v []int32 - v, nTmp, err = d.DecodeInt() + var l uint32 + l, nTmp, err = d.DecodeUint() + n += nTmp + if err != nil { + return n, fmt.Errorf("decoding Int: %w", err) + } + if l > 1 { + return n, fmt.Errorf("decoding int32: data size (%d) exceeds size limit (1)", l) + } + if l == 0 { + v = v[:0] + } else { + if uint(d.Remaining()) < uint(l) { + return n, fmt.Errorf("decoding int32: length (%d) exceeds remaining input length (%d)", l, d.Remaining()) + } + if cap(v) >= int(l) { + v = v[:l] + } else { + v = make([]int32, l) + } + for i := uint32(0); i < l; i++ { + v[i], nTmp, err = d.DecodeInt() n += nTmp if err != nil { return n, fmt.Errorf("decoding Int: %w", err) + } + } } *s = TestArray2(v) return n, nil @@ -177,11 +199,8 @@ func (s TestArray2) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *TestArray2) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/enum.x/MyXDR_generated.go b/spec/output/generator_spec_go/enum.x/MyXDR_generated.go index 3fb16f7fc..f7466acca 100644 --- a/spec/output/generator_spec_go/enum.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/enum.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/enum.x": "f764c2a2d349765e611f686e9d416b7f576ea881154d069355a2e75c898daf58", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -106,6 +99,10 @@ const ( MessageTypeFbaQuorumset MessageType = 12 MessageTypeFbaMessage MessageType = 13 ) +const ( + _MessageType_Min int32 = 0 + _MessageType_Max int32 = 13 +) var messageTypeMap = map[int32]string{ 0: "MessageTypeErrorMsg", 1: "MessageTypeHello", @@ -126,8 +123,7 @@ var messageTypeMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for MessageType func (e MessageType) ValidEnum(v int32) bool { - _, ok := messageTypeMap[v] - return ok + return v >= _MessageType_Min && v <= _MessageType_Max } // String returns the name of `e` func (e MessageType) String() string { @@ -137,24 +133,23 @@ func (e MessageType) String() string { // EncodeTo encodes this value using the Encoder. func (e MessageType) EncodeTo(enc *xdr.Encoder) error { - if _, ok := messageTypeMap[int32(e)]; !ok { + if int32(e) < _MessageType_Min || int32(e) > _MessageType_Max { return fmt.Errorf("'%d' is not a valid MessageType enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*MessageType)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MessageType)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *MessageType) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MessageType: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding MessageType: %w", err) } - if _, ok := messageTypeMap[v]; !ok { + if v < _MessageType_Min || v > _MessageType_Max { return n, fmt.Errorf("'%d' is not a valid MessageType enum value", v) } *e = MessageType(v) @@ -170,11 +165,8 @@ func (s MessageType) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MessageType) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -202,6 +194,10 @@ const ( ColorGreen Color = 1 ColorBlue Color = 2 ) +const ( + _Color_Min int32 = 0 + _Color_Max int32 = 2 +) var colorMap = map[int32]string{ 0: "ColorRed", 1: "ColorGreen", @@ -211,8 +207,7 @@ var colorMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for Color func (e Color) ValidEnum(v int32) bool { - _, ok := colorMap[v] - return ok + return v >= _Color_Min && v <= _Color_Max } // String returns the name of `e` func (e Color) String() string { @@ -222,24 +217,23 @@ func (e Color) String() string { // EncodeTo encodes this value using the Encoder. func (e Color) EncodeTo(enc *xdr.Encoder) error { - if _, ok := colorMap[int32(e)]; !ok { + if int32(e) < _Color_Min || int32(e) > _Color_Max { return fmt.Errorf("'%d' is not a valid Color enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*Color)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Color)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *Color) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Color: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding Color: %w", err) } - if _, ok := colorMap[v]; !ok { + if v < _Color_Min || v > _Color_Max { return n, fmt.Errorf("'%d' is not a valid Color enum value", v) } *e = Color(v) @@ -255,11 +249,8 @@ func (s Color) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Color) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -287,6 +278,10 @@ const ( Color2Green2 Color2 = 1 Color2Blue2 Color2 = 2 ) +const ( + _Color2_Min int32 = 0 + _Color2_Max int32 = 2 +) var color2Map = map[int32]string{ 0: "Color2Red2", 1: "Color2Green2", @@ -296,8 +291,7 @@ var color2Map = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for Color2 func (e Color2) ValidEnum(v int32) bool { - _, ok := color2Map[v] - return ok + return v >= _Color2_Min && v <= _Color2_Max } // String returns the name of `e` func (e Color2) String() string { @@ -307,24 +301,23 @@ func (e Color2) String() string { // EncodeTo encodes this value using the Encoder. func (e Color2) EncodeTo(enc *xdr.Encoder) error { - if _, ok := color2Map[int32(e)]; !ok { + if int32(e) < _Color2_Min || int32(e) > _Color2_Max { return fmt.Errorf("'%d' is not a valid Color2 enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*Color2)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Color2)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *Color2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Color2: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding Color2: %w", err) } - if _, ok := color2Map[v]; !ok { + if v < _Color2_Min || v > _Color2_Max { return n, fmt.Errorf("'%d' is not a valid Color2 enum value", v) } *e = Color2(v) @@ -340,11 +333,8 @@ func (s Color2) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Color2) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -372,6 +362,10 @@ const ( Color3Red2Two Color3 = 2 Color3Red3 Color3 = 3 ) +const ( + _Color3_Min int32 = 1 + _Color3_Max int32 = 3 +) var color3Map = map[int32]string{ 1: "Color3Red1", 2: "Color3Red2Two", @@ -381,8 +375,7 @@ var color3Map = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for Color3 func (e Color3) ValidEnum(v int32) bool { - _, ok := color3Map[v] - return ok + return v >= _Color3_Min && v <= _Color3_Max } // String returns the name of `e` func (e Color3) String() string { @@ -392,24 +385,23 @@ func (e Color3) String() string { // EncodeTo encodes this value using the Encoder. func (e Color3) EncodeTo(enc *xdr.Encoder) error { - if _, ok := color3Map[int32(e)]; !ok { + if int32(e) < _Color3_Min || int32(e) > _Color3_Max { return fmt.Errorf("'%d' is not a valid Color3 enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*Color3)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Color3)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *Color3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Color3: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding Color3: %w", err) } - if _, ok := color3Map[v]; !ok { + if v < _Color3_Min || v > _Color3_Max { return n, fmt.Errorf("'%d' is not a valid Color3 enum value", v) } *e = Color3(v) @@ -425,11 +417,8 @@ func (s Color3) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Color3) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go index dbd781b2e..9f175fe7c 100644 --- a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/nesting.x": "5537949272c11f1bd09cf613a3751668b5018d686a1c2aaa3baa91183ca18f6a", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -77,6 +70,10 @@ const ( UnionKeyTwo UnionKey = 2 UnionKeyOffer UnionKey = 3 ) +const ( + _UnionKey_Min int32 = 1 + _UnionKey_Max int32 = 3 +) var unionKeyMap = map[int32]string{ 1: "UnionKeyOne", 2: "UnionKeyTwo", @@ -86,8 +83,7 @@ var unionKeyMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for UnionKey func (e UnionKey) ValidEnum(v int32) bool { - _, ok := unionKeyMap[v] - return ok + return v >= _UnionKey_Min && v <= _UnionKey_Max } // String returns the name of `e` func (e UnionKey) String() string { @@ -97,24 +93,23 @@ func (e UnionKey) String() string { // EncodeTo encodes this value using the Encoder. func (e UnionKey) EncodeTo(enc *xdr.Encoder) error { - if _, ok := unionKeyMap[int32(e)]; !ok { + if int32(e) < _UnionKey_Min || int32(e) > _UnionKey_Max { return fmt.Errorf("'%d' is not a valid UnionKey enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*UnionKey)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*UnionKey)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *UnionKey) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding UnionKey: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding UnionKey: %w", err) } - if _, ok := unionKeyMap[v]; !ok { + if v < _UnionKey_Min || v > _UnionKey_Max { return n, fmt.Errorf("'%d' is not a valid UnionKey enum value", v) } *e = UnionKey(v) @@ -130,11 +125,8 @@ func (s UnionKey) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *UnionKey) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -163,8 +155,8 @@ func (s Foo) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Foo)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Foo)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Foo) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Foo: %w", ErrMaxDecodingDepthReached) @@ -192,11 +184,8 @@ func (s Foo) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Foo) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -229,8 +218,8 @@ func (s *MyUnionOne) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*MyUnionOne)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MyUnionOne)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *MyUnionOne) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MyUnionOne: %w", ErrMaxDecodingDepthReached) @@ -256,11 +245,8 @@ func (s MyUnionOne) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MyUnionOne) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -298,8 +284,8 @@ func (s *MyUnionTwo) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*MyUnionTwo)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MyUnionTwo)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *MyUnionTwo) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MyUnionTwo: %w", ErrMaxDecodingDepthReached) @@ -330,11 +316,8 @@ func (s MyUnionTwo) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MyUnionTwo) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -419,49 +402,37 @@ switch UnionKey(aType) { // MustOne retrieves the One value from the union, // panicing if the value is not set. func (u MyUnion) MustOne() MyUnionOne { - val, ok := u.GetOne() - - if !ok { - panic("arm One is not set") + if UnionKey(u.Type) == UnionKeyOne { + return *u.One } - - return val + panic("arm One is not set") } // GetOne retrieves the One value from the union, // returning ok if the union's switch indicated the value is valid. func (u MyUnion) GetOne() (result MyUnionOne, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.Type)) - - if armName == "One" { + if UnionKey(u.Type) == UnionKeyOne { result = *u.One ok = true } - return } // MustTwo retrieves the Two value from the union, // panicing if the value is not set. func (u MyUnion) MustTwo() MyUnionTwo { - val, ok := u.GetTwo() - - if !ok { - panic("arm Two is not set") + if UnionKey(u.Type) == UnionKeyTwo { + return *u.Two } - - return val + panic("arm Two is not set") } // GetTwo retrieves the Two value from the union, // returning ok if the union's switch indicated the value is valid. func (u MyUnion) GetTwo() (result MyUnionTwo, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.Type)) - - if armName == "Two" { + if UnionKey(u.Type) == UnionKeyTwo { result = *u.Two ok = true } - return } @@ -489,8 +460,8 @@ return nil return fmt.Errorf("Type (UnionKey) switch value '%d' is not valid for union MyUnion", u.Type) } -var _ decoderFrom = (*MyUnion)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MyUnion)(nil) +// DecodeFrom decodes this value from the given decoder. func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MyUnion: %w", ErrMaxDecodingDepthReached) @@ -505,7 +476,9 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch UnionKey(u.Type) { case UnionKeyOne: - u.One = new(MyUnionOne) + if u.One == nil { + u.One = new(MyUnionOne) + } nTmp, err = (*u.One).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -513,7 +486,9 @@ switch UnionKey(u.Type) { } return n, nil case UnionKeyTwo: - u.Two = new(MyUnionTwo) + if u.Two == nil { + u.Two = new(MyUnionTwo) + } nTmp, err = (*u.Two).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -537,11 +512,8 @@ func (s MyUnion) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MyUnion) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go index 5ffb0fbd5..1a6357c99 100644 --- a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/optional.x": "3241e832fcf00bca4315ecb6c259621dafb0e302a63a993f5504b0b5cebb6bd7", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -71,14 +64,16 @@ type Arr [2]int32 // EncodeTo encodes this value using the Encoder. func (s *Arr) EncodeTo(e *xdr.Encoder) error { var err error - if _, err = e.EncodeInt(int32(s)); err != nil { + for i := 0; i < len(s); i++ { + if _, err = e.EncodeInt(int32(s[i])); err != nil { return err } + } return nil } -var _ decoderFrom = (*Arr)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Arr)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Arr) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Arr: %w", ErrMaxDecodingDepthReached) @@ -87,11 +82,13 @@ func (s *Arr) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { var err error var n, nTmp int var v [2]int32 - v, nTmp, err = d.DecodeInt() + for i := 0; i < len(v); i++ { + v[i], nTmp, err = d.DecodeInt() n += nTmp if err != nil { return n, fmt.Errorf("decoding Int: %w", err) } + } *s = Arr(v) return n, nil } @@ -106,11 +103,8 @@ func (s Arr) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Arr) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -169,8 +163,8 @@ func (s *HasOptions) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*HasOptions)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*HasOptions)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding HasOptions: %w", ErrMaxDecodingDepthReached) @@ -234,11 +228,8 @@ func (s HasOptions) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *HasOptions) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/struct.x/MyXDR_generated.go b/spec/output/generator_spec_go/struct.x/MyXDR_generated.go index f427dc4c2..d84bb9bad 100644 --- a/spec/output/generator_spec_go/struct.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/struct.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/struct.x": "c6911a83390e3b499c078fd0c579132eacce88a4a0538d3b8b5e57747a58db4a", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -78,8 +71,8 @@ func (s Int64) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Int64)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Int64)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Int64) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Int64: %w", ErrMaxDecodingDepthReached) @@ -107,11 +100,8 @@ func (s Int64) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Int64) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -165,8 +155,8 @@ func (s *MyStruct) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*MyStruct)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MyStruct)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MyStruct: %w", ErrMaxDecodingDepthReached) @@ -212,11 +202,8 @@ func (s MyStruct) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MyStruct) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/test.x/MyXDR_generated.go b/spec/output/generator_spec_go/test.x/MyXDR_generated.go index 1a20a446a..6f48940d3 100644 --- a/spec/output/generator_spec_go/test.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/test.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/test.x": "d29a98a6a3b9bf533a3e6712d928e0bed655e0f462ac4dae810c65d52ca9af41", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -82,8 +75,8 @@ func (s *Uint512) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Uint512)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Uint512)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Uint512) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Uint512: %w", ErrMaxDecodingDepthReached) @@ -109,11 +102,8 @@ func (s Uint512) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Uint512) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -146,8 +136,8 @@ func (s Uint513) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Uint513)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Uint513)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Uint513) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Uint513: %w", ErrMaxDecodingDepthReached) @@ -173,11 +163,8 @@ func (s Uint513) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Uint513) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -206,8 +193,8 @@ func (s Uint514) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Uint514)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Uint514)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Uint514) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Uint514: %w", ErrMaxDecodingDepthReached) @@ -233,11 +220,8 @@ func (s Uint514) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Uint514) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -270,8 +254,8 @@ func (s Str) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Str)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Str)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Str) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Str: %w", ErrMaxDecodingDepthReached) @@ -299,11 +283,8 @@ func (s Str) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Str) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -332,8 +313,8 @@ func (s Str2) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Str2)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Str2)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Str2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Str2: %w", ErrMaxDecodingDepthReached) @@ -361,11 +342,8 @@ func (s Str2) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Str2) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -398,8 +376,8 @@ func (s *Hash) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Hash)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Hash)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Hash) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Hash: %w", ErrMaxDecodingDepthReached) @@ -425,11 +403,8 @@ func (s Hash) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Hash) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -459,8 +434,8 @@ func (s *Hashes1) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Hashes1)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Hashes1)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Hashes1) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Hashes1: %w", ErrMaxDecodingDepthReached) @@ -488,11 +463,8 @@ func (s Hashes1) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Hashes1) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -529,8 +501,8 @@ func (s Hashes2) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Hashes2)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Hashes2)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Hashes2: %w", ErrMaxDecodingDepthReached) @@ -547,12 +519,17 @@ func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if l > 12 { return n, fmt.Errorf("decoding Hash: data size (%d) exceeds size limit (12)", l) } - (*s) = nil - if l > 0 { - if il, ok := d.InputLen(); ok && uint(il) < uint(l) { - return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il) + if l == 0 { + (*s) = (*s)[:0] + } else { + if uint(d.Remaining()) < uint(l) { + return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, d.Remaining()) + } + if cap((*s)) >= int(l) { + (*s) = (*s)[:l] + } else { + (*s) = make([]Hash, l) } - (*s) = make([]Hash, l) for i := uint32(0); i < l; i++ { nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp @@ -574,11 +551,8 @@ func (s Hashes2) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Hashes2) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -611,8 +585,8 @@ func (s Hashes3) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Hashes3)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Hashes3)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Hashes3: %w", ErrMaxDecodingDepthReached) @@ -626,12 +600,17 @@ func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if err != nil { return n, fmt.Errorf("decoding Hash: %w", err) } - (*s) = nil - if l > 0 { - if il, ok := d.InputLen(); ok && uint(il) < uint(l) { - return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il) + if l == 0 { + (*s) = (*s)[:0] + } else { + if uint(d.Remaining()) < uint(l) { + return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, d.Remaining()) + } + if cap((*s)) >= int(l) { + (*s) = (*s)[:l] + } else { + (*s) = make([]Hash, l) } - (*s) = make([]Hash, l) for i := uint32(0); i < l; i++ { nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp @@ -653,11 +632,8 @@ func (s Hashes3) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Hashes3) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -696,8 +672,8 @@ func (s Int1) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Int1)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Int1)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Int1) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Int1: %w", ErrMaxDecodingDepthReached) @@ -725,11 +701,8 @@ func (s Int1) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Int1) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -758,8 +731,8 @@ func (s Int2) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Int2)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Int2)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Int2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Int2: %w", ErrMaxDecodingDepthReached) @@ -787,11 +760,8 @@ func (s Int2) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Int2) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -820,8 +790,8 @@ func (s Int3) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Int3)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Int3)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Int3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Int3: %w", ErrMaxDecodingDepthReached) @@ -849,11 +819,8 @@ func (s Int3) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Int3) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -876,14 +843,14 @@ type Int4 uint64 // EncodeTo encodes this value using the Encoder. func (s Int4) EncodeTo(e *xdr.Encoder) error { var err error - if _, err = e.EncodeUhyper(uint64(s)); err != nil { + if _, err = e.EncodeUhyper(uint64(s)); err != nil { return err } return nil } -var _ decoderFrom = (*Int4)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Int4)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Int4) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Int4: %w", ErrMaxDecodingDepthReached) @@ -911,11 +878,8 @@ func (s Int4) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Int4) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -972,10 +936,10 @@ func (s *MyStruct) EncodeTo(e *xdr.Encoder) error { if _, err = e.EncodeUint(uint32(s.Field4)); err != nil { return err } - if _, err = e.Encode(s.Field5); err != nil { + if _, err = e.EncodeFloat(float32(s.Field5)); err != nil { return err } - if _, err = e.Encode(s.Field6); err != nil { + if _, err = e.EncodeDouble(float64(s.Field6)); err != nil { return err } if _, err = e.EncodeBool(bool(s.Field7)); err != nil { @@ -984,8 +948,8 @@ func (s *MyStruct) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*MyStruct)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MyStruct)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MyStruct: %w", ErrMaxDecodingDepthReached) @@ -1023,12 +987,12 @@ func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if err != nil { return n, fmt.Errorf("decoding Unsigned int: %w", err) } - nTmp, err = d.DecodeWithMaxDepth(&s.Field5, maxDepth) + s.Field5, nTmp, err = d.DecodeFloat() n += nTmp if err != nil { return n, fmt.Errorf("decoding Float: %w", err) } - nTmp, err = d.DecodeWithMaxDepth(&s.Field6, maxDepth) + s.Field6, nTmp, err = d.DecodeDouble() n += nTmp if err != nil { return n, fmt.Errorf("decoding Double: %w", err) @@ -1051,11 +1015,8 @@ func (s MyStruct) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MyStruct) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1094,8 +1055,8 @@ func (s *LotsOfMyStructs) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*LotsOfMyStructs)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*LotsOfMyStructs)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding LotsOfMyStructs: %w", ErrMaxDecodingDepthReached) @@ -1109,12 +1070,17 @@ func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) if err != nil { return n, fmt.Errorf("decoding MyStruct: %w", err) } - s.Members = nil - if l > 0 { - if il, ok := d.InputLen(); ok && uint(il) < uint(l) { - return n, fmt.Errorf("decoding MyStruct: length (%d) exceeds remaining input length (%d)", l, il) + if l == 0 { + s.Members = s.Members[:0] + } else { + if uint(d.Remaining()) < uint(l) { + return n, fmt.Errorf("decoding MyStruct: length (%d) exceeds remaining input length (%d)", l, d.Remaining()) + } + if cap(s.Members) >= int(l) { + s.Members = s.Members[:l] + } else { + s.Members = make([]MyStruct, l) } - s.Members = make([]MyStruct, l) for i := uint32(0); i < l; i++ { nTmp, err = s.Members[i].DecodeFrom(d, maxDepth) n += nTmp @@ -1136,11 +1102,8 @@ func (s LotsOfMyStructs) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *LotsOfMyStructs) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1174,8 +1137,8 @@ func (s *HasStuff) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*HasStuff)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*HasStuff)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *HasStuff) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding HasStuff: %w", ErrMaxDecodingDepthReached) @@ -1201,11 +1164,8 @@ func (s HasStuff) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *HasStuff) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1242,8 +1202,12 @@ var colorMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for Color func (e Color) ValidEnum(v int32) bool { - _, ok := colorMap[v] - return ok + switch v { + case 0, 5, 6: + return true + default: + return false + } } // String returns the name of `e` func (e Color) String() string { @@ -1253,24 +1217,29 @@ func (e Color) String() string { // EncodeTo encodes this value using the Encoder. func (e Color) EncodeTo(enc *xdr.Encoder) error { - if _, ok := colorMap[int32(e)]; !ok { + switch int32(e) { + case 0, 5, 6: + // valid + default: return fmt.Errorf("'%d' is not a valid Color enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*Color)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Color)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *Color) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Color: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding Color: %w", err) } - if _, ok := colorMap[v]; !ok { + switch v { + case 0, 5, 6: + // valid + default: return n, fmt.Errorf("'%d' is not a valid Color enum value", v) } *e = Color(v) @@ -1286,11 +1255,8 @@ func (s Color) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Color) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1328,6 +1294,10 @@ const ( NesterNestedEnumBlah1 NesterNestedEnum = 0 NesterNestedEnumBlah2 NesterNestedEnum = 1 ) +const ( + _NesterNestedEnum_Min int32 = 0 + _NesterNestedEnum_Max int32 = 1 +) var nestedEnumMap = map[int32]string{ 0: "NesterNestedEnumBlah1", 1: "NesterNestedEnumBlah2", @@ -1336,8 +1306,7 @@ var nestedEnumMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for NesterNestedEnum func (e NesterNestedEnum) ValidEnum(v int32) bool { - _, ok := nestedEnumMap[v] - return ok + return v >= _NesterNestedEnum_Min && v <= _NesterNestedEnum_Max } // String returns the name of `e` func (e NesterNestedEnum) String() string { @@ -1347,24 +1316,23 @@ func (e NesterNestedEnum) String() string { // EncodeTo encodes this value using the Encoder. func (e NesterNestedEnum) EncodeTo(enc *xdr.Encoder) error { - if _, ok := nestedEnumMap[int32(e)]; !ok { + if int32(e) < _NesterNestedEnum_Min || int32(e) > _NesterNestedEnum_Max { return fmt.Errorf("'%d' is not a valid NesterNestedEnum enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*NesterNestedEnum)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*NesterNestedEnum)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *NesterNestedEnum) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding NesterNestedEnum: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding NesterNestedEnum: %w", err) } - if _, ok := nestedEnumMap[v]; !ok { + if v < _NesterNestedEnum_Min || v > _NesterNestedEnum_Max { return n, fmt.Errorf("'%d' is not a valid NesterNestedEnum enum value", v) } *e = NesterNestedEnum(v) @@ -1380,11 +1348,8 @@ func (s NesterNestedEnum) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *NesterNestedEnum) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1417,8 +1382,8 @@ func (s *NesterNestedStruct) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*NesterNestedStruct)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*NesterNestedStruct)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *NesterNestedStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding NesterNestedStruct: %w", ErrMaxDecodingDepthReached) @@ -1444,11 +1409,8 @@ func (s NesterNestedStruct) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *NesterNestedStruct) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1473,7 +1435,7 @@ var _ xdrType = (*NesterNestedStruct)(nil) // type NesterNestedUnion struct{ Color Color - Blah2 *int32 + Blah2 int32 } // SwitchFieldName returns the field name in which this union's @@ -1505,7 +1467,7 @@ switch Color(color) { err = errors.New("invalid value, must be int32") return } - result.Blah2 = &tv + result.Blah2 = tv } return } @@ -1513,11 +1475,9 @@ switch Color(color) { // panicing if the value is not set. func (u NesterNestedUnion) MustBlah2() int32 { val, ok := u.GetBlah2() - if !ok { panic("arm Blah2 is not set") } - return val } @@ -1525,12 +1485,10 @@ func (u NesterNestedUnion) MustBlah2() int32 { // returning ok if the union's switch indicated the value is valid. func (u NesterNestedUnion) GetBlah2() (result int32, ok bool) { armName, _ := u.ArmForSwitch(int32(u.Color)) - if armName == "Blah2" { - result = *u.Blah2 + result = u.Blah2 ok = true } - return } @@ -1545,15 +1503,15 @@ switch Color(u.Color) { // Void return nil default: - if _, err = e.EncodeInt(int32((*u.Blah2))); err != nil { + if _, err = e.EncodeInt(int32(u.Blah2)); err != nil { return err } return nil } } -var _ decoderFrom = (*NesterNestedUnion)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*NesterNestedUnion)(nil) +// DecodeFrom decodes this value from the given decoder. func (u *NesterNestedUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding NesterNestedUnion: %w", ErrMaxDecodingDepthReached) @@ -1571,8 +1529,7 @@ switch Color(u.Color) { // Void return n, nil default: - u.Blah2 = new(int32) - (*u.Blah2), nTmp, err = d.DecodeInt() + u.Blah2, nTmp, err = d.DecodeInt() n += nTmp if err != nil { return n, fmt.Errorf("decoding Int: %w", err) @@ -1591,11 +1548,8 @@ func (s NesterNestedUnion) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *NesterNestedUnion) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -1653,8 +1607,8 @@ func (s *Nester) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Nester)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Nester)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Nester) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Nester: %w", ErrMaxDecodingDepthReached) @@ -1690,11 +1644,8 @@ func (s Nester) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Nester) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } diff --git a/spec/output/generator_spec_go/union.x/MyXDR_generated.go b/spec/output/generator_spec_go/union.x/MyXDR_generated.go index 88281c4d5..ec54d6ef9 100644 --- a/spec/output/generator_spec_go/union.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/union.x/MyXDR_generated.go @@ -23,29 +23,22 @@ var XdrFilesSHA256 = map[string]string{ "spec/fixtures/generator/union.x": "c251258d967223b341ebcf2d5bb0718e9a039b46232cb743865d9acd0c4bbe41", } -var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") - type xdrType interface { xdrType() } -type decoderFrom interface { - DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) -} - -// Unmarshal reads an xdr element from `r` into `v`. -func Unmarshal(r io.Reader, v interface{}) (int, error) { - return UnmarshalWithOptions(r, v, xdr.DefaultDecodeOptions) -} +// ErrMaxDecodingDepthReached is returned when the maximum decoding depth is +// exceeded. This prevents stack overflow from deeply nested structures. +var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached") -// UnmarshalWithOptions works like Unmarshal but uses decoding options. -func UnmarshalWithOptions(r io.Reader, v interface{}, options xdr.DecodeOptions) (int, error) { - if decodable, ok := v.(decoderFrom); ok { - d := xdr.NewDecoderWithOptions(r, options) - return decodable.DecodeFrom(d, options.MaxDepth) +// Unmarshal reads an xdr element from `data` into `v`. +func Unmarshal(data []byte, v interface{}) (int, error) { + if decodable, ok := v.(xdr.DecoderFrom); ok { + d := xdr.NewDecoder(data) + return decodable.DecodeFrom(d, d.MaxDepth()) } // delegate to xdr package's Unmarshal - return xdr.UnmarshalWithOptions(r, v, options) + return xdr.Unmarshal(data, v) } // Marshal writes an xdr element `v` into `w`. @@ -78,8 +71,8 @@ func (s Error) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Error)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Error)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Error) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Error: %w", ErrMaxDecodingDepthReached) @@ -107,11 +100,8 @@ func (s Error) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Error) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -140,8 +130,8 @@ func (s Multi) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*Multi)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*Multi)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *Multi) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding Multi: %w", ErrMaxDecodingDepthReached) @@ -169,11 +159,8 @@ func (s Multi) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *Multi) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -199,6 +186,10 @@ const ( UnionKeyError UnionKey = 0 UnionKeyMulti UnionKey = 1 ) +const ( + _UnionKey_Min int32 = 0 + _UnionKey_Max int32 = 1 +) var unionKeyMap = map[int32]string{ 0: "UnionKeyError", 1: "UnionKeyMulti", @@ -207,8 +198,7 @@ var unionKeyMap = map[int32]string{ // ValidEnum validates a proposed value for this enum. Implements // the Enum interface for UnionKey func (e UnionKey) ValidEnum(v int32) bool { - _, ok := unionKeyMap[v] - return ok + return v >= _UnionKey_Min && v <= _UnionKey_Max } // String returns the name of `e` func (e UnionKey) String() string { @@ -218,24 +208,23 @@ func (e UnionKey) String() string { // EncodeTo encodes this value using the Encoder. func (e UnionKey) EncodeTo(enc *xdr.Encoder) error { - if _, ok := unionKeyMap[int32(e)]; !ok { + if int32(e) < _UnionKey_Min || int32(e) > _UnionKey_Max { return fmt.Errorf("'%d' is not a valid UnionKey enum value", e) } _, err := enc.EncodeInt(int32(e)) return err } -var _ decoderFrom = (*UnionKey)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*UnionKey)(nil) +// DecodeFrom decodes this value from the given decoder. func (e *UnionKey) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding UnionKey: %w", ErrMaxDecodingDepthReached) } - maxDepth -= 1 v, n, err := d.DecodeInt() if err != nil { return n, fmt.Errorf("decoding UnionKey: %w", err) } - if _, ok := unionKeyMap[v]; !ok { + if v < _UnionKey_Min || v > _UnionKey_Max { return n, fmt.Errorf("'%d' is not a valid UnionKey enum value", v) } *e = UnionKey(v) @@ -251,11 +240,8 @@ func (s UnionKey) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *UnionKey) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -283,7 +269,7 @@ var _ xdrType = (*UnionKey)(nil) // type MyUnion struct{ Type UnionKey - Error *Error + Error Error Things *[]Multi } @@ -315,7 +301,7 @@ switch UnionKey(aType) { err = errors.New("invalid value, must be Error") return } - result.Error = &tv + result.Error = tv case UnionKeyMulti: tv, ok := value.([]Multi) if !ok { @@ -329,49 +315,37 @@ switch UnionKey(aType) { // MustError retrieves the Error value from the union, // panicing if the value is not set. func (u MyUnion) MustError() Error { - val, ok := u.GetError() - - if !ok { - panic("arm Error is not set") + if UnionKey(u.Type) == UnionKeyError { + return u.Error } - - return val + panic("arm Error is not set") } // GetError retrieves the Error value from the union, // returning ok if the union's switch indicated the value is valid. func (u MyUnion) GetError() (result Error, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.Type)) - - if armName == "Error" { - result = *u.Error + if UnionKey(u.Type) == UnionKeyError { + result = u.Error ok = true } - return } // MustThings retrieves the Things value from the union, // panicing if the value is not set. func (u MyUnion) MustThings() []Multi { - val, ok := u.GetThings() - - if !ok { - panic("arm Things is not set") + if UnionKey(u.Type) == UnionKeyMulti { + return *u.Things } - - return val + panic("arm Things is not set") } // GetThings retrieves the Things value from the union, // returning ok if the union's switch indicated the value is valid. func (u MyUnion) GetThings() (result []Multi, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.Type)) - - if armName == "Things" { + if UnionKey(u.Type) == UnionKeyMulti { result = *u.Things ok = true } - return } @@ -383,7 +357,7 @@ func (u MyUnion) EncodeTo(e *xdr.Encoder) error { } switch UnionKey(u.Type) { case UnionKeyError: - if err = (*u.Error).EncodeTo(e); err != nil { + if err = u.Error.EncodeTo(e); err != nil { return err } return nil @@ -401,8 +375,8 @@ return nil return fmt.Errorf("Type (UnionKey) switch value '%d' is not valid for union MyUnion", u.Type) } -var _ decoderFrom = (*MyUnion)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*MyUnion)(nil) +// DecodeFrom decodes this value from the given decoder. func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding MyUnion: %w", ErrMaxDecodingDepthReached) @@ -417,27 +391,33 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch UnionKey(u.Type) { case UnionKeyError: - u.Error = new(Error) - nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) + nTmp, err = u.Error.DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Error: %w", err) } return n, nil case UnionKeyMulti: - u.Things = new([]Multi) + if u.Things == nil { + u.Things = new([]Multi) + } var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { return n, fmt.Errorf("decoding Multi: %w", err) } - (*u.Things) = nil - if l > 0 { - if il, ok := d.InputLen(); ok && uint(il) < uint(l) { - return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il) + if l == 0 { + (*u.Things) = (*u.Things)[:0] + } else { + if uint(d.Remaining()) < uint(l) { + return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, d.Remaining()) + } + if cap((*u.Things)) >= int(l) { + (*u.Things) = (*u.Things)[:l] + } else { + (*u.Things) = make([]Multi, l) } - (*u.Things) = make([]Multi, l) for i := uint32(0); i < l; i++ { nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp @@ -461,11 +441,8 @@ func (s MyUnion) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *MyUnion) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -492,7 +469,7 @@ var _ xdrType = (*MyUnion)(nil) // type IntUnion struct{ Type int32 - Error *Error + Error Error Things *[]Multi } @@ -524,7 +501,7 @@ switch int32(aType) { err = errors.New("invalid value, must be Error") return } - result.Error = &tv + result.Error = tv case 1: tv, ok := value.([]Multi) if !ok { @@ -538,49 +515,37 @@ switch int32(aType) { // MustError retrieves the Error value from the union, // panicing if the value is not set. func (u IntUnion) MustError() Error { - val, ok := u.GetError() - - if !ok { - panic("arm Error is not set") + if int32(u.Type) == 0 { + return u.Error } - - return val + panic("arm Error is not set") } // GetError retrieves the Error value from the union, // returning ok if the union's switch indicated the value is valid. func (u IntUnion) GetError() (result Error, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.Type)) - - if armName == "Error" { - result = *u.Error + if int32(u.Type) == 0 { + result = u.Error ok = true } - return } // MustThings retrieves the Things value from the union, // panicing if the value is not set. func (u IntUnion) MustThings() []Multi { - val, ok := u.GetThings() - - if !ok { - panic("arm Things is not set") + if int32(u.Type) == 1 { + return *u.Things } - - return val + panic("arm Things is not set") } // GetThings retrieves the Things value from the union, // returning ok if the union's switch indicated the value is valid. func (u IntUnion) GetThings() (result []Multi, ok bool) { - armName, _ := u.ArmForSwitch(int32(u.Type)) - - if armName == "Things" { + if int32(u.Type) == 1 { result = *u.Things ok = true } - return } @@ -592,7 +557,7 @@ func (u IntUnion) EncodeTo(e *xdr.Encoder) error { } switch int32(u.Type) { case 0: - if err = (*u.Error).EncodeTo(e); err != nil { + if err = u.Error.EncodeTo(e); err != nil { return err } return nil @@ -610,8 +575,8 @@ return nil return fmt.Errorf("Type (int32) switch value '%d' is not valid for union IntUnion", u.Type) } -var _ decoderFrom = (*IntUnion)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*IntUnion)(nil) +// DecodeFrom decodes this value from the given decoder. func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding IntUnion: %w", ErrMaxDecodingDepthReached) @@ -626,27 +591,33 @@ func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch int32(u.Type) { case 0: - u.Error = new(Error) - nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) + nTmp, err = u.Error.DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Error: %w", err) } return n, nil case 1: - u.Things = new([]Multi) + if u.Things == nil { + u.Things = new([]Multi) + } var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp if err != nil { return n, fmt.Errorf("decoding Multi: %w", err) } - (*u.Things) = nil - if l > 0 { - if il, ok := d.InputLen(); ok && uint(il) < uint(l) { - return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il) + if l == 0 { + (*u.Things) = (*u.Things)[:0] + } else { + if uint(d.Remaining()) < uint(l) { + return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, d.Remaining()) + } + if cap((*u.Things)) >= int(l) { + (*u.Things) = (*u.Things)[:l] + } else { + (*u.Things) = make([]Multi, l) } - (*u.Things) = make([]Multi, l) for i := uint32(0); i < l; i++ { nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp @@ -670,11 +641,8 @@ func (s IntUnion) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *IntUnion) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err } @@ -744,8 +712,8 @@ func (s IntUnion2) EncodeTo(e *xdr.Encoder) error { return nil } -var _ decoderFrom = (*IntUnion2)(nil) -// DecodeFrom decodes this value using the Decoder. +var _ xdr.DecoderFrom = (*IntUnion2)(nil) +// DecodeFrom decodes this value from the given decoder. func (s *IntUnion2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if maxDepth == 0 { return 0, fmt.Errorf("decoding IntUnion2: %w", ErrMaxDecodingDepthReached) @@ -771,11 +739,8 @@ func (s IntUnion2) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements encoding.BinaryUnmarshaler. func (s *IntUnion2) UnmarshalBinary(inp []byte) error { - r := bytes.NewReader(inp) - o := xdr.DefaultDecodeOptions - o.MaxInputLen = len(inp) - d := xdr.NewDecoderWithOptions(r, o) - _, err := s.DecodeFrom(d, o.MaxDepth) + d := xdr.NewDecoder(inp) + _, err := s.DecodeFrom(d, d.MaxDepth()) return err }