Skip to content

Commit 966e0f4

Browse files
committed
Ensure decoded enum values are valid
1 parent b6f99ee commit 966e0f4

File tree

1 file changed

+18
-16
lines changed
  • lib/xdrgen/generators

1 file changed

+18
-16
lines changed

Diff for: lib/xdrgen/generators/go.rb

+18-16
Original file line numberDiff line numberDiff line change
@@ -551,11 +551,9 @@ def render_union_decode_from_interface(out, union)
551551
name = name(union)
552552
out.puts "// DecodeFrom decodes this value using the Decoder."
553553
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
554-
out.puts " disc, n, err := d.DecodeInt()"
555-
out.puts " if err != nil {"
556-
out.puts " return 0, err"
557-
out.puts " }"
558-
out.puts " u.#{name(union.discriminant)} = #{reference union.discriminant.type}(disc)"
554+
out.puts " var err error"
555+
out.puts " var n, nTmp int"
556+
render_decode_from_body(out, "u.#{name(union.discriminant)}", union.discriminant.type, declared_variables: [], self_encode: false)
559557
switch_for(out, union, "u.#{name(union.discriminant)}") do |arm, kase|
560558
out2 = StringIO.new
561559
if arm.void?
@@ -580,16 +578,20 @@ def render_union_decode_from_interface(out, union)
580578
def render_enum_decode_from_interface(out, typedef)
581579
name = name(typedef)
582580
type = typedef
583-
out.puts "// DecodeFrom decodes this value using the Decoder."
584-
out.puts "func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {"
585-
out.puts " var err error"
586-
out.puts " var n, nTmp int"
587-
out.puts " var i int32"
588-
render_decode_from_body(out, "i", type, declared_variables: [], self_encode: true)
589-
out.puts " *e = #{name}(i)"
590-
out.puts " return n, nil"
591-
out.puts "}"
592-
out.break
581+
out.puts <<-EOS.strip_heredoc
582+
// DecodeFrom decodes this value using the Decoder.
583+
func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {
584+
v, n, err := d.DecodeInt()
585+
if err != nil {
586+
return n, err
587+
}
588+
if _, ok := #{private_name type}Map[v]; !ok {
589+
return n, fmt.Errorf("'%d' is not a valid #{name} enum value", v)
590+
}
591+
*e = #{name}(v)
592+
return n, nil
593+
}
594+
EOS
593595
end
594596

595597
def render_typedef_decode_from_interface(out, typedef)
@@ -655,7 +657,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
655657
when AST::Typespecs::UnsignedInt
656658
out.puts " #{var}, nTmp, err = d.DecodeUint()"
657659
out.puts tail
658-
when AST::Typespecs::Int, AST::Definitions::Enum
660+
when AST::Typespecs::Int
659661
out.puts " #{var}, nTmp, err = d.DecodeInt()"
660662
out.puts tail
661663
when AST::Typespecs::String

0 commit comments

Comments
 (0)