diff options
Diffstat (limited to 'libgo/go/encoding/json/encode.go')
-rw-r--r-- | libgo/go/encoding/json/encode.go | 154 |
1 files changed, 131 insertions, 23 deletions
diff --git a/libgo/go/encoding/json/encode.go b/libgo/go/encoding/json/encode.go index 67412763d64..39cdaebde7b 100644 --- a/libgo/go/encoding/json/encode.go +++ b/libgo/go/encoding/json/encode.go @@ -153,7 +153,7 @@ import ( // // JSON cannot represent cyclic data structures and Marshal does not // handle them. Passing cyclic structures to Marshal will result in -// an infinite recursion. +// an error. // func Marshal(v interface{}) ([]byte, error) { e := newEncodeState() @@ -164,7 +164,6 @@ func Marshal(v interface{}) ([]byte, error) { } buf := append([]byte(nil), e.Bytes()...) - e.Reset() encodeStatePool.Put(e) return buf, nil @@ -262,14 +261,22 @@ func (e *InvalidUTF8Error) Error() string { // A MarshalerError represents an error from calling a MarshalJSON or MarshalText method. type MarshalerError struct { - Type reflect.Type - Err error + Type reflect.Type + Err error + sourceFunc string } func (e *MarshalerError) Error() string { - return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Err.Error() + srcFunc := e.sourceFunc + if srcFunc == "" { + srcFunc = "MarshalJSON" + } + return "json: error calling " + srcFunc + + " for type " + e.Type.String() + + ": " + e.Err.Error() } +// Unwrap returns the underlying error. func (e *MarshalerError) Unwrap() error { return e.Err } var hex = "0123456789abcdef" @@ -278,17 +285,31 @@ var hex = "0123456789abcdef" type encodeState struct { bytes.Buffer // accumulated output scratch [64]byte + + // Keep track of what pointers we've seen in the current recursive call + // path, to avoid cycles that could lead to a stack overflow. Only do + // the relatively expensive map operations if ptrLevel is larger than + // startDetectingCyclesAfter, so that we skip the work if we're within a + // reasonable amount of nested pointers deep. + ptrLevel uint + ptrSeen map[interface{}]struct{} } +const startDetectingCyclesAfter = 1000 + var encodeStatePool sync.Pool func newEncodeState() *encodeState { if v := encodeStatePool.Get(); v != nil { e := v.(*encodeState) e.Reset() + if len(e.ptrSeen) > 0 { + panic("ptrEncoder.encode should have emptied ptrSeen via defers") + } + e.ptrLevel = 0 return e } - return new(encodeState) + return &encodeState{ptrSeen: make(map[interface{}]struct{})} } // jsonError is an error wrapper type for internal use only. @@ -392,19 +413,22 @@ var ( // newTypeEncoder constructs an encoderFunc for a type. // The returned encoder only checks CanAddr when allowAddr is true. func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { - if t.Implements(marshalerType) { - return marshalerEncoder - } + // If we have a non-pointer value whose type implements + // Marshaler with a value receiver, then we're better off taking + // the address of the value - otherwise we end up with an + // allocation as we cast the value to an interface. if t.Kind() != reflect.Ptr && allowAddr && reflect.PtrTo(t).Implements(marshalerType) { return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } - - if t.Implements(textMarshalerType) { - return textMarshalerEncoder + if t.Implements(marshalerType) { + return marshalerEncoder } if t.Kind() != reflect.Ptr && allowAddr && reflect.PtrTo(t).Implements(textMarshalerType) { return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) } + if t.Implements(textMarshalerType) { + return textMarshalerEncoder + } switch t.Kind() { case reflect.Bool: @@ -456,7 +480,7 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { err = compact(&e.Buffer, b, opts.escapeHTML) } if err != nil { - e.error(&MarshalerError{v.Type(), err}) + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) } } @@ -473,7 +497,7 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { err = compact(&e.Buffer, b, opts.escapeHTML) } if err != nil { - e.error(&MarshalerError{v.Type(), err}) + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) } } @@ -482,10 +506,14 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { e.WriteString("null") return } - m := v.Interface().(encoding.TextMarshaler) + m, ok := v.Interface().(encoding.TextMarshaler) + if !ok { + e.WriteString("null") + return + } b, err := m.MarshalText() if err != nil { - e.error(&MarshalerError{v.Type(), err}) + e.error(&MarshalerError{v.Type(), err, "MarshalText"}) } e.stringBytes(b, opts.escapeHTML) } @@ -499,7 +527,7 @@ func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { m := va.Interface().(encoding.TextMarshaler) b, err := m.MarshalText() if err != nil { - e.error(&MarshalerError{v.Type(), err}) + e.error(&MarshalerError{v.Type(), err, "MarshalText"}) } e.stringBytes(b, opts.escapeHTML) } @@ -597,20 +625,86 @@ func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) { if !isValidNumber(numStr) { e.error(fmt.Errorf("json: invalid number literal %q", numStr)) } + if opts.quoted { + e.WriteByte('"') + } e.WriteString(numStr) + if opts.quoted { + e.WriteByte('"') + } return } if opts.quoted { - sb, err := Marshal(v.String()) - if err != nil { - e.error(err) - } - e.string(string(sb), opts.escapeHTML) + b := make([]byte, 0, v.Len()+2) + b = append(b, '"') + b = append(b, []byte(v.String())...) + b = append(b, '"') + e.stringBytes(b, opts.escapeHTML) } else { e.string(v.String(), opts.escapeHTML) } } +// isValidNumber reports whether s is a valid JSON number literal. +func isValidNumber(s string) bool { + // This function implements the JSON numbers grammar. + // See https://tools.ietf.org/html/rfc7159#section-6 + // and https://json.org/number.gif + + if s == "" { + return false + } + + // Optional - + if s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + + // Digits + switch { + default: + return false + + case s[0] == '0': + s = s[1:] + + case '1' <= s[0] && s[0] <= '9': + s = s[1:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // . followed by 1 or more digits. + if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' { + s = s[2:] + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // e or E followed by an optional - or + and + // 1 or more digits. + if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') { + s = s[1:] + if s[0] == '+' || s[0] == '-' { + s = s[1:] + if s == "" { + return false + } + } + for len(s) > 0 && '0' <= s[0] && s[0] <= '9' { + s = s[1:] + } + } + + // Make sure we are at the end. + return s == "" +} + func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.IsNil() { e.WriteString("null") @@ -692,7 +786,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { for i, v := range keys { sv[i].v = v if err := sv[i].resolve(); err != nil { - e.error(&MarshalerError{v.Type(), err}) + e.error(fmt.Errorf("json: encoding error for type %q: %q", v.Type().String(), err.Error())) } } sort.Slice(sv, func(i, j int) bool { return sv[i].s < sv[j].s }) @@ -807,7 +901,18 @@ func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { e.WriteString("null") return } + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + ptr := v.Interface() + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } pe.elemEnc(e, v.Elem(), opts) + e.ptrLevel-- } func newPtrEncoder(t reflect.Type) encoderFunc { @@ -872,6 +977,9 @@ func (w *reflectWithString) resolve() error { return nil } if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok { + if w.v.Kind() == reflect.Ptr && w.v.IsNil() { + return nil + } buf, err := tm.MarshalText() w.s = string(buf) return err |