diff --git a/lib/xdrgen/generators/go.rb b/lib/xdrgen/generators/go.rb index 4d0ec53b2..ccb8ab8de 100644 --- a/lib/xdrgen/generators/go.rb +++ b/lib/xdrgen/generators/go.rb @@ -571,6 +571,9 @@ def render_union_decode_from_interface(out, union) else mn = name(arm) type = arm.type + out2.puts " if err = xdr.TrackOutputBytesOf[#{reference arm.type}](d); err != nil {" + out2.puts " return n, fmt.Errorf(\"decoding #{reference arm.type}: %w\", err)" + out2.puts " }" out2.puts " u.#{mn} = new(#{reference arm.type})" render_decode_from_body(out2, "(*u.#{mn})",type, declared_variables: [], self_encode: false) end @@ -662,6 +665,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts tail out.puts " #{var} = nil" out.puts " if b {" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)" + out.puts " }" out.puts " #{var} = new(#{name type})" end case type @@ -704,6 +710,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts tail out.puts " #{var} = nil" out.puts " if b {" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)" + out.puts " }" out.puts " #{var} = new(#{name type.resolved_type.declaration.type})" end var = "(*#{name type})(#{var})" if self_encode @@ -744,9 +753,25 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " if il, ok := d.InputLen(); ok && uint(il) < uint(l) {" out.puts " return n, fmt.Errorf(\"decoding #{name type}: length (%d) exceeds remaining input length (%d)\", l, il)" out.puts " }" - out.puts " #{var} = make([]#{name type}, l)" + # Cap pre-allocation to avoid memory amplification from untrusted inputs. + # The InputLen check above compares element count against remaining + # input bytes, but each element may be much larger in memory than on + # the wire. Capping initial allocation and growing via append ensures + # memory usage stays proportional to data actually decoded. + slice_var = var # save before optional handling may reassign var + out.puts " {" + out.puts " initialCap := l" + out.puts " if initialCap > xdr.MaxPrealloc {" + out.puts " initialCap = xdr.MaxPrealloc" + out.puts " }" + out.puts " #{slice_var} = make([]#{name type}, 0, initialCap)" + out.puts " var empty #{name type}" out.puts " for i := uint32(0); i < l; i++ {" - element_var = "#{var}[i]" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type}: %w\", err)" + out.puts " }" + out.puts " #{slice_var} = append(#{slice_var}, empty)" + element_var = "#{slice_var}[i]" optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional if optional_within out.puts " var eb bool" @@ -754,6 +779,9 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts tail out.puts " #{element_var} = nil" out.puts " if eb {" + out.puts " if err = xdr.TrackOutputBytesOf[#{name type.resolved_type.declaration.type}](d); err != nil {" + out.puts " return n, fmt.Errorf(\"decoding #{name type.resolved_type.declaration.type}: %w\", err)" + out.puts " }" out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})" var = "(*#{element_var})" end @@ -763,6 +791,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) out.puts " }" end out.puts " }" + out.puts " }" out.puts " }" else raise "Unknown sub_type: #{type.sub_type}" diff --git a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go index dbd781b2e..fce392cb6 100644 --- a/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/nesting.x/MyXDR_generated.go @@ -505,7 +505,10 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch UnionKey(u.Type) { case UnionKeyOne: - u.One = new(MyUnionOne) + if err = xdr.TrackOutputBytesOf[MyUnionOne](d); err != nil { + return n, fmt.Errorf("decoding MyUnionOne: %w", err) + } + u.One = new(MyUnionOne) nTmp, err = (*u.One).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -513,7 +516,10 @@ switch UnionKey(u.Type) { } return n, nil case UnionKeyTwo: - u.Two = new(MyUnionTwo) + if err = xdr.TrackOutputBytesOf[MyUnionTwo](d); err != nil { + return n, fmt.Errorf("decoding MyUnionTwo: %w", err) + } + u.Two = new(MyUnionTwo) nTmp, err = (*u.Two).DecodeFrom(d, maxDepth) n += nTmp if err != nil { diff --git a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go index 5ffb0fbd5..e63b3aaca 100644 --- a/spec/output/generator_spec_go/optional.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/optional.x/MyXDR_generated.go @@ -186,6 +186,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.FirstOption = nil if b { + if err = xdr.TrackOutputBytesOf[Int](d); err != nil { + return n, fmt.Errorf("decoding Int: %w", err) + } s.FirstOption = new(Int) s.FirstOption, nTmp, err = d.DecodeInt() n += nTmp @@ -200,6 +203,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.SecondOption = nil if b { + if err = xdr.TrackOutputBytesOf[Int](d); err != nil { + return n, fmt.Errorf("decoding Int: %w", err) + } s.SecondOption = new(Int) s.SecondOption, nTmp, err = d.DecodeInt() n += nTmp @@ -214,6 +220,9 @@ func (s *HasOptions) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.ThirdOption = nil if b { + if err = xdr.TrackOutputBytesOf[Arr](d); err != nil { + return n, fmt.Errorf("decoding Arr: %w", err) + } s.ThirdOption = new(Arr) nTmp, err = s.ThirdOption.DecodeFrom(d, maxDepth) n += nTmp diff --git a/spec/output/generator_spec_go/test.x/MyXDR_generated.go b/spec/output/generator_spec_go/test.x/MyXDR_generated.go index 1a20a446a..c96fb7d3c 100644 --- a/spec/output/generator_spec_go/test.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/test.x/MyXDR_generated.go @@ -552,14 +552,25 @@ func (s *Hashes2) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il) } - (*s) = make([]Hash, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*s) = make([]Hash, 0, initialCap) + var empty Hash for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Hash](d); err != nil { + return n, fmt.Errorf("decoding Hash: %w", err) + } + (*s) = append((*s), empty) nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Hash: %w", err) } } + } } return n, nil } @@ -631,14 +642,25 @@ func (s *Hashes3) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Hash: length (%d) exceeds remaining input length (%d)", l, il) } - (*s) = make([]Hash, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*s) = make([]Hash, 0, initialCap) + var empty Hash for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Hash](d); err != nil { + return n, fmt.Errorf("decoding Hash: %w", err) + } + (*s) = append((*s), empty) nTmp, err = (*s)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Hash: %w", err) } } + } } return n, nil } @@ -1006,6 +1028,9 @@ func (s *MyStruct) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } s.Field2 = nil if b { + if err = xdr.TrackOutputBytesOf[Hash](d); err != nil { + return n, fmt.Errorf("decoding Hash: %w", err) + } s.Field2 = new(Hash) nTmp, err = s.Field2.DecodeFrom(d, maxDepth) n += nTmp @@ -1114,14 +1139,25 @@ func (s *LotsOfMyStructs) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding MyStruct: length (%d) exceeds remaining input length (%d)", l, il) } - s.Members = make([]MyStruct, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + s.Members = make([]MyStruct, 0, initialCap) + var empty MyStruct for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[MyStruct](d); err != nil { + return n, fmt.Errorf("decoding MyStruct: %w", err) + } + s.Members = append(s.Members, empty) nTmp, err = s.Members[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding MyStruct: %w", err) } } + } } return n, nil } @@ -1571,7 +1607,10 @@ switch Color(u.Color) { // Void return n, nil default: - u.Blah2 = new(int32) + if err = xdr.TrackOutputBytesOf[int32](d); err != nil { + return n, fmt.Errorf("decoding int32: %w", err) + } + u.Blah2 = new(int32) (*u.Blah2), nTmp, err = d.DecodeInt() n += nTmp if err != nil { diff --git a/spec/output/generator_spec_go/union.x/MyXDR_generated.go b/spec/output/generator_spec_go/union.x/MyXDR_generated.go index 88281c4d5..62a76ab4f 100644 --- a/spec/output/generator_spec_go/union.x/MyXDR_generated.go +++ b/spec/output/generator_spec_go/union.x/MyXDR_generated.go @@ -417,7 +417,10 @@ func (u *MyUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch UnionKey(u.Type) { case UnionKeyError: - u.Error = new(Error) + if err = xdr.TrackOutputBytesOf[Error](d); err != nil { + return n, fmt.Errorf("decoding Error: %w", err) + } + u.Error = new(Error) nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -425,7 +428,10 @@ switch UnionKey(u.Type) { } return n, nil case UnionKeyMulti: - u.Things = new([]Multi) + if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil { + return n, fmt.Errorf("decoding []Multi: %w", err) + } + u.Things = new([]Multi) var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp @@ -437,14 +443,25 @@ switch UnionKey(u.Type) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il) } - (*u.Things) = make([]Multi, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*u.Things) = make([]Multi, 0, initialCap) + var empty Multi for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Multi](d); err != nil { + return n, fmt.Errorf("decoding Multi: %w", err) + } + (*u.Things) = append((*u.Things), empty) nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Multi: %w", err) } } + } } return n, nil } @@ -626,7 +643,10 @@ func (u *IntUnion) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) { } switch int32(u.Type) { case 0: - u.Error = new(Error) + if err = xdr.TrackOutputBytesOf[Error](d); err != nil { + return n, fmt.Errorf("decoding Error: %w", err) + } + u.Error = new(Error) nTmp, err = (*u.Error).DecodeFrom(d, maxDepth) n += nTmp if err != nil { @@ -634,7 +654,10 @@ switch int32(u.Type) { } return n, nil case 1: - u.Things = new([]Multi) + if err = xdr.TrackOutputBytesOf[[]Multi](d); err != nil { + return n, fmt.Errorf("decoding []Multi: %w", err) + } + u.Things = new([]Multi) var l uint32 l, nTmp, err = d.DecodeUint() n += nTmp @@ -646,14 +669,25 @@ switch int32(u.Type) { if il, ok := d.InputLen(); ok && uint(il) < uint(l) { return n, fmt.Errorf("decoding Multi: length (%d) exceeds remaining input length (%d)", l, il) } - (*u.Things) = make([]Multi, l) + { + initialCap := l + if initialCap > xdr.MaxPrealloc { + initialCap = xdr.MaxPrealloc + } + (*u.Things) = make([]Multi, 0, initialCap) + var empty Multi for i := uint32(0); i < l; i++ { + if err = xdr.TrackOutputBytesOf[Multi](d); err != nil { + return n, fmt.Errorf("decoding Multi: %w", err) + } + (*u.Things) = append((*u.Things), empty) nTmp, err = (*u.Things)[i].DecodeFrom(d, maxDepth) n += nTmp if err != nil { return n, fmt.Errorf("decoding Multi: %w", err) } } + } } return n, nil }