Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

encoding/json: detect cyclic maps and slices #40756

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/encoding/json/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,16 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
ptr := v.Pointer()
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
e.WriteByte('{')

// Extract and sort the keys.
Expand All @@ -801,6 +811,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
me.elemEnc(e, v.MapIndex(kv.v), opts)
}
e.WriteByte('}')
e.ptrLevel--
}

func newMapEncoder(t reflect.Type) encoderFunc {
Expand Down Expand Up @@ -857,7 +868,23 @@ func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
// Here we use a struct to memorize the pointer to the first element of the slice
// and its length.
ptr := struct {
ptr uintptr
len int
}{v.Pointer(), v.Len()}
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
se.arrayEnc(e, v, opts)
e.ptrLevel--
}

func newSliceEncoder(t reflect.Type) encoderFunc {
Expand Down
27 changes: 26 additions & 1 deletion src/encoding/json/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,15 @@ type PointerCycleIndirect struct {
Ptrs []interface{}
}

var pointerCycleIndirect = &PointerCycleIndirect{}
type RecursiveSlice []RecursiveSlice

var (
pointerCycleIndirect = &PointerCycleIndirect{}
mapCycle = make(map[string]interface{})
sliceCycle = []interface{}{nil}
sliceNoCycle = []interface{}{nil, nil}
recursiveSliceCycle = []RecursiveSlice{nil}
)

func init() {
ptr := &SamePointerNoCycle{}
Expand All @@ -192,6 +200,14 @@ func init() {

pointerCycle.Ptr = pointerCycle
pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect}

mapCycle["x"] = mapCycle
sliceCycle[0] = sliceCycle
sliceNoCycle[1] = sliceNoCycle[:1]
for i := startDetectingCyclesAfter; i > 0; i-- {
sliceNoCycle = []interface{}{sliceNoCycle}
}
recursiveSliceCycle[0] = recursiveSliceCycle
}

func TestSamePointerNoCycle(t *testing.T) {
Expand All @@ -200,12 +216,21 @@ func TestSamePointerNoCycle(t *testing.T) {
}
}

func TestSliceNoCycle(t *testing.T) {
if _, err := Marshal(sliceNoCycle); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}

var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
pointerCycle,
pointerCycleIndirect,
mapCycle,
sliceCycle,
recursiveSliceCycle,
}

func TestUnsupportedValues(t *testing.T) {
Expand Down