extension.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. package msgp
  2. import (
  3. "errors"
  4. "math"
  5. "strconv"
  6. )
  7. const (
  8. // Complex64Extension is the extension number used for complex64
  9. Complex64Extension = 3
  10. // Complex128Extension is the extension number used for complex128
  11. Complex128Extension = 4
  12. // TimeExtension is the extension number used for time.Time
  13. TimeExtension = 5
  14. )
  15. // our extensions live here
  16. var extensionReg = make(map[int8]func() Extension)
  17. // RegisterExtension registers extensions so that they
  18. // can be initialized and returned by methods that
  19. // decode `interface{}` values. This should only
  20. // be called during initialization. f() should return
  21. // a newly-initialized zero value of the extension. Keep in
  22. // mind that extensions 3, 4, and 5 are reserved for
  23. // complex64, complex128, and time.Time, respectively,
  24. // and that MessagePack reserves extension types from -127 to -1.
  25. //
  26. // For example, if you wanted to register a user-defined struct:
  27. //
  28. // msgp.RegisterExtension(10, func() msgp.Extension { &MyExtension{} })
  29. //
  30. // RegisterExtension will panic if you call it multiple times
  31. // with the same 'typ' argument, or if you use a reserved
  32. // type (3, 4, or 5).
  33. func RegisterExtension(typ int8, f func() Extension) {
  34. switch typ {
  35. case Complex64Extension, Complex128Extension, TimeExtension:
  36. panic(errors.New("msgp: forbidden extension type: " + strconv.Itoa(int(typ))))
  37. }
  38. if _, ok := extensionReg[typ]; ok {
  39. panic(errors.New("msgp: RegisterExtension() called with typ " + strconv.Itoa(int(typ)) + " more than once"))
  40. }
  41. extensionReg[typ] = f
  42. }
  43. // ExtensionTypeError is an error type returned
  44. // when there is a mis-match between an extension type
  45. // and the type encoded on the wire
  46. type ExtensionTypeError struct {
  47. Got int8
  48. Want int8
  49. }
  50. // Error implements the error interface
  51. func (e ExtensionTypeError) Error() string {
  52. return "msgp: error decoding extension: wanted type " + strconv.Itoa(int(e.Want)) + "; got type " + strconv.Itoa(int(e.Got))
  53. }
  54. // Resumable returns 'true' for ExtensionTypeErrors
  55. func (e ExtensionTypeError) Resumable() bool { return true }
  56. func errExt(got int8, wanted int8) error {
  57. return ExtensionTypeError{Got: got, Want: wanted}
  58. }
  59. // Extension is the interface fulfilled
  60. // by types that want to define their
  61. // own binary encoding.
  62. type Extension interface {
  63. // ExtensionType should return
  64. // a int8 that identifies the concrete
  65. // type of the extension. (Types <0 are
  66. // officially reserved by the MessagePack
  67. // specifications.)
  68. ExtensionType() int8
  69. // Len should return the length
  70. // of the data to be encoded
  71. Len() int
  72. // MarshalBinaryTo should copy
  73. // the data into the supplied slice,
  74. // assuming that the slice has length Len()
  75. MarshalBinaryTo([]byte) error
  76. UnmarshalBinary([]byte) error
  77. }
  78. // RawExtension implements the Extension interface
  79. type RawExtension struct {
  80. Data []byte
  81. Type int8
  82. }
  83. // ExtensionType implements Extension.ExtensionType, and returns r.Type
  84. func (r *RawExtension) ExtensionType() int8 { return r.Type }
  85. // Len implements Extension.Len, and returns len(r.Data)
  86. func (r *RawExtension) Len() int { return len(r.Data) }
  87. // MarshalBinaryTo implements Extension.MarshalBinaryTo,
  88. // and returns a copy of r.Data
  89. func (r *RawExtension) MarshalBinaryTo(d []byte) error {
  90. copy(d, r.Data)
  91. return nil
  92. }
  93. // UnmarshalBinary implements Extension.UnmarshalBinary,
  94. // and sets r.Data to the contents of the provided slice
  95. func (r *RawExtension) UnmarshalBinary(b []byte) error {
  96. if cap(r.Data) >= len(b) {
  97. r.Data = r.Data[0:len(b)]
  98. } else {
  99. r.Data = make([]byte, len(b))
  100. }
  101. copy(r.Data, b)
  102. return nil
  103. }
  104. // WriteExtension writes an extension type to the writer
  105. func (mw *Writer) WriteExtension(e Extension) error {
  106. l := e.Len()
  107. var err error
  108. switch l {
  109. case 0:
  110. o, err := mw.require(3)
  111. if err != nil {
  112. return err
  113. }
  114. mw.buf[o] = mext8
  115. mw.buf[o+1] = 0
  116. mw.buf[o+2] = byte(e.ExtensionType())
  117. case 1:
  118. o, err := mw.require(2)
  119. if err != nil {
  120. return err
  121. }
  122. mw.buf[o] = mfixext1
  123. mw.buf[o+1] = byte(e.ExtensionType())
  124. case 2:
  125. o, err := mw.require(2)
  126. if err != nil {
  127. return err
  128. }
  129. mw.buf[o] = mfixext2
  130. mw.buf[o+1] = byte(e.ExtensionType())
  131. case 4:
  132. o, err := mw.require(2)
  133. if err != nil {
  134. return err
  135. }
  136. mw.buf[o] = mfixext4
  137. mw.buf[o+1] = byte(e.ExtensionType())
  138. case 8:
  139. o, err := mw.require(2)
  140. if err != nil {
  141. return err
  142. }
  143. mw.buf[o] = mfixext8
  144. mw.buf[o+1] = byte(e.ExtensionType())
  145. case 16:
  146. o, err := mw.require(2)
  147. if err != nil {
  148. return err
  149. }
  150. mw.buf[o] = mfixext16
  151. mw.buf[o+1] = byte(e.ExtensionType())
  152. default:
  153. switch {
  154. case l < math.MaxUint8:
  155. o, err := mw.require(3)
  156. if err != nil {
  157. return err
  158. }
  159. mw.buf[o] = mext8
  160. mw.buf[o+1] = byte(uint8(l))
  161. mw.buf[o+2] = byte(e.ExtensionType())
  162. case l < math.MaxUint16:
  163. o, err := mw.require(4)
  164. if err != nil {
  165. return err
  166. }
  167. mw.buf[o] = mext16
  168. big.PutUint16(mw.buf[o+1:], uint16(l))
  169. mw.buf[o+3] = byte(e.ExtensionType())
  170. default:
  171. o, err := mw.require(6)
  172. if err != nil {
  173. return err
  174. }
  175. mw.buf[o] = mext32
  176. big.PutUint32(mw.buf[o+1:], uint32(l))
  177. mw.buf[o+5] = byte(e.ExtensionType())
  178. }
  179. }
  180. // we can only write directly to the
  181. // buffer if we're sure that it
  182. // fits the object
  183. if l <= mw.bufsize() {
  184. o, err := mw.require(l)
  185. if err != nil {
  186. return err
  187. }
  188. return e.MarshalBinaryTo(mw.buf[o:])
  189. }
  190. // here we create a new buffer
  191. // just large enough for the body
  192. // and save it as the write buffer
  193. err = mw.flush()
  194. if err != nil {
  195. return err
  196. }
  197. buf := make([]byte, l)
  198. err = e.MarshalBinaryTo(buf)
  199. if err != nil {
  200. return err
  201. }
  202. mw.buf = buf
  203. mw.wloc = l
  204. return nil
  205. }
  206. // peek at the extension type, assuming the next
  207. // kind to be read is Extension
  208. func (m *Reader) peekExtensionType() (int8, error) {
  209. p, err := m.R.Peek(2)
  210. if err != nil {
  211. return 0, err
  212. }
  213. spec := getBytespec(p[0])
  214. if spec.typ != ExtensionType {
  215. return 0, badPrefix(ExtensionType, p[0])
  216. }
  217. if spec.extra == constsize {
  218. return int8(p[1]), nil
  219. }
  220. size := spec.size
  221. p, err = m.R.Peek(int(size))
  222. if err != nil {
  223. return 0, err
  224. }
  225. return int8(p[size-1]), nil
  226. }
  227. // peekExtension peeks at the extension encoding type
  228. // (must guarantee at least 1 byte in 'b')
  229. func peekExtension(b []byte) (int8, error) {
  230. spec := getBytespec(b[0])
  231. size := spec.size
  232. if spec.typ != ExtensionType {
  233. return 0, badPrefix(ExtensionType, b[0])
  234. }
  235. if len(b) < int(size) {
  236. return 0, ErrShortBytes
  237. }
  238. // for fixed extensions,
  239. // the type information is in
  240. // the second byte
  241. if spec.extra == constsize {
  242. return int8(b[1]), nil
  243. }
  244. // otherwise, it's in the last
  245. // part of the prefix
  246. return int8(b[size-1]), nil
  247. }
  248. // ReadExtension reads the next object from the reader
  249. // as an extension. ReadExtension will fail if the next
  250. // object in the stream is not an extension, or if
  251. // e.Type() is not the same as the wire type.
  252. func (m *Reader) ReadExtension(e Extension) (err error) {
  253. var p []byte
  254. p, err = m.R.Peek(2)
  255. if err != nil {
  256. return
  257. }
  258. lead := p[0]
  259. var read int
  260. var off int
  261. switch lead {
  262. case mfixext1:
  263. if int8(p[1]) != e.ExtensionType() {
  264. err = errExt(int8(p[1]), e.ExtensionType())
  265. return
  266. }
  267. p, err = m.R.Peek(3)
  268. if err != nil {
  269. return
  270. }
  271. err = e.UnmarshalBinary(p[2:])
  272. if err == nil {
  273. _, err = m.R.Skip(3)
  274. }
  275. return
  276. case mfixext2:
  277. if int8(p[1]) != e.ExtensionType() {
  278. err = errExt(int8(p[1]), e.ExtensionType())
  279. return
  280. }
  281. p, err = m.R.Peek(4)
  282. if err != nil {
  283. return
  284. }
  285. err = e.UnmarshalBinary(p[2:])
  286. if err == nil {
  287. _, err = m.R.Skip(4)
  288. }
  289. return
  290. case mfixext4:
  291. if int8(p[1]) != e.ExtensionType() {
  292. err = errExt(int8(p[1]), e.ExtensionType())
  293. return
  294. }
  295. p, err = m.R.Peek(6)
  296. if err != nil {
  297. return
  298. }
  299. err = e.UnmarshalBinary(p[2:])
  300. if err == nil {
  301. _, err = m.R.Skip(6)
  302. }
  303. return
  304. case mfixext8:
  305. if int8(p[1]) != e.ExtensionType() {
  306. err = errExt(int8(p[1]), e.ExtensionType())
  307. return
  308. }
  309. p, err = m.R.Peek(10)
  310. if err != nil {
  311. return
  312. }
  313. err = e.UnmarshalBinary(p[2:])
  314. if err == nil {
  315. _, err = m.R.Skip(10)
  316. }
  317. return
  318. case mfixext16:
  319. if int8(p[1]) != e.ExtensionType() {
  320. err = errExt(int8(p[1]), e.ExtensionType())
  321. return
  322. }
  323. p, err = m.R.Peek(18)
  324. if err != nil {
  325. return
  326. }
  327. err = e.UnmarshalBinary(p[2:])
  328. if err == nil {
  329. _, err = m.R.Skip(18)
  330. }
  331. return
  332. case mext8:
  333. p, err = m.R.Peek(3)
  334. if err != nil {
  335. return
  336. }
  337. if int8(p[2]) != e.ExtensionType() {
  338. err = errExt(int8(p[2]), e.ExtensionType())
  339. return
  340. }
  341. read = int(uint8(p[1]))
  342. off = 3
  343. case mext16:
  344. p, err = m.R.Peek(4)
  345. if err != nil {
  346. return
  347. }
  348. if int8(p[3]) != e.ExtensionType() {
  349. err = errExt(int8(p[3]), e.ExtensionType())
  350. return
  351. }
  352. read = int(big.Uint16(p[1:]))
  353. off = 4
  354. case mext32:
  355. p, err = m.R.Peek(6)
  356. if err != nil {
  357. return
  358. }
  359. if int8(p[5]) != e.ExtensionType() {
  360. err = errExt(int8(p[5]), e.ExtensionType())
  361. return
  362. }
  363. read = int(big.Uint32(p[1:]))
  364. off = 6
  365. default:
  366. err = badPrefix(ExtensionType, lead)
  367. return
  368. }
  369. p, err = m.R.Peek(read + off)
  370. if err != nil {
  371. return
  372. }
  373. err = e.UnmarshalBinary(p[off:])
  374. if err == nil {
  375. _, err = m.R.Skip(read + off)
  376. }
  377. return
  378. }
  379. // AppendExtension appends a MessagePack extension to the provided slice
  380. func AppendExtension(b []byte, e Extension) ([]byte, error) {
  381. l := e.Len()
  382. var o []byte
  383. var n int
  384. switch l {
  385. case 0:
  386. o, n = ensure(b, 3)
  387. o[n] = mext8
  388. o[n+1] = 0
  389. o[n+2] = byte(e.ExtensionType())
  390. return o[:n+3], nil
  391. case 1:
  392. o, n = ensure(b, 3)
  393. o[n] = mfixext1
  394. o[n+1] = byte(e.ExtensionType())
  395. n += 2
  396. case 2:
  397. o, n = ensure(b, 4)
  398. o[n] = mfixext2
  399. o[n+1] = byte(e.ExtensionType())
  400. n += 2
  401. case 4:
  402. o, n = ensure(b, 6)
  403. o[n] = mfixext4
  404. o[n+1] = byte(e.ExtensionType())
  405. n += 2
  406. case 8:
  407. o, n = ensure(b, 10)
  408. o[n] = mfixext8
  409. o[n+1] = byte(e.ExtensionType())
  410. n += 2
  411. case 16:
  412. o, n = ensure(b, 18)
  413. o[n] = mfixext16
  414. o[n+1] = byte(e.ExtensionType())
  415. n += 2
  416. default:
  417. switch {
  418. case l < math.MaxUint8:
  419. o, n = ensure(b, l+3)
  420. o[n] = mext8
  421. o[n+1] = byte(uint8(l))
  422. o[n+2] = byte(e.ExtensionType())
  423. n += 3
  424. case l < math.MaxUint16:
  425. o, n = ensure(b, l+4)
  426. o[n] = mext16
  427. big.PutUint16(o[n+1:], uint16(l))
  428. o[n+3] = byte(e.ExtensionType())
  429. n += 4
  430. default:
  431. o, n = ensure(b, l+6)
  432. o[n] = mext32
  433. big.PutUint32(o[n+1:], uint32(l))
  434. o[n+5] = byte(e.ExtensionType())
  435. n += 6
  436. }
  437. }
  438. return o, e.MarshalBinaryTo(o[n:])
  439. }
  440. // ReadExtensionBytes reads an extension from 'b' into 'e'
  441. // and returns any remaining bytes.
  442. // Possible errors:
  443. // - ErrShortBytes ('b' not long enough)
  444. // - ExtensionTypeError{} (wire type not the same as e.Type())
  445. // - TypeError{} (next object not an extension)
  446. // - InvalidPrefixError
  447. // - An umarshal error returned from e.UnmarshalBinary
  448. func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) {
  449. l := len(b)
  450. if l < 3 {
  451. return b, ErrShortBytes
  452. }
  453. lead := b[0]
  454. var (
  455. sz int // size of 'data'
  456. off int // offset of 'data'
  457. typ int8
  458. )
  459. switch lead {
  460. case mfixext1:
  461. typ = int8(b[1])
  462. sz = 1
  463. off = 2
  464. case mfixext2:
  465. typ = int8(b[1])
  466. sz = 2
  467. off = 2
  468. case mfixext4:
  469. typ = int8(b[1])
  470. sz = 4
  471. off = 2
  472. case mfixext8:
  473. typ = int8(b[1])
  474. sz = 8
  475. off = 2
  476. case mfixext16:
  477. typ = int8(b[1])
  478. sz = 16
  479. off = 2
  480. case mext8:
  481. sz = int(uint8(b[1]))
  482. typ = int8(b[2])
  483. off = 3
  484. if sz == 0 {
  485. return b[3:], e.UnmarshalBinary(b[3:3])
  486. }
  487. case mext16:
  488. if l < 4 {
  489. return b, ErrShortBytes
  490. }
  491. sz = int(big.Uint16(b[1:]))
  492. typ = int8(b[3])
  493. off = 4
  494. case mext32:
  495. if l < 6 {
  496. return b, ErrShortBytes
  497. }
  498. sz = int(big.Uint32(b[1:]))
  499. typ = int8(b[5])
  500. off = 6
  501. default:
  502. return b, badPrefix(ExtensionType, lead)
  503. }
  504. if typ != e.ExtensionType() {
  505. return b, errExt(typ, e.ExtensionType())
  506. }
  507. // the data of the extension starts
  508. // at 'off' and is 'sz' bytes long
  509. if len(b[off:]) < sz {
  510. return b, ErrShortBytes
  511. }
  512. tot := off + sz
  513. return b[tot:], e.UnmarshalBinary(b[off:tot])
  514. }