decode.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "math/bits"
  7. "google.golang.org/protobuf/encoding/protowire"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/proto"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. preg "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. piface "google.golang.org/protobuf/runtime/protoiface"
  15. )
  16. var errDecode = errors.New("cannot parse invalid wire-format data")
  17. var errRecursionDepth = errors.New("exceeded maximum recursion depth")
  18. type unmarshalOptions struct {
  19. flags protoiface.UnmarshalInputFlags
  20. resolver interface {
  21. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  22. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  23. }
  24. depth int
  25. }
  26. func (o unmarshalOptions) Options() proto.UnmarshalOptions {
  27. return proto.UnmarshalOptions{
  28. Merge: true,
  29. AllowPartial: true,
  30. DiscardUnknown: o.DiscardUnknown(),
  31. Resolver: o.resolver,
  32. }
  33. }
  34. func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
  35. func (o unmarshalOptions) IsDefault() bool {
  36. return o.flags == 0 && o.resolver == preg.GlobalTypes
  37. }
  38. var lazyUnmarshalOptions = unmarshalOptions{
  39. resolver: preg.GlobalTypes,
  40. depth: protowire.DefaultRecursionLimit,
  41. }
  42. type unmarshalOutput struct {
  43. n int // number of bytes consumed
  44. initialized bool
  45. }
  46. // unmarshal is protoreflect.Methods.Unmarshal.
  47. func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
  48. var p pointer
  49. if ms, ok := in.Message.(*messageState); ok {
  50. p = ms.pointer()
  51. } else {
  52. p = in.Message.(*messageReflectWrapper).pointer()
  53. }
  54. out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
  55. flags: in.Flags,
  56. resolver: in.Resolver,
  57. depth: in.Depth,
  58. })
  59. var flags piface.UnmarshalOutputFlags
  60. if out.initialized {
  61. flags |= piface.UnmarshalInitialized
  62. }
  63. return piface.UnmarshalOutput{
  64. Flags: flags,
  65. }, err
  66. }
  67. // errUnknown is returned during unmarshaling to indicate a parse error that
  68. // should result in a field being placed in the unknown fields section (for example,
  69. // when the wire type doesn't match) as opposed to the entire unmarshal operation
  70. // failing (for example, when a field extends past the available input).
  71. //
  72. // This is a sentinel error which should never be visible to the user.
  73. var errUnknown = errors.New("unknown")
  74. func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  75. mi.init()
  76. opts.depth--
  77. if opts.depth < 0 {
  78. return out, errRecursionDepth
  79. }
  80. if flags.ProtoLegacy && mi.isMessageSet {
  81. return unmarshalMessageSet(mi, b, p, opts)
  82. }
  83. initialized := true
  84. var requiredMask uint64
  85. var exts *map[int32]ExtensionField
  86. start := len(b)
  87. for len(b) > 0 {
  88. // Parse the tag (field number and wire type).
  89. var tag uint64
  90. if b[0] < 0x80 {
  91. tag = uint64(b[0])
  92. b = b[1:]
  93. } else if len(b) >= 2 && b[1] < 128 {
  94. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  95. b = b[2:]
  96. } else {
  97. var n int
  98. tag, n = protowire.ConsumeVarint(b)
  99. if n < 0 {
  100. return out, errDecode
  101. }
  102. b = b[n:]
  103. }
  104. var num protowire.Number
  105. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  106. return out, errDecode
  107. } else {
  108. num = protowire.Number(n)
  109. }
  110. wtyp := protowire.Type(tag & 7)
  111. if wtyp == protowire.EndGroupType {
  112. if num != groupTag {
  113. return out, errDecode
  114. }
  115. groupTag = 0
  116. break
  117. }
  118. var f *coderFieldInfo
  119. if int(num) < len(mi.denseCoderFields) {
  120. f = mi.denseCoderFields[num]
  121. } else {
  122. f = mi.coderFields[num]
  123. }
  124. var n int
  125. err := errUnknown
  126. switch {
  127. case f != nil:
  128. if f.funcs.unmarshal == nil {
  129. break
  130. }
  131. var o unmarshalOutput
  132. o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
  133. n = o.n
  134. if err != nil {
  135. break
  136. }
  137. requiredMask |= f.validation.requiredBit
  138. if f.funcs.isInit != nil && !o.initialized {
  139. initialized = false
  140. }
  141. default:
  142. // Possible extension.
  143. if exts == nil && mi.extensionOffset.IsValid() {
  144. exts = p.Apply(mi.extensionOffset).Extensions()
  145. if *exts == nil {
  146. *exts = make(map[int32]ExtensionField)
  147. }
  148. }
  149. if exts == nil {
  150. break
  151. }
  152. var o unmarshalOutput
  153. o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
  154. if err != nil {
  155. break
  156. }
  157. n = o.n
  158. if !o.initialized {
  159. initialized = false
  160. }
  161. }
  162. if err != nil {
  163. if err != errUnknown {
  164. return out, err
  165. }
  166. n = protowire.ConsumeFieldValue(num, wtyp, b)
  167. if n < 0 {
  168. return out, errDecode
  169. }
  170. if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
  171. u := mi.mutableUnknownBytes(p)
  172. *u = protowire.AppendTag(*u, num, wtyp)
  173. *u = append(*u, b[:n]...)
  174. }
  175. }
  176. b = b[n:]
  177. }
  178. if groupTag != 0 {
  179. return out, errDecode
  180. }
  181. if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
  182. initialized = false
  183. }
  184. if initialized {
  185. out.initialized = true
  186. }
  187. out.n = start - len(b)
  188. return out, nil
  189. }
  190. func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
  191. x := exts[int32(num)]
  192. xt := x.Type()
  193. if xt == nil {
  194. var err error
  195. xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
  196. if err != nil {
  197. if err == preg.NotFound {
  198. return out, errUnknown
  199. }
  200. return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
  201. }
  202. }
  203. xi := getExtensionFieldInfo(xt)
  204. if xi.funcs.unmarshal == nil {
  205. return out, errUnknown
  206. }
  207. if flags.LazyUnmarshalExtensions {
  208. if opts.IsDefault() && x.canLazy(xt) {
  209. out, valid := skipExtension(b, xi, num, wtyp, opts)
  210. switch valid {
  211. case ValidationValid:
  212. if out.initialized {
  213. x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
  214. exts[int32(num)] = x
  215. return out, nil
  216. }
  217. case ValidationInvalid:
  218. return out, errDecode
  219. case ValidationUnknown:
  220. }
  221. }
  222. }
  223. ival := x.Value()
  224. if !ival.IsValid() && xi.unmarshalNeedsValue {
  225. // Create a new message, list, or map value to fill in.
  226. // For enums, create a prototype value to let the unmarshal func know the
  227. // concrete type.
  228. ival = xt.New()
  229. }
  230. v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
  231. if err != nil {
  232. return out, err
  233. }
  234. if xi.funcs.isInit == nil {
  235. out.initialized = true
  236. }
  237. x.Set(xt, v)
  238. exts[int32(num)] = x
  239. return out, nil
  240. }
  241. func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
  242. if xi.validation.mi == nil {
  243. return out, ValidationUnknown
  244. }
  245. xi.validation.mi.init()
  246. switch xi.validation.typ {
  247. case validationTypeMessage:
  248. if wtyp != protowire.BytesType {
  249. return out, ValidationUnknown
  250. }
  251. v, n := protowire.ConsumeBytes(b)
  252. if n < 0 {
  253. return out, ValidationUnknown
  254. }
  255. out, st := xi.validation.mi.validate(v, 0, opts)
  256. out.n = n
  257. return out, st
  258. case validationTypeGroup:
  259. if wtyp != protowire.StartGroupType {
  260. return out, ValidationUnknown
  261. }
  262. out, st := xi.validation.mi.validate(b, num, opts)
  263. return out, st
  264. default:
  265. return out, ValidationUnknown
  266. }
  267. }