main.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package main
  2. import (
  3. "bytes"
  4. "flag"
  5. "fmt"
  6. "go/format"
  7. "os"
  8. "unicode"
  9. "unicode/utf8"
  10. )
  11. type stringSet struct {
  12. values map[string]struct{}
  13. }
  14. func (s stringSet) String() string {
  15. return ""
  16. }
  17. func (s stringSet) Set(value string) error {
  18. s.values[value] = struct{}{}
  19. return nil
  20. }
  21. func (s stringSet) GetValues() map[string]struct{} {
  22. return s.values
  23. }
  24. var (
  25. typeName = flag.String("type", "", "interface type to generate plugin rpc proxy for")
  26. rpcName = flag.String("name", *typeName, "RPC name, set if different from type")
  27. inputFile = flag.String("i", "", "input file path")
  28. outputFile = flag.String("o", *inputFile+"_proxy.go", "output file path")
  29. skipFuncs map[string]struct{}
  30. flSkipFuncs = stringSet{make(map[string]struct{})}
  31. flBuildTags = stringSet{make(map[string]struct{})}
  32. )
  33. func errorOut(msg string, err error) {
  34. if err == nil {
  35. return
  36. }
  37. fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err)
  38. os.Exit(1)
  39. }
  40. func checkFlags() error {
  41. if *outputFile == "" {
  42. return fmt.Errorf("missing required flag `-o`")
  43. }
  44. if *inputFile == "" {
  45. return fmt.Errorf("missing required flag `-i`")
  46. }
  47. return nil
  48. }
  49. func main() {
  50. flag.Var(flSkipFuncs, "skip", "skip parsing for function")
  51. flag.Var(flBuildTags, "tag", "build tags to add to generated files")
  52. flag.Parse()
  53. skipFuncs = flSkipFuncs.GetValues()
  54. errorOut("error", checkFlags())
  55. pkg, err := Parse(*inputFile, *typeName)
  56. errorOut(fmt.Sprintf("error parsing requested type %s", *typeName), err)
  57. analysis := struct {
  58. InterfaceType string
  59. RPCName string
  60. BuildTags map[string]struct{}
  61. *ParsedPkg
  62. }{toLower(*typeName), *rpcName, flBuildTags.GetValues(), pkg}
  63. var buf bytes.Buffer
  64. errorOut("parser error", generatedTempl.Execute(&buf, analysis))
  65. src, err := format.Source(buf.Bytes())
  66. errorOut("error formatting generated source:\n"+buf.String(), err)
  67. errorOut("error writing file", os.WriteFile(*outputFile, src, 0o644))
  68. }
  69. func toLower(s string) string {
  70. if s == "" {
  71. return ""
  72. }
  73. r, n := utf8.DecodeRuneInString(s)
  74. return string(unicode.ToLower(r)) + s[n:]
  75. }