query.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. package runtime
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "net/url"
  6. "reflect"
  7. "regexp"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/golang/protobuf/proto"
  12. "github.com/grpc-ecosystem/grpc-gateway/utilities"
  13. "google.golang.org/grpc/grpclog"
  14. )
  15. var valuesKeyRegexp = regexp.MustCompile("^(.*)\\[(.*)\\]$")
  16. var currentQueryParser QueryParameterParser = &defaultQueryParser{}
  17. // QueryParameterParser defines interface for all query parameter parsers
  18. type QueryParameterParser interface {
  19. Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
  20. }
  21. // PopulateQueryParameters parses query parameters
  22. // into "msg" using current query parser
  23. func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
  24. return currentQueryParser.Parse(msg, values, filter)
  25. }
  26. type defaultQueryParser struct{}
  27. // Parse populates "values" into "msg".
  28. // A value is ignored if its key starts with one of the elements in "filter".
  29. func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
  30. for key, values := range values {
  31. match := valuesKeyRegexp.FindStringSubmatch(key)
  32. if len(match) == 3 {
  33. key = match[1]
  34. values = append([]string{match[2]}, values...)
  35. }
  36. fieldPath := strings.Split(key, ".")
  37. if filter.HasCommonPrefix(fieldPath) {
  38. continue
  39. }
  40. if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
  41. return err
  42. }
  43. }
  44. return nil
  45. }
  46. // PopulateFieldFromPath sets a value in a nested Protobuf structure.
  47. // It instantiates missing protobuf fields as it goes.
  48. func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
  49. fieldPath := strings.Split(fieldPathString, ".")
  50. return populateFieldValueFromPath(msg, fieldPath, []string{value})
  51. }
  52. func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
  53. m := reflect.ValueOf(msg)
  54. if m.Kind() != reflect.Ptr {
  55. return fmt.Errorf("unexpected type %T: %v", msg, msg)
  56. }
  57. var props *proto.Properties
  58. m = m.Elem()
  59. for i, fieldName := range fieldPath {
  60. isLast := i == len(fieldPath)-1
  61. if !isLast && m.Kind() != reflect.Struct {
  62. return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
  63. }
  64. var f reflect.Value
  65. var err error
  66. f, props, err = fieldByProtoName(m, fieldName)
  67. if err != nil {
  68. return err
  69. } else if !f.IsValid() {
  70. grpclog.Infof("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
  71. return nil
  72. }
  73. switch f.Kind() {
  74. case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
  75. if !isLast {
  76. return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
  77. }
  78. m = f
  79. case reflect.Slice:
  80. if !isLast {
  81. return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
  82. }
  83. // Handle []byte
  84. if f.Type().Elem().Kind() == reflect.Uint8 {
  85. m = f
  86. break
  87. }
  88. return populateRepeatedField(f, values, props)
  89. case reflect.Ptr:
  90. if f.IsNil() {
  91. m = reflect.New(f.Type().Elem())
  92. f.Set(m.Convert(f.Type()))
  93. }
  94. m = f.Elem()
  95. continue
  96. case reflect.Struct:
  97. m = f
  98. continue
  99. case reflect.Map:
  100. if !isLast {
  101. return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
  102. }
  103. return populateMapField(f, values, props)
  104. default:
  105. return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
  106. }
  107. }
  108. switch len(values) {
  109. case 0:
  110. return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
  111. case 1:
  112. default:
  113. grpclog.Infof("too many field values: %s", strings.Join(fieldPath, "."))
  114. }
  115. return populateField(m, values[0], props)
  116. }
  117. // fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
  118. // "m" must be a struct value. It returns zero reflect.Value if no such field found.
  119. func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) {
  120. props := proto.GetProperties(m.Type())
  121. // look up field name in oneof map
  122. for _, op := range props.OneofTypes {
  123. if name == op.Prop.OrigName || name == op.Prop.JSONName {
  124. v := reflect.New(op.Type.Elem())
  125. field := m.Field(op.Field)
  126. if !field.IsNil() {
  127. return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName)
  128. }
  129. field.Set(v)
  130. return v.Elem().Field(0), op.Prop, nil
  131. }
  132. }
  133. for _, p := range props.Prop {
  134. if p.OrigName == name {
  135. return m.FieldByName(p.Name), p, nil
  136. }
  137. if p.JSONName == name {
  138. return m.FieldByName(p.Name), p, nil
  139. }
  140. }
  141. return reflect.Value{}, nil, nil
  142. }
  143. func populateMapField(f reflect.Value, values []string, props *proto.Properties) error {
  144. if len(values) != 2 {
  145. return fmt.Errorf("more than one value provided for key %s in map %s", values[0], props.Name)
  146. }
  147. key, value := values[0], values[1]
  148. keyType := f.Type().Key()
  149. valueType := f.Type().Elem()
  150. if f.IsNil() {
  151. f.Set(reflect.MakeMap(f.Type()))
  152. }
  153. keyConv, ok := convFromType[keyType.Kind()]
  154. if !ok {
  155. return fmt.Errorf("unsupported key type %s in map %s", keyType, props.Name)
  156. }
  157. valueConv, ok := convFromType[valueType.Kind()]
  158. if !ok {
  159. return fmt.Errorf("unsupported value type %s in map %s", valueType, props.Name)
  160. }
  161. keyV := keyConv.Call([]reflect.Value{reflect.ValueOf(key)})
  162. if err := keyV[1].Interface(); err != nil {
  163. return err.(error)
  164. }
  165. valueV := valueConv.Call([]reflect.Value{reflect.ValueOf(value)})
  166. if err := valueV[1].Interface(); err != nil {
  167. return err.(error)
  168. }
  169. f.SetMapIndex(keyV[0].Convert(keyType), valueV[0].Convert(valueType))
  170. return nil
  171. }
  172. func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error {
  173. elemType := f.Type().Elem()
  174. // is the destination field a slice of an enumeration type?
  175. if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
  176. return populateFieldEnumRepeated(f, values, enumValMap)
  177. }
  178. conv, ok := convFromType[elemType.Kind()]
  179. if !ok {
  180. return fmt.Errorf("unsupported field type %s", elemType)
  181. }
  182. f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
  183. for i, v := range values {
  184. result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
  185. if err := result[1].Interface(); err != nil {
  186. return err.(error)
  187. }
  188. f.Index(i).Set(result[0].Convert(f.Index(i).Type()))
  189. }
  190. return nil
  191. }
  192. func populateField(f reflect.Value, value string, props *proto.Properties) error {
  193. i := f.Addr().Interface()
  194. // Handle protobuf well known types
  195. var name string
  196. switch m := i.(type) {
  197. case interface{ XXX_WellKnownType() string }:
  198. name = m.XXX_WellKnownType()
  199. case proto.Message:
  200. const wktPrefix = "google.protobuf."
  201. if fullName := proto.MessageName(m); strings.HasPrefix(fullName, wktPrefix) {
  202. name = fullName[len(wktPrefix):]
  203. }
  204. }
  205. switch name {
  206. case "Timestamp":
  207. if value == "null" {
  208. f.FieldByName("Seconds").SetInt(0)
  209. f.FieldByName("Nanos").SetInt(0)
  210. return nil
  211. }
  212. t, err := time.Parse(time.RFC3339Nano, value)
  213. if err != nil {
  214. return fmt.Errorf("bad Timestamp: %v", err)
  215. }
  216. f.FieldByName("Seconds").SetInt(int64(t.Unix()))
  217. f.FieldByName("Nanos").SetInt(int64(t.Nanosecond()))
  218. return nil
  219. case "Duration":
  220. if value == "null" {
  221. f.FieldByName("Seconds").SetInt(0)
  222. f.FieldByName("Nanos").SetInt(0)
  223. return nil
  224. }
  225. d, err := time.ParseDuration(value)
  226. if err != nil {
  227. return fmt.Errorf("bad Duration: %v", err)
  228. }
  229. ns := d.Nanoseconds()
  230. s := ns / 1e9
  231. ns %= 1e9
  232. f.FieldByName("Seconds").SetInt(s)
  233. f.FieldByName("Nanos").SetInt(ns)
  234. return nil
  235. case "DoubleValue":
  236. fallthrough
  237. case "FloatValue":
  238. float64Val, err := strconv.ParseFloat(value, 64)
  239. if err != nil {
  240. return fmt.Errorf("bad DoubleValue: %s", value)
  241. }
  242. f.FieldByName("Value").SetFloat(float64Val)
  243. return nil
  244. case "Int64Value":
  245. fallthrough
  246. case "Int32Value":
  247. int64Val, err := strconv.ParseInt(value, 10, 64)
  248. if err != nil {
  249. return fmt.Errorf("bad DoubleValue: %s", value)
  250. }
  251. f.FieldByName("Value").SetInt(int64Val)
  252. return nil
  253. case "UInt64Value":
  254. fallthrough
  255. case "UInt32Value":
  256. uint64Val, err := strconv.ParseUint(value, 10, 64)
  257. if err != nil {
  258. return fmt.Errorf("bad DoubleValue: %s", value)
  259. }
  260. f.FieldByName("Value").SetUint(uint64Val)
  261. return nil
  262. case "BoolValue":
  263. if value == "true" {
  264. f.FieldByName("Value").SetBool(true)
  265. } else if value == "false" {
  266. f.FieldByName("Value").SetBool(false)
  267. } else {
  268. return fmt.Errorf("bad BoolValue: %s", value)
  269. }
  270. return nil
  271. case "StringValue":
  272. f.FieldByName("Value").SetString(value)
  273. return nil
  274. case "BytesValue":
  275. bytesVal, err := base64.StdEncoding.DecodeString(value)
  276. if err != nil {
  277. return fmt.Errorf("bad BytesValue: %s", value)
  278. }
  279. f.FieldByName("Value").SetBytes(bytesVal)
  280. return nil
  281. case "FieldMask":
  282. p := f.FieldByName("Paths")
  283. for _, v := range strings.Split(value, ",") {
  284. if v != "" {
  285. p.Set(reflect.Append(p, reflect.ValueOf(v)))
  286. }
  287. }
  288. return nil
  289. }
  290. // Handle Time and Duration stdlib types
  291. switch t := i.(type) {
  292. case *time.Time:
  293. pt, err := time.Parse(time.RFC3339Nano, value)
  294. if err != nil {
  295. return fmt.Errorf("bad Timestamp: %v", err)
  296. }
  297. *t = pt
  298. return nil
  299. case *time.Duration:
  300. d, err := time.ParseDuration(value)
  301. if err != nil {
  302. return fmt.Errorf("bad Duration: %v", err)
  303. }
  304. *t = d
  305. return nil
  306. }
  307. // is the destination field an enumeration type?
  308. if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
  309. return populateFieldEnum(f, value, enumValMap)
  310. }
  311. conv, ok := convFromType[f.Kind()]
  312. if !ok {
  313. return fmt.Errorf("field type %T is not supported in query parameters", i)
  314. }
  315. result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
  316. if err := result[1].Interface(); err != nil {
  317. return err.(error)
  318. }
  319. f.Set(result[0].Convert(f.Type()))
  320. return nil
  321. }
  322. func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) {
  323. // see if it's an enumeration string
  324. if enumVal, ok := enumValMap[value]; ok {
  325. return reflect.ValueOf(enumVal).Convert(t), nil
  326. }
  327. // check for an integer that matches an enumeration value
  328. eVal, err := strconv.Atoi(value)
  329. if err != nil {
  330. return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
  331. }
  332. for _, v := range enumValMap {
  333. if v == int32(eVal) {
  334. return reflect.ValueOf(eVal).Convert(t), nil
  335. }
  336. }
  337. return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
  338. }
  339. func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error {
  340. cval, err := convertEnum(value, f.Type(), enumValMap)
  341. if err != nil {
  342. return err
  343. }
  344. f.Set(cval)
  345. return nil
  346. }
  347. func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error {
  348. elemType := f.Type().Elem()
  349. f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
  350. for i, v := range values {
  351. result, err := convertEnum(v, elemType, enumValMap)
  352. if err != nil {
  353. return err
  354. }
  355. f.Index(i).Set(result)
  356. }
  357. return nil
  358. }
  359. var (
  360. convFromType = map[reflect.Kind]reflect.Value{
  361. reflect.String: reflect.ValueOf(String),
  362. reflect.Bool: reflect.ValueOf(Bool),
  363. reflect.Float64: reflect.ValueOf(Float64),
  364. reflect.Float32: reflect.ValueOf(Float32),
  365. reflect.Int64: reflect.ValueOf(Int64),
  366. reflect.Int32: reflect.ValueOf(Int32),
  367. reflect.Uint64: reflect.ValueOf(Uint64),
  368. reflect.Uint32: reflect.ValueOf(Uint32),
  369. reflect.Slice: reflect.ValueOf(Bytes),
  370. }
  371. )