helper.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. // Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license found in the LICENSE file.
  3. package codec
  4. // Contains code shared by both encode and decode.
  5. import (
  6. "encoding/binary"
  7. "fmt"
  8. "math"
  9. "reflect"
  10. "sort"
  11. "strings"
  12. "sync"
  13. "time"
  14. "unicode"
  15. "unicode/utf8"
  16. )
  17. const (
  18. structTagName = "codec"
  19. // Support
  20. // encoding.BinaryMarshaler: MarshalBinary() (data []byte, err error)
  21. // encoding.BinaryUnmarshaler: UnmarshalBinary(data []byte) error
  22. // This constant flag will enable or disable it.
  23. supportBinaryMarshal = true
  24. // Each Encoder or Decoder uses a cache of functions based on conditionals,
  25. // so that the conditionals are not run every time.
  26. //
  27. // Either a map or a slice is used to keep track of the functions.
  28. // The map is more natural, but has a higher cost than a slice/array.
  29. // This flag (useMapForCodecCache) controls which is used.
  30. useMapForCodecCache = false
  31. // For some common container types, we can short-circuit an elaborate
  32. // reflection dance and call encode/decode directly.
  33. // The currently supported types are:
  34. // - slices of strings, or id's (int64,uint64) or interfaces.
  35. // - maps of str->str, str->intf, id(int64,uint64)->intf, intf->intf
  36. shortCircuitReflectToFastPath = true
  37. // for debugging, set this to false, to catch panic traces.
  38. // Note that this will always cause rpc tests to fail, since they need io.EOF sent via panic.
  39. recoverPanicToErr = true
  40. )
  41. type charEncoding uint8
  42. const (
  43. c_RAW charEncoding = iota
  44. c_UTF8
  45. c_UTF16LE
  46. c_UTF16BE
  47. c_UTF32LE
  48. c_UTF32BE
  49. )
  50. // valueType is the stream type
  51. type valueType uint8
  52. const (
  53. valueTypeUnset valueType = iota
  54. valueTypeNil
  55. valueTypeInt
  56. valueTypeUint
  57. valueTypeFloat
  58. valueTypeBool
  59. valueTypeString
  60. valueTypeSymbol
  61. valueTypeBytes
  62. valueTypeMap
  63. valueTypeArray
  64. valueTypeTimestamp
  65. valueTypeExt
  66. valueTypeInvalid = 0xff
  67. )
  68. var (
  69. bigen = binary.BigEndian
  70. structInfoFieldName = "_struct"
  71. cachedTypeInfo = make(map[uintptr]*typeInfo, 4)
  72. cachedTypeInfoMutex sync.RWMutex
  73. intfSliceTyp = reflect.TypeOf([]interface{}(nil))
  74. intfTyp = intfSliceTyp.Elem()
  75. strSliceTyp = reflect.TypeOf([]string(nil))
  76. boolSliceTyp = reflect.TypeOf([]bool(nil))
  77. uintSliceTyp = reflect.TypeOf([]uint(nil))
  78. uint8SliceTyp = reflect.TypeOf([]uint8(nil))
  79. uint16SliceTyp = reflect.TypeOf([]uint16(nil))
  80. uint32SliceTyp = reflect.TypeOf([]uint32(nil))
  81. uint64SliceTyp = reflect.TypeOf([]uint64(nil))
  82. intSliceTyp = reflect.TypeOf([]int(nil))
  83. int8SliceTyp = reflect.TypeOf([]int8(nil))
  84. int16SliceTyp = reflect.TypeOf([]int16(nil))
  85. int32SliceTyp = reflect.TypeOf([]int32(nil))
  86. int64SliceTyp = reflect.TypeOf([]int64(nil))
  87. float32SliceTyp = reflect.TypeOf([]float32(nil))
  88. float64SliceTyp = reflect.TypeOf([]float64(nil))
  89. mapIntfIntfTyp = reflect.TypeOf(map[interface{}]interface{}(nil))
  90. mapStrIntfTyp = reflect.TypeOf(map[string]interface{}(nil))
  91. mapStrStrTyp = reflect.TypeOf(map[string]string(nil))
  92. mapIntIntfTyp = reflect.TypeOf(map[int]interface{}(nil))
  93. mapInt64IntfTyp = reflect.TypeOf(map[int64]interface{}(nil))
  94. mapUintIntfTyp = reflect.TypeOf(map[uint]interface{}(nil))
  95. mapUint64IntfTyp = reflect.TypeOf(map[uint64]interface{}(nil))
  96. stringTyp = reflect.TypeOf("")
  97. timeTyp = reflect.TypeOf(time.Time{})
  98. rawExtTyp = reflect.TypeOf(RawExt{})
  99. mapBySliceTyp = reflect.TypeOf((*MapBySlice)(nil)).Elem()
  100. binaryMarshalerTyp = reflect.TypeOf((*binaryMarshaler)(nil)).Elem()
  101. binaryUnmarshalerTyp = reflect.TypeOf((*binaryUnmarshaler)(nil)).Elem()
  102. rawExtTypId = reflect.ValueOf(rawExtTyp).Pointer()
  103. intfTypId = reflect.ValueOf(intfTyp).Pointer()
  104. timeTypId = reflect.ValueOf(timeTyp).Pointer()
  105. intfSliceTypId = reflect.ValueOf(intfSliceTyp).Pointer()
  106. strSliceTypId = reflect.ValueOf(strSliceTyp).Pointer()
  107. boolSliceTypId = reflect.ValueOf(boolSliceTyp).Pointer()
  108. uintSliceTypId = reflect.ValueOf(uintSliceTyp).Pointer()
  109. uint8SliceTypId = reflect.ValueOf(uint8SliceTyp).Pointer()
  110. uint16SliceTypId = reflect.ValueOf(uint16SliceTyp).Pointer()
  111. uint32SliceTypId = reflect.ValueOf(uint32SliceTyp).Pointer()
  112. uint64SliceTypId = reflect.ValueOf(uint64SliceTyp).Pointer()
  113. intSliceTypId = reflect.ValueOf(intSliceTyp).Pointer()
  114. int8SliceTypId = reflect.ValueOf(int8SliceTyp).Pointer()
  115. int16SliceTypId = reflect.ValueOf(int16SliceTyp).Pointer()
  116. int32SliceTypId = reflect.ValueOf(int32SliceTyp).Pointer()
  117. int64SliceTypId = reflect.ValueOf(int64SliceTyp).Pointer()
  118. float32SliceTypId = reflect.ValueOf(float32SliceTyp).Pointer()
  119. float64SliceTypId = reflect.ValueOf(float64SliceTyp).Pointer()
  120. mapStrStrTypId = reflect.ValueOf(mapStrStrTyp).Pointer()
  121. mapIntfIntfTypId = reflect.ValueOf(mapIntfIntfTyp).Pointer()
  122. mapStrIntfTypId = reflect.ValueOf(mapStrIntfTyp).Pointer()
  123. mapIntIntfTypId = reflect.ValueOf(mapIntIntfTyp).Pointer()
  124. mapInt64IntfTypId = reflect.ValueOf(mapInt64IntfTyp).Pointer()
  125. mapUintIntfTypId = reflect.ValueOf(mapUintIntfTyp).Pointer()
  126. mapUint64IntfTypId = reflect.ValueOf(mapUint64IntfTyp).Pointer()
  127. // Id = reflect.ValueOf().Pointer()
  128. // mapBySliceTypId = reflect.ValueOf(mapBySliceTyp).Pointer()
  129. binaryMarshalerTypId = reflect.ValueOf(binaryMarshalerTyp).Pointer()
  130. binaryUnmarshalerTypId = reflect.ValueOf(binaryUnmarshalerTyp).Pointer()
  131. intBitsize uint8 = uint8(reflect.TypeOf(int(0)).Bits())
  132. uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits())
  133. bsAll0x00 = []byte{0, 0, 0, 0, 0, 0, 0, 0}
  134. bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
  135. )
  136. type binaryUnmarshaler interface {
  137. UnmarshalBinary(data []byte) error
  138. }
  139. type binaryMarshaler interface {
  140. MarshalBinary() (data []byte, err error)
  141. }
  142. // MapBySlice represents a slice which should be encoded as a map in the stream.
  143. // The slice contains a sequence of key-value pairs.
  144. type MapBySlice interface {
  145. MapBySlice()
  146. }
  147. // WARNING: DO NOT USE DIRECTLY. EXPORTED FOR GODOC BENEFIT. WILL BE REMOVED.
  148. //
  149. // BasicHandle encapsulates the common options and extension functions.
  150. type BasicHandle struct {
  151. extHandle
  152. EncodeOptions
  153. DecodeOptions
  154. }
  155. // Handle is the interface for a specific encoding format.
  156. //
  157. // Typically, a Handle is pre-configured before first time use,
  158. // and not modified while in use. Such a pre-configured Handle
  159. // is safe for concurrent access.
  160. type Handle interface {
  161. writeExt() bool
  162. getBasicHandle() *BasicHandle
  163. newEncDriver(w encWriter) encDriver
  164. newDecDriver(r decReader) decDriver
  165. }
  166. // RawExt represents raw unprocessed extension data.
  167. type RawExt struct {
  168. Tag byte
  169. Data []byte
  170. }
  171. type extTypeTagFn struct {
  172. rtid uintptr
  173. rt reflect.Type
  174. tag byte
  175. encFn func(reflect.Value) ([]byte, error)
  176. decFn func(reflect.Value, []byte) error
  177. }
  178. type extHandle []*extTypeTagFn
  179. // AddExt registers an encode and decode function for a reflect.Type.
  180. // Note that the type must be a named type, and specifically not
  181. // a pointer or Interface. An error is returned if that is not honored.
  182. //
  183. // To Deregister an ext, call AddExt with 0 tag, nil encfn and nil decfn.
  184. func (o *extHandle) AddExt(
  185. rt reflect.Type,
  186. tag byte,
  187. encfn func(reflect.Value) ([]byte, error),
  188. decfn func(reflect.Value, []byte) error,
  189. ) (err error) {
  190. // o is a pointer, because we may need to initialize it
  191. if rt.PkgPath() == "" || rt.Kind() == reflect.Interface {
  192. err = fmt.Errorf("codec.Handle.AddExt: Takes named type, especially not a pointer or interface: %T",
  193. reflect.Zero(rt).Interface())
  194. return
  195. }
  196. // o cannot be nil, since it is always embedded in a Handle.
  197. // if nil, let it panic.
  198. // if o == nil {
  199. // err = errors.New("codec.Handle.AddExt: extHandle cannot be a nil pointer.")
  200. // return
  201. // }
  202. rtid := reflect.ValueOf(rt).Pointer()
  203. for _, v := range *o {
  204. if v.rtid == rtid {
  205. v.tag, v.encFn, v.decFn = tag, encfn, decfn
  206. return
  207. }
  208. }
  209. *o = append(*o, &extTypeTagFn{rtid, rt, tag, encfn, decfn})
  210. return
  211. }
  212. func (o extHandle) getExt(rtid uintptr) *extTypeTagFn {
  213. for _, v := range o {
  214. if v.rtid == rtid {
  215. return v
  216. }
  217. }
  218. return nil
  219. }
  220. func (o extHandle) getExtForTag(tag byte) *extTypeTagFn {
  221. for _, v := range o {
  222. if v.tag == tag {
  223. return v
  224. }
  225. }
  226. return nil
  227. }
  228. func (o extHandle) getDecodeExtForTag(tag byte) (
  229. rv reflect.Value, fn func(reflect.Value, []byte) error) {
  230. if x := o.getExtForTag(tag); x != nil {
  231. // ext is only registered for base
  232. rv = reflect.New(x.rt).Elem()
  233. fn = x.decFn
  234. }
  235. return
  236. }
  237. func (o extHandle) getDecodeExt(rtid uintptr) (tag byte, fn func(reflect.Value, []byte) error) {
  238. if x := o.getExt(rtid); x != nil {
  239. tag = x.tag
  240. fn = x.decFn
  241. }
  242. return
  243. }
  244. func (o extHandle) getEncodeExt(rtid uintptr) (tag byte, fn func(reflect.Value) ([]byte, error)) {
  245. if x := o.getExt(rtid); x != nil {
  246. tag = x.tag
  247. fn = x.encFn
  248. }
  249. return
  250. }
  251. type structFieldInfo struct {
  252. encName string // encode name
  253. // only one of 'i' or 'is' can be set. If 'i' is -1, then 'is' has been set.
  254. is []int // (recursive/embedded) field index in struct
  255. i int16 // field index in struct
  256. omitEmpty bool
  257. toArray bool // if field is _struct, is the toArray set?
  258. // tag string // tag
  259. // name string // field name
  260. // encNameBs []byte // encoded name as byte stream
  261. // ikind int // kind of the field as an int i.e. int(reflect.Kind)
  262. }
  263. func parseStructFieldInfo(fname string, stag string) *structFieldInfo {
  264. if fname == "" {
  265. panic("parseStructFieldInfo: No Field Name")
  266. }
  267. si := structFieldInfo{
  268. // name: fname,
  269. encName: fname,
  270. // tag: stag,
  271. }
  272. if stag != "" {
  273. for i, s := range strings.Split(stag, ",") {
  274. if i == 0 {
  275. if s != "" {
  276. si.encName = s
  277. }
  278. } else {
  279. switch s {
  280. case "omitempty":
  281. si.omitEmpty = true
  282. case "toarray":
  283. si.toArray = true
  284. }
  285. }
  286. }
  287. }
  288. // si.encNameBs = []byte(si.encName)
  289. return &si
  290. }
  291. type sfiSortedByEncName []*structFieldInfo
  292. func (p sfiSortedByEncName) Len() int {
  293. return len(p)
  294. }
  295. func (p sfiSortedByEncName) Less(i, j int) bool {
  296. return p[i].encName < p[j].encName
  297. }
  298. func (p sfiSortedByEncName) Swap(i, j int) {
  299. p[i], p[j] = p[j], p[i]
  300. }
  301. // typeInfo keeps information about each type referenced in the encode/decode sequence.
  302. //
  303. // During an encode/decode sequence, we work as below:
  304. // - If base is a built in type, en/decode base value
  305. // - If base is registered as an extension, en/decode base value
  306. // - If type is binary(M/Unm)arshaler, call Binary(M/Unm)arshal method
  307. // - Else decode appropriately based on the reflect.Kind
  308. type typeInfo struct {
  309. sfi []*structFieldInfo // sorted. Used when enc/dec struct to map.
  310. sfip []*structFieldInfo // unsorted. Used when enc/dec struct to array.
  311. rt reflect.Type
  312. rtid uintptr
  313. // baseId gives pointer to the base reflect.Type, after deferencing
  314. // the pointers. E.g. base type of ***time.Time is time.Time.
  315. base reflect.Type
  316. baseId uintptr
  317. baseIndir int8 // number of indirections to get to base
  318. mbs bool // base type (T or *T) is a MapBySlice
  319. m bool // base type (T or *T) is a binaryMarshaler
  320. unm bool // base type (T or *T) is a binaryUnmarshaler
  321. mIndir int8 // number of indirections to get to binaryMarshaler type
  322. unmIndir int8 // number of indirections to get to binaryUnmarshaler type
  323. toArray bool // whether this (struct) type should be encoded as an array
  324. }
  325. func (ti *typeInfo) indexForEncName(name string) int {
  326. //tisfi := ti.sfi
  327. const binarySearchThreshold = 16
  328. if sfilen := len(ti.sfi); sfilen < binarySearchThreshold {
  329. // linear search. faster than binary search in my testing up to 16-field structs.
  330. for i, si := range ti.sfi {
  331. if si.encName == name {
  332. return i
  333. }
  334. }
  335. } else {
  336. // binary search. adapted from sort/search.go.
  337. h, i, j := 0, 0, sfilen
  338. for i < j {
  339. h = i + (j-i)/2
  340. if ti.sfi[h].encName < name {
  341. i = h + 1
  342. } else {
  343. j = h
  344. }
  345. }
  346. if i < sfilen && ti.sfi[i].encName == name {
  347. return i
  348. }
  349. }
  350. return -1
  351. }
  352. func getTypeInfo(rtid uintptr, rt reflect.Type) (pti *typeInfo) {
  353. var ok bool
  354. cachedTypeInfoMutex.RLock()
  355. pti, ok = cachedTypeInfo[rtid]
  356. cachedTypeInfoMutex.RUnlock()
  357. if ok {
  358. return
  359. }
  360. cachedTypeInfoMutex.Lock()
  361. defer cachedTypeInfoMutex.Unlock()
  362. if pti, ok = cachedTypeInfo[rtid]; ok {
  363. return
  364. }
  365. ti := typeInfo{rt: rt, rtid: rtid}
  366. pti = &ti
  367. var indir int8
  368. if ok, indir = implementsIntf(rt, binaryMarshalerTyp); ok {
  369. ti.m, ti.mIndir = true, indir
  370. }
  371. if ok, indir = implementsIntf(rt, binaryUnmarshalerTyp); ok {
  372. ti.unm, ti.unmIndir = true, indir
  373. }
  374. if ok, _ = implementsIntf(rt, mapBySliceTyp); ok {
  375. ti.mbs = true
  376. }
  377. pt := rt
  378. var ptIndir int8
  379. // for ; pt.Kind() == reflect.Ptr; pt, ptIndir = pt.Elem(), ptIndir+1 { }
  380. for pt.Kind() == reflect.Ptr {
  381. pt = pt.Elem()
  382. ptIndir++
  383. }
  384. if ptIndir == 0 {
  385. ti.base = rt
  386. ti.baseId = rtid
  387. } else {
  388. ti.base = pt
  389. ti.baseId = reflect.ValueOf(pt).Pointer()
  390. ti.baseIndir = ptIndir
  391. }
  392. if rt.Kind() == reflect.Struct {
  393. var siInfo *structFieldInfo
  394. if f, ok := rt.FieldByName(structInfoFieldName); ok {
  395. siInfo = parseStructFieldInfo(structInfoFieldName, f.Tag.Get(structTagName))
  396. ti.toArray = siInfo.toArray
  397. }
  398. sfip := make([]*structFieldInfo, 0, rt.NumField())
  399. rgetTypeInfo(rt, nil, make(map[string]bool), &sfip, siInfo)
  400. // // try to put all si close together
  401. // const tryToPutAllStructFieldInfoTogether = true
  402. // if tryToPutAllStructFieldInfoTogether {
  403. // sfip2 := make([]structFieldInfo, len(sfip))
  404. // for i, si := range sfip {
  405. // sfip2[i] = *si
  406. // }
  407. // for i := range sfip {
  408. // sfip[i] = &sfip2[i]
  409. // }
  410. // }
  411. ti.sfip = make([]*structFieldInfo, len(sfip))
  412. ti.sfi = make([]*structFieldInfo, len(sfip))
  413. copy(ti.sfip, sfip)
  414. sort.Sort(sfiSortedByEncName(sfip))
  415. copy(ti.sfi, sfip)
  416. }
  417. // sfi = sfip
  418. cachedTypeInfo[rtid] = pti
  419. return
  420. }
  421. func rgetTypeInfo(rt reflect.Type, indexstack []int, fnameToHastag map[string]bool,
  422. sfi *[]*structFieldInfo, siInfo *structFieldInfo,
  423. ) {
  424. // for rt.Kind() == reflect.Ptr {
  425. // // indexstack = append(indexstack, 0)
  426. // rt = rt.Elem()
  427. // }
  428. for j := 0; j < rt.NumField(); j++ {
  429. f := rt.Field(j)
  430. stag := f.Tag.Get(structTagName)
  431. if stag == "-" {
  432. continue
  433. }
  434. if r1, _ := utf8.DecodeRuneInString(f.Name); r1 == utf8.RuneError || !unicode.IsUpper(r1) {
  435. continue
  436. }
  437. // if anonymous and there is no struct tag and its a struct (or pointer to struct), inline it.
  438. if f.Anonymous && stag == "" {
  439. ft := f.Type
  440. for ft.Kind() == reflect.Ptr {
  441. ft = ft.Elem()
  442. }
  443. if ft.Kind() == reflect.Struct {
  444. indexstack2 := append(append(make([]int, 0, len(indexstack)+4), indexstack...), j)
  445. rgetTypeInfo(ft, indexstack2, fnameToHastag, sfi, siInfo)
  446. continue
  447. }
  448. }
  449. // do not let fields with same name in embedded structs override field at higher level.
  450. // this must be done after anonymous check, to allow anonymous field
  451. // still include their child fields
  452. if _, ok := fnameToHastag[f.Name]; ok {
  453. continue
  454. }
  455. si := parseStructFieldInfo(f.Name, stag)
  456. // si.ikind = int(f.Type.Kind())
  457. if len(indexstack) == 0 {
  458. si.i = int16(j)
  459. } else {
  460. si.i = -1
  461. si.is = append(append(make([]int, 0, len(indexstack)+4), indexstack...), j)
  462. }
  463. if siInfo != nil {
  464. if siInfo.omitEmpty {
  465. si.omitEmpty = true
  466. }
  467. }
  468. *sfi = append(*sfi, si)
  469. fnameToHastag[f.Name] = stag != ""
  470. }
  471. }
  472. func panicToErr(err *error) {
  473. if recoverPanicToErr {
  474. if x := recover(); x != nil {
  475. //debug.PrintStack()
  476. panicValToErr(x, err)
  477. }
  478. }
  479. }
  480. func doPanic(tag string, format string, params ...interface{}) {
  481. params2 := make([]interface{}, len(params)+1)
  482. params2[0] = tag
  483. copy(params2[1:], params)
  484. panic(fmt.Errorf("%s: "+format, params2...))
  485. }
  486. func checkOverflowFloat32(f float64, doCheck bool) {
  487. if !doCheck {
  488. return
  489. }
  490. // check overflow (logic adapted from std pkg reflect/value.go OverflowFloat()
  491. f2 := f
  492. if f2 < 0 {
  493. f2 = -f
  494. }
  495. if math.MaxFloat32 < f2 && f2 <= math.MaxFloat64 {
  496. decErr("Overflow float32 value: %v", f2)
  497. }
  498. }
  499. func checkOverflow(ui uint64, i int64, bitsize uint8) {
  500. // check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
  501. if bitsize == 0 {
  502. return
  503. }
  504. if i != 0 {
  505. if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc {
  506. decErr("Overflow int value: %v", i)
  507. }
  508. }
  509. if ui != 0 {
  510. if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc {
  511. decErr("Overflow uint value: %v", ui)
  512. }
  513. }
  514. }