summaryrefslogtreecommitdiff
path: root/libgo/go/encoding/json/encode.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/encoding/json/encode.go')
-rw-r--r--libgo/go/encoding/json/encode.go154
1 files changed, 131 insertions, 23 deletions
diff --git a/libgo/go/encoding/json/encode.go b/libgo/go/encoding/json/encode.go
index 67412763d64..39cdaebde7b 100644
--- a/libgo/go/encoding/json/encode.go
+++ b/libgo/go/encoding/json/encode.go
@@ -153,7 +153,7 @@ import (
//
// JSON cannot represent cyclic data structures and Marshal does not
// handle them. Passing cyclic structures to Marshal will result in
-// an infinite recursion.
+// an error.
//
func Marshal(v interface{}) ([]byte, error) {
e := newEncodeState()
@@ -164,7 +164,6 @@ func Marshal(v interface{}) ([]byte, error) {
}
buf := append([]byte(nil), e.Bytes()...)
- e.Reset()
encodeStatePool.Put(e)
return buf, nil
@@ -262,14 +261,22 @@ func (e *InvalidUTF8Error) Error() string {
// A MarshalerError represents an error from calling a MarshalJSON or MarshalText method.
type MarshalerError struct {
- Type reflect.Type
- Err error
+ Type reflect.Type
+ Err error
+ sourceFunc string
}
func (e *MarshalerError) Error() string {
- return "json: error calling MarshalJSON for type " + e.Type.String() + ": " + e.Err.Error()
+ srcFunc := e.sourceFunc
+ if srcFunc == "" {
+ srcFunc = "MarshalJSON"
+ }
+ return "json: error calling " + srcFunc +
+ " for type " + e.Type.String() +
+ ": " + e.Err.Error()
}
+// Unwrap returns the underlying error.
func (e *MarshalerError) Unwrap() error { return e.Err }
var hex = "0123456789abcdef"
@@ -278,17 +285,31 @@ var hex = "0123456789abcdef"
type encodeState struct {
bytes.Buffer // accumulated output
scratch [64]byte
+
+ // Keep track of what pointers we've seen in the current recursive call
+ // path, to avoid cycles that could lead to a stack overflow. Only do
+ // the relatively expensive map operations if ptrLevel is larger than
+ // startDetectingCyclesAfter, so that we skip the work if we're within a
+ // reasonable amount of nested pointers deep.
+ ptrLevel uint
+ ptrSeen map[interface{}]struct{}
}
+const startDetectingCyclesAfter = 1000
+
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
+ if len(e.ptrSeen) > 0 {
+ panic("ptrEncoder.encode should have emptied ptrSeen via defers")
+ }
+ e.ptrLevel = 0
return e
}
- return new(encodeState)
+ return &encodeState{ptrSeen: make(map[interface{}]struct{})}
}
// jsonError is an error wrapper type for internal use only.
@@ -392,19 +413,22 @@ var (
// newTypeEncoder constructs an encoderFunc for a type.
// The returned encoder only checks CanAddr when allowAddr is true.
func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
- if t.Implements(marshalerType) {
- return marshalerEncoder
- }
+ // If we have a non-pointer value whose type implements
+ // Marshaler with a value receiver, then we're better off taking
+ // the address of the value - otherwise we end up with an
+ // allocation as we cast the value to an interface.
if t.Kind() != reflect.Ptr && allowAddr && reflect.PtrTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
-
- if t.Implements(textMarshalerType) {
- return textMarshalerEncoder
+ if t.Implements(marshalerType) {
+ return marshalerEncoder
}
if t.Kind() != reflect.Ptr && allowAddr && reflect.PtrTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
+ if t.Implements(textMarshalerType) {
+ return textMarshalerEncoder
+ }
switch t.Kind() {
case reflect.Bool:
@@ -456,7 +480,7 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
err = compact(&e.Buffer, b, opts.escapeHTML)
}
if err != nil {
- e.error(&MarshalerError{v.Type(), err})
+ e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
@@ -473,7 +497,7 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
err = compact(&e.Buffer, b, opts.escapeHTML)
}
if err != nil {
- e.error(&MarshalerError{v.Type(), err})
+ e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
@@ -482,10 +506,14 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
- m := v.Interface().(encoding.TextMarshaler)
+ m, ok := v.Interface().(encoding.TextMarshaler)
+ if !ok {
+ e.WriteString("null")
+ return
+ }
b, err := m.MarshalText()
if err != nil {
- e.error(&MarshalerError{v.Type(), err})
+ e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.stringBytes(b, opts.escapeHTML)
}
@@ -499,7 +527,7 @@ func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
m := va.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err != nil {
- e.error(&MarshalerError{v.Type(), err})
+ e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.stringBytes(b, opts.escapeHTML)
}
@@ -597,20 +625,86 @@ func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if !isValidNumber(numStr) {
e.error(fmt.Errorf("json: invalid number literal %q", numStr))
}
+ if opts.quoted {
+ e.WriteByte('"')
+ }
e.WriteString(numStr)
+ if opts.quoted {
+ e.WriteByte('"')
+ }
return
}
if opts.quoted {
- sb, err := Marshal(v.String())
- if err != nil {
- e.error(err)
- }
- e.string(string(sb), opts.escapeHTML)
+ b := make([]byte, 0, v.Len()+2)
+ b = append(b, '"')
+ b = append(b, []byte(v.String())...)
+ b = append(b, '"')
+ e.stringBytes(b, opts.escapeHTML)
} else {
e.string(v.String(), opts.escapeHTML)
}
}
+// isValidNumber reports whether s is a valid JSON number literal.
+func isValidNumber(s string) bool {
+ // This function implements the JSON numbers grammar.
+ // See https://tools.ietf.org/html/rfc7159#section-6
+ // and https://json.org/number.gif
+
+ if s == "" {
+ return false
+ }
+
+ // Optional -
+ if s[0] == '-' {
+ s = s[1:]
+ if s == "" {
+ return false
+ }
+ }
+
+ // Digits
+ switch {
+ default:
+ return false
+
+ case s[0] == '0':
+ s = s[1:]
+
+ case '1' <= s[0] && s[0] <= '9':
+ s = s[1:]
+ for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
+ s = s[1:]
+ }
+ }
+
+ // . followed by 1 or more digits.
+ if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' {
+ s = s[2:]
+ for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
+ s = s[1:]
+ }
+ }
+
+ // e or E followed by an optional - or + and
+ // 1 or more digits.
+ if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') {
+ s = s[1:]
+ if s[0] == '+' || s[0] == '-' {
+ s = s[1:]
+ if s == "" {
+ return false
+ }
+ }
+ for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
+ s = s[1:]
+ }
+ }
+
+ // Make sure we are at the end.
+ return s == ""
+}
+
func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
@@ -692,7 +786,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
for i, v := range keys {
sv[i].v = v
if err := sv[i].resolve(); err != nil {
- e.error(&MarshalerError{v.Type(), err})
+ e.error(fmt.Errorf("json: encoding error for type %q: %q", v.Type().String(), err.Error()))
}
}
sort.Slice(sv, func(i, j int) bool { return sv[i].s < sv[j].s })
@@ -807,7 +901,18 @@ func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
+ if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
+ // We're a large number of nested ptrEncoder.encode calls deep;
+ // start checking if we've run into a pointer cycle.
+ ptr := v.Interface()
+ if _, ok := e.ptrSeen[ptr]; ok {
+ e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
+ }
+ e.ptrSeen[ptr] = struct{}{}
+ defer delete(e.ptrSeen, ptr)
+ }
pe.elemEnc(e, v.Elem(), opts)
+ e.ptrLevel--
}
func newPtrEncoder(t reflect.Type) encoderFunc {
@@ -872,6 +977,9 @@ func (w *reflectWithString) resolve() error {
return nil
}
if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok {
+ if w.v.Kind() == reflect.Ptr && w.v.IsNil() {
+ return nil
+ }
buf, err := tm.MarshalText()
w.s = string(buf)
return err