main.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package main
  2. import (
  3. "bytes"
  4. "flag"
  5. "fmt"
  6. "go/format"
  7. "io/ioutil"
  8. "os"
  9. "unicode"
  10. "unicode/utf8"
  11. )
  12. type stringSet struct {
  13. values map[string]struct{}
  14. }
  15. func (s stringSet) String() string {
  16. return ""
  17. }
  18. func (s stringSet) Set(value string) error {
  19. s.values[value] = struct{}{}
  20. return nil
  21. }
  22. func (s stringSet) GetValues() map[string]struct{} {
  23. return s.values
  24. }
  25. var (
  26. typeName = flag.String("type", "", "interface type to generate plugin rpc proxy for")
  27. rpcName = flag.String("name", *typeName, "RPC name, set if different from type")
  28. inputFile = flag.String("i", "", "input file path")
  29. outputFile = flag.String("o", *inputFile+"_proxy.go", "output file path")
  30. skipFuncs map[string]struct{}
  31. flSkipFuncs = stringSet{make(map[string]struct{})}
  32. flBuildTags = stringSet{make(map[string]struct{})}
  33. )
  34. func errorOut(msg string, err error) {
  35. if err == nil {
  36. return
  37. }
  38. fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err)
  39. os.Exit(1)
  40. }
  41. func checkFlags() error {
  42. if *outputFile == "" {
  43. return fmt.Errorf("missing required flag `-o`")
  44. }
  45. if *inputFile == "" {
  46. return fmt.Errorf("missing required flag `-i`")
  47. }
  48. return nil
  49. }
  50. func main() {
  51. flag.Var(flSkipFuncs, "skip", "skip parsing for function")
  52. flag.Var(flBuildTags, "tag", "build tags to add to generated files")
  53. flag.Parse()
  54. skipFuncs = flSkipFuncs.GetValues()
  55. errorOut("error", checkFlags())
  56. pkg, err := Parse(*inputFile, *typeName)
  57. errorOut(fmt.Sprintf("error parsing requested type %s", *typeName), err)
  58. var analysis = struct {
  59. InterfaceType string
  60. RPCName string
  61. BuildTags map[string]struct{}
  62. *ParsedPkg
  63. }{toLower(*typeName), *rpcName, flBuildTags.GetValues(), pkg}
  64. var buf bytes.Buffer
  65. errorOut("parser error", generatedTempl.Execute(&buf, analysis))
  66. src, err := format.Source(buf.Bytes())
  67. errorOut("error formatting generated source:\n"+buf.String(), err)
  68. errorOut("error writing file", ioutil.WriteFile(*outputFile, src, 0644))
  69. }
  70. func toLower(s string) string {
  71. if s == "" {
  72. return ""
  73. }
  74. r, n := utf8.DecodeRuneInString(s)
  75. return string(unicode.ToLower(r)) + s[n:]
  76. }