update.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package source
  2. import (
  3. "bytes"
  4. "errors"
  5. "flag"
  6. "fmt"
  7. "go/ast"
  8. "go/format"
  9. "go/parser"
  10. "go/token"
  11. "os"
  12. "runtime"
  13. "strings"
  14. )
  15. // Update is set by the -update flag. It indicates the user running the tests
  16. // would like to update any golden values.
  17. var Update bool
  18. func init() {
  19. flag.BoolVar(&Update, "update", false, "update golden values")
  20. }
  21. // ErrNotFound indicates that UpdateExpectedValue failed to find the
  22. // variable to update, likely because it is not a package level variable.
  23. var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
  24. // UpdateExpectedValue looks for a package-level variable with a name that
  25. // starts with expected in the arguments to the caller. If the variable is
  26. // found, the value of the variable will be updated to value of the other
  27. // argument to the caller.
  28. func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
  29. _, filename, line, ok := runtime.Caller(stackIndex + 1)
  30. if !ok {
  31. return errors.New("failed to get call stack")
  32. }
  33. debug("call stack position: %s:%d", filename, line)
  34. fileset := token.NewFileSet()
  35. astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
  36. if err != nil {
  37. return fmt.Errorf("failed to parse source file %s: %w", filename, err)
  38. }
  39. expr, err := getCallExprArgs(fileset, astFile, line)
  40. if err != nil {
  41. return fmt.Errorf("call from %s:%d: %w", filename, line, err)
  42. }
  43. if len(expr) < 3 {
  44. debug("not enough arguments %d: %v",
  45. len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
  46. return ErrNotFound
  47. }
  48. argIndex, ident := getIdentForExpectedValueArg(expr)
  49. if argIndex < 0 || ident == nil {
  50. debug("no arguments started with the word 'expected': %v",
  51. debugFormatNode{Node: &ast.CallExpr{Args: expr}})
  52. return ErrNotFound
  53. }
  54. value := x
  55. if argIndex == 1 {
  56. value = y
  57. }
  58. strValue, ok := value.(string)
  59. if !ok {
  60. debug("value must be type string, got %T", value)
  61. return ErrNotFound
  62. }
  63. return UpdateVariable(filename, fileset, astFile, ident, strValue)
  64. }
  65. // UpdateVariable writes to filename the contents of astFile with the value of
  66. // the variable updated to value.
  67. func UpdateVariable(
  68. filename string,
  69. fileset *token.FileSet,
  70. astFile *ast.File,
  71. ident *ast.Ident,
  72. value string,
  73. ) error {
  74. obj := ident.Obj
  75. if obj == nil {
  76. return ErrNotFound
  77. }
  78. if obj.Kind != ast.Con && obj.Kind != ast.Var {
  79. debug("can only update var and const, found %v", obj.Kind)
  80. return ErrNotFound
  81. }
  82. switch decl := obj.Decl.(type) {
  83. case *ast.ValueSpec:
  84. if len(decl.Names) != 1 {
  85. debug("more than one name in ast.ValueSpec")
  86. return ErrNotFound
  87. }
  88. decl.Values[0] = &ast.BasicLit{
  89. Kind: token.STRING,
  90. Value: "`" + value + "`",
  91. }
  92. case *ast.AssignStmt:
  93. if len(decl.Lhs) != 1 {
  94. debug("more than one name in ast.AssignStmt")
  95. return ErrNotFound
  96. }
  97. decl.Rhs[0] = &ast.BasicLit{
  98. Kind: token.STRING,
  99. Value: "`" + value + "`",
  100. }
  101. default:
  102. debug("can only update *ast.ValueSpec, found %T", obj.Decl)
  103. return ErrNotFound
  104. }
  105. var buf bytes.Buffer
  106. if err := format.Node(&buf, fileset, astFile); err != nil {
  107. return fmt.Errorf("failed to format file after update: %w", err)
  108. }
  109. fh, err := os.Create(filename)
  110. if err != nil {
  111. return fmt.Errorf("failed to open file %v: %w", filename, err)
  112. }
  113. if _, err = fh.Write(buf.Bytes()); err != nil {
  114. return fmt.Errorf("failed to write file %v: %w", filename, err)
  115. }
  116. if err := fh.Sync(); err != nil {
  117. return fmt.Errorf("failed to sync file %v: %w", filename, err)
  118. }
  119. return nil
  120. }
  121. func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
  122. for i := 1; i < 3; i++ {
  123. switch e := expr[i].(type) {
  124. case *ast.Ident:
  125. if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
  126. return i, e
  127. }
  128. }
  129. }
  130. return -1, nil
  131. }