unmarshal.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package jsonutil
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "reflect"
  10. "strings"
  11. "time"
  12. "github.com/aws/aws-sdk-go/aws"
  13. "github.com/aws/aws-sdk-go/aws/awserr"
  14. "github.com/aws/aws-sdk-go/private/protocol"
  15. )
  16. var millisecondsFloat = new(big.Float).SetInt64(1e3)
  17. // UnmarshalJSONError unmarshal's the reader's JSON document into the passed in
  18. // type. The value to unmarshal the json document into must be a pointer to the
  19. // type.
  20. func UnmarshalJSONError(v interface{}, stream io.Reader) error {
  21. var errBuf bytes.Buffer
  22. body := io.TeeReader(stream, &errBuf)
  23. err := json.NewDecoder(body).Decode(v)
  24. if err != nil {
  25. msg := "failed decoding error message"
  26. if err == io.EOF {
  27. msg = "error message missing"
  28. err = nil
  29. }
  30. return awserr.NewUnmarshalError(err, msg, errBuf.Bytes())
  31. }
  32. return nil
  33. }
  34. // UnmarshalJSON reads a stream and unmarshals the results in object v.
  35. func UnmarshalJSON(v interface{}, stream io.Reader) error {
  36. var out interface{}
  37. decoder := json.NewDecoder(stream)
  38. decoder.UseNumber()
  39. err := decoder.Decode(&out)
  40. if err == io.EOF {
  41. return nil
  42. } else if err != nil {
  43. return err
  44. }
  45. return unmarshaler{}.unmarshalAny(reflect.ValueOf(v), out, "")
  46. }
  47. // UnmarshalJSONCaseInsensitive reads a stream and unmarshals the result into the
  48. // object v. Ignores casing for structure members.
  49. func UnmarshalJSONCaseInsensitive(v interface{}, stream io.Reader) error {
  50. var out interface{}
  51. decoder := json.NewDecoder(stream)
  52. decoder.UseNumber()
  53. err := decoder.Decode(&out)
  54. if err == io.EOF {
  55. return nil
  56. } else if err != nil {
  57. return err
  58. }
  59. return unmarshaler{
  60. caseInsensitive: true,
  61. }.unmarshalAny(reflect.ValueOf(v), out, "")
  62. }
  63. type unmarshaler struct {
  64. caseInsensitive bool
  65. }
  66. func (u unmarshaler) unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
  67. vtype := value.Type()
  68. if vtype.Kind() == reflect.Ptr {
  69. vtype = vtype.Elem() // check kind of actual element type
  70. }
  71. t := tag.Get("type")
  72. if t == "" {
  73. switch vtype.Kind() {
  74. case reflect.Struct:
  75. // also it can't be a time object
  76. if _, ok := value.Interface().(*time.Time); !ok {
  77. t = "structure"
  78. }
  79. case reflect.Slice:
  80. // also it can't be a byte slice
  81. if _, ok := value.Interface().([]byte); !ok {
  82. t = "list"
  83. }
  84. case reflect.Map:
  85. // cannot be a JSONValue map
  86. if _, ok := value.Interface().(aws.JSONValue); !ok {
  87. t = "map"
  88. }
  89. }
  90. }
  91. switch t {
  92. case "structure":
  93. if field, ok := vtype.FieldByName("_"); ok {
  94. tag = field.Tag
  95. }
  96. return u.unmarshalStruct(value, data, tag)
  97. case "list":
  98. return u.unmarshalList(value, data, tag)
  99. case "map":
  100. return u.unmarshalMap(value, data, tag)
  101. default:
  102. return u.unmarshalScalar(value, data, tag)
  103. }
  104. }
  105. func (u unmarshaler) unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
  106. if data == nil {
  107. return nil
  108. }
  109. mapData, ok := data.(map[string]interface{})
  110. if !ok {
  111. return fmt.Errorf("JSON value is not a structure (%#v)", data)
  112. }
  113. t := value.Type()
  114. if value.Kind() == reflect.Ptr {
  115. if value.IsNil() { // create the structure if it's nil
  116. s := reflect.New(value.Type().Elem())
  117. value.Set(s)
  118. value = s
  119. }
  120. value = value.Elem()
  121. t = t.Elem()
  122. }
  123. // unwrap any payloads
  124. if payload := tag.Get("payload"); payload != "" {
  125. field, _ := t.FieldByName(payload)
  126. return u.unmarshalAny(value.FieldByName(payload), data, field.Tag)
  127. }
  128. for i := 0; i < t.NumField(); i++ {
  129. field := t.Field(i)
  130. if field.PkgPath != "" {
  131. continue // ignore unexported fields
  132. }
  133. // figure out what this field is called
  134. name := field.Name
  135. if locName := field.Tag.Get("locationName"); locName != "" {
  136. name = locName
  137. }
  138. if u.caseInsensitive {
  139. if _, ok := mapData[name]; !ok {
  140. // Fallback to uncased name search if the exact name didn't match.
  141. for kn, v := range mapData {
  142. if strings.EqualFold(kn, name) {
  143. mapData[name] = v
  144. }
  145. }
  146. }
  147. }
  148. member := value.FieldByIndex(field.Index)
  149. err := u.unmarshalAny(member, mapData[name], field.Tag)
  150. if err != nil {
  151. return err
  152. }
  153. }
  154. return nil
  155. }
  156. func (u unmarshaler) unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
  157. if data == nil {
  158. return nil
  159. }
  160. listData, ok := data.([]interface{})
  161. if !ok {
  162. return fmt.Errorf("JSON value is not a list (%#v)", data)
  163. }
  164. if value.IsNil() {
  165. l := len(listData)
  166. value.Set(reflect.MakeSlice(value.Type(), l, l))
  167. }
  168. for i, c := range listData {
  169. err := u.unmarshalAny(value.Index(i), c, "")
  170. if err != nil {
  171. return err
  172. }
  173. }
  174. return nil
  175. }
  176. func (u unmarshaler) unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
  177. if data == nil {
  178. return nil
  179. }
  180. mapData, ok := data.(map[string]interface{})
  181. if !ok {
  182. return fmt.Errorf("JSON value is not a map (%#v)", data)
  183. }
  184. if value.IsNil() {
  185. value.Set(reflect.MakeMap(value.Type()))
  186. }
  187. for k, v := range mapData {
  188. kvalue := reflect.ValueOf(k)
  189. vvalue := reflect.New(value.Type().Elem()).Elem()
  190. u.unmarshalAny(vvalue, v, "")
  191. value.SetMapIndex(kvalue, vvalue)
  192. }
  193. return nil
  194. }
  195. func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
  196. switch d := data.(type) {
  197. case nil:
  198. return nil // nothing to do here
  199. case string:
  200. switch value.Interface().(type) {
  201. case *string:
  202. value.Set(reflect.ValueOf(&d))
  203. case []byte:
  204. b, err := base64.StdEncoding.DecodeString(d)
  205. if err != nil {
  206. return err
  207. }
  208. value.Set(reflect.ValueOf(b))
  209. case *time.Time:
  210. format := tag.Get("timestampFormat")
  211. if len(format) == 0 {
  212. format = protocol.ISO8601TimeFormatName
  213. }
  214. t, err := protocol.ParseTime(format, d)
  215. if err != nil {
  216. return err
  217. }
  218. value.Set(reflect.ValueOf(&t))
  219. case aws.JSONValue:
  220. // No need to use escaping as the value is a non-quoted string.
  221. v, err := protocol.DecodeJSONValue(d, protocol.NoEscape)
  222. if err != nil {
  223. return err
  224. }
  225. value.Set(reflect.ValueOf(v))
  226. default:
  227. return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
  228. }
  229. case json.Number:
  230. switch value.Interface().(type) {
  231. case *int64:
  232. // Retain the old behavior where we would just truncate the float64
  233. // calling d.Int64() here could cause an invalid syntax error due to the usage of strconv.ParseInt
  234. f, err := d.Float64()
  235. if err != nil {
  236. return err
  237. }
  238. di := int64(f)
  239. value.Set(reflect.ValueOf(&di))
  240. case *float64:
  241. f, err := d.Float64()
  242. if err != nil {
  243. return err
  244. }
  245. value.Set(reflect.ValueOf(&f))
  246. case *time.Time:
  247. float, ok := new(big.Float).SetString(d.String())
  248. if !ok {
  249. return fmt.Errorf("unsupported float time representation: %v", d.String())
  250. }
  251. float = float.Mul(float, millisecondsFloat)
  252. ms, _ := float.Int64()
  253. t := time.Unix(0, ms*1e6).UTC()
  254. value.Set(reflect.ValueOf(&t))
  255. default:
  256. return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
  257. }
  258. case bool:
  259. switch value.Interface().(type) {
  260. case *bool:
  261. value.Set(reflect.ValueOf(&d))
  262. default:
  263. return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
  264. }
  265. default:
  266. return fmt.Errorf("unsupported JSON value (%v)", data)
  267. }
  268. return nil
  269. }