From e009504868352d9afe17b72dcfbcfec20d591ee2 Mon Sep 17 00:00:00 2001 From: tamirms Date: Wed, 11 Mar 2026 12:33:10 -0500 Subject: [PATCH] Go generator: emit output size tracking and capped pre-allocation Update the Go code generator to emit TrackOutputBytesOf calls before each heap allocation site (union arms, optional fields, array elements) and cap initial array allocation at 256 elements with append-based growth. This works with the new MaxOutputBytes option in go-xdr to allow callers to limit cumulative decoded output size. Co-Authored-By: Claude Opus 4.6 --- lib/xdrgen/generators/go.rb | 33 ++++++++++++- .../nesting.x/MyXDR_generated.go | 10 +++- .../optional.x/MyXDR_generated.go | 9 ++++ .../test.x/MyXDR_generated.go | 47 +++++++++++++++++-- .../union.x/MyXDR_generated.go | 46 +++++++++++++++--- 5 files changed, 131 insertions(+), 14 deletions(-) diff --git a/lib/xdrgen/generators/go.rb b/lib/xdrgen/generators/go.rb index 4d0ec53b2..ccb8ab8de 100644 --- a/lib/xdrgen/generators/go.rb +++ b/lib/xdrgen/generators/go.rb @@ -571,6 +571,9 @@ def render_union_decode_from_interface(out, union) else mn = name(arm) type = arm.type + out2.puts " if err = xdr.TrackOutputBytesOf[#{reference arm.type}](d); err != nil {" + out2.puts " return n, fmt.Errorf(\"decoding #{reference arm.type}: %w\", err)" + out2.puts " }" out2.puts " u.#{mn} = new(#{reference arm.type})" render_decode_from_body(out2, "(*u.#{mn})",type, declared_variables: [], self_encode: false) end @@ -662,6 +665,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts tail out.puts " #{var} = nil" out.puts " if b {" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)" + out.puts " }" out.puts " #{var} = new(#{name type})" end case type @@ -704,6 +710,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts tail out.puts " #{var} = nil" out.puts " if b {" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)" + out.puts " }" out.puts " #{var} = new(#{name type.resolved_type.declaration.type})" end var = "(*#{name type})(#{var})" if self_encode @@ -744,9 +753,25 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " if il, ok := d.InputLen(); ok && uint(il) < uint(l) {" out.puts " return n, fmt.Errorf(\"decoding #{name type}: length (%d) exceeds remaining input length (%d)\", l, il)" out.puts " }" - out.puts " #{var} = make([]#{name type}, l)" + # Cap pre-allocation to avoid memory amplification from untrusted inputs. + # The InputLen check above compares element count against remaining + # input bytes, but each element may be much larger in memory than on + # the wire. Capping initial allocation and growing via append ensures + # memory usage stays proportional to data actually decoded. + slice_var = var # save before optional handling may reassign var + out.puts " {" + out.puts " initialCap := l" + out.puts " if initialCap > xdr.MaxPrealloc {" + out.puts " initialCap = xdr.MaxPrealloc" + out.puts " }" + out.puts " #{slice_var} = make([]#{name type}, 0, initialCap)" + out.puts " var empty #{name type}" out.puts " for i := uint32(0); i < l; i++ {" - element_var = "#{var}[i]" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)" + out.puts " }" + out.puts " #{slice_var} = append(#{slice_var}, empty)" + element_var = "#{slice_var}[i]" optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional if optional_within out.puts " var eb bool" @@ -754,6 +779,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts tail out.puts " #{element_var} = nil" out.puts " if eb {" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)" + out.puts " }" out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})" var = "(*#{element_var})" end @@ -763,6 +791,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " }" end out.puts " }" + out.puts " }" out.puts " }" else raise "Unknown sub_type: #{type.sub_type}" diff --git a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go index dbd781b2e..fce392cb6 100644 --- a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go @@ -505,7 +505,10 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch UnionKey(u.Type) { case UnionKeyOne: - u.One = new(MyUnionOne) + if err = xdr.TrackOutputBytesOf[MyUnionOne](d); err != nil { + return n, fmt.Errorf("decoding MyUnionOne: %w", err) + } + u.One = new(MyUnionOne) nTmp, err = (*u.One).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -513,7 +516,10 @@ switch UnionKey(u.Type) { } return n, nil case UnionKeyTwo: - u.Two = new(MyUnionTwo) + if err = xdr.TrackOutputBytesOf[MyUnionTwo](d); err != nil { + return n, fmt.Errorf("decoding MyUnionTwo: %w", err) + } + u.Two = new(MyUnionTwo) nTmp, err = (*u.Two).DecodeFrom(d, maxDepth) n += nTmp if err != nil { diff --git a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go index 5ffb0fbd5..e63b3aaca 100644 --- a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go @@ -186,6 +186,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.FirstOption = nil if b { + if err = xdr.TrackOutputBytesOf[Int](d); err != nil { + return n, fmt.Errorf("decoding Int: %w", err) + } s.FirstOption = new(Int) s.FirstOption, nTmp, err = d.DecodeInt() n += nTmp @@ -200,6 +203,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.SecondOption = nil if b { + if err = xdr.TrackOutputBytesOf[Int](d); err != nil { + return n, fmt.Errorf("decoding Int: %w", err) + } s.SecondOption = new(Int) s.SecondOption, nTmp, err = d.DecodeInt() n += nTmp @@ -214,6 +220,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.ThirdOption = nil if b { + if err = xdr.TrackOutputBytesOf[Arr](d); err != nil { + return n, fmt.Errorf("decoding Arr: %w", err) + } s.ThirdOption = new(Arr) nTmp, err = s.ThirdOption.DecodeFrom(d, maxDepth) n += nTmp diff --git a/spec/output/generator_spec_go/test.x/MyXDR_generated.go b/spec/output/generator_spec_go/test.x/MyXDR_generated.go index 1a20a446a..c96fb7d3c 100644 --- a/spec/output/generator_spec_go/test.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/test.x/MyXDR_generated.go @@ -552,14 +552,25 @@ func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il) } - (*s) = make([]Hash, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*s) = make([]Hash, 0, initialCap) + var empty Hash for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Hash](d); err != nil { + return n, fmt.Errorf("decoding Hash: %w", err) + } + (*s) = append((*s), empty) nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Hash: %w", err) } } + } } return n, nil } @@ -631,14 +642,25 @@ func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il) } - (*s) = make([]Hash, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*s) = make([]Hash, 0, initialCap) + var empty Hash for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Hash](d); err != nil { + return n, fmt.Errorf("decoding Hash: %w", err) + } + (*s) = append((*s), empty) nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Hash: %w", err) } } + } } return n, nil } @@ -1006,6 +1028,9 @@ func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.Field2 = nil if b { + if err = xdr.TrackOutputBytesOf[Hash](d); err != nil { + return n, fmt.Errorf("decoding Hash: %w", err) + } s.Field2 = new(Hash) nTmp, err = s.Field2.DecodeFrom(d, maxDepth) n += nTmp @@ -1114,14 +1139,25 @@ func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding MyStruct: length (%d) exceeds remaining input length (%d)", l, il) } - s.Members = make([]MyStruct, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + s.Members = make([]MyStruct, 0, initialCap) + var empty MyStruct for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[MyStruct](d); err != nil { + return n, fmt.Errorf("decoding MyStruct: %w", err) + } + s.Members = append(s.Members, empty) nTmp, err = s.Members[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding MyStruct: %w", err) } } + } } return n, nil } @@ -1571,7 +1607,10 @@ switch Color(u.Color) { // Void return n, nil default: - u.Blah2 = new(int32) + if err = xdr.TrackOutputBytesOf[int32](d); err != nil { + return n, fmt.Errorf("decoding int32: %w", err) + } + u.Blah2 = new(int32) (*u.Blah2), nTmp, err = d.DecodeInt() n += nTmp if err != nil { diff --git a/spec/output/generator_spec_go/union.x/MyXDR_generated.go b/spec/output/generator_spec_go/union.x/MyXDR_generated.go index 88281c4d5..62a76ab4f 100644 --- a/spec/output/generator_spec_go/union.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/union.x/MyXDR_generated.go @@ -417,7 +417,10 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch UnionKey(u.Type) { case UnionKeyError: - u.Error = new(Error) + if err = xdr.TrackOutputBytesOf[Error](d); err != nil { + return n, fmt.Errorf("decoding Error: %w", err) + } + u.Error = new(Error) nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -425,7 +428,10 @@ switch UnionKey(u.Type) { } return n, nil case UnionKeyMulti: - u.Things = new([]Multi) + if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil { + return n, fmt.Errorf("decoding []Multi: %w", err) + } + u.Things = new([]Multi) var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp @@ -437,14 +443,25 @@ switch UnionKey(u.Type) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il) } - (*u.Things) = make([]Multi, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*u.Things) = make([]Multi, 0, initialCap) + var empty Multi for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Multi](d); err != nil { + return n, fmt.Errorf("decoding Multi: %w", err) + } + (*u.Things) = append((*u.Things), empty) nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Multi: %w", err) } } + } } return n, nil } @@ -626,7 +643,10 @@ func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch int32(u.Type) { case 0: - u.Error = new(Error) + if err = xdr.TrackOutputBytesOf[Error](d); err != nil { + return n, fmt.Errorf("decoding Error: %w", err) + } + u.Error = new(Error) nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -634,7 +654,10 @@ switch int32(u.Type) { } return n, nil case 1: - u.Things = new([]Multi) + if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil { + return n, fmt.Errorf("decoding []Multi: %w", err) + } + u.Things = new([]Multi) var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp @@ -646,14 +669,25 @@ switch int32(u.Type) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il) } - (*u.Things) = make([]Multi, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*u.Things) = make([]Multi, 0, initialCap) + var empty Multi for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Multi](d); err != nil { + return n, fmt.Errorf("decoding Multi: %w", err) + } + (*u.Things) = append((*u.Things), empty) nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Multi: %w", err) } } + } } return n, nil }