update.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package source
  2. import (
  3. "bytes"
  4. "errors"
  5. "flag"
  6. "fmt"
  7. "go/ast"
  8. "go/format"
  9. "go/parser"
  10. "go/token"
  11. "os"
  12. "runtime"
  13. "strings"
  14. )
  15. // IsUpdate is returns true if the -update flag is set. It indicates the user
  16. // running the tests would like to update any golden values.
  17. func IsUpdate() bool {
  18. if Update {
  19. return true
  20. }
  21. return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
  22. }
  23. // Update is a shim for testing, and for compatibility with the old -update-golden
  24. // flag.
  25. var Update bool
  26. func init() {
  27. if f := flag.Lookup("update"); f != nil {
  28. getter, ok := f.Value.(flag.Getter)
  29. msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
  30. if !ok {
  31. panic(msg)
  32. }
  33. if _, ok := getter.Get().(bool); !ok {
  34. panic(msg)
  35. }
  36. return
  37. }
  38. flag.Bool("update", false, "update golden values")
  39. }
  40. // ErrNotFound indicates that UpdateExpectedValue failed to find the
  41. // variable to update, likely because it is not a package level variable.
  42. var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
  43. // UpdateExpectedValue looks for a package-level variable with a name that
  44. // starts with expected in the arguments to the caller. If the variable is
  45. // found, the value of the variable will be updated to value of the other
  46. // argument to the caller.
  47. func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
  48. _, filename, line, ok := runtime.Caller(stackIndex + 1)
  49. if !ok {
  50. return errors.New("failed to get call stack")
  51. }
  52. debug("call stack position: %s:%d", filename, line)
  53. fileset := token.NewFileSet()
  54. astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
  55. if err != nil {
  56. return fmt.Errorf("failed to parse source file %s: %w", filename, err)
  57. }
  58. expr, err := getCallExprArgs(fileset, astFile, line)
  59. if err != nil {
  60. return fmt.Errorf("call from %s:%d: %w", filename, line, err)
  61. }
  62. if len(expr) < 3 {
  63. debug("not enough arguments %d: %v",
  64. len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
  65. return ErrNotFound
  66. }
  67. argIndex, ident := getIdentForExpectedValueArg(expr)
  68. if argIndex < 0 || ident == nil {
  69. debug("no arguments started with the word 'expected': %v",
  70. debugFormatNode{Node: &ast.CallExpr{Args: expr}})
  71. return ErrNotFound
  72. }
  73. value := x
  74. if argIndex == 1 {
  75. value = y
  76. }
  77. strValue, ok := value.(string)
  78. if !ok {
  79. debug("value must be type string, got %T", value)
  80. return ErrNotFound
  81. }
  82. return UpdateVariable(filename, fileset, astFile, ident, strValue)
  83. }
  84. // UpdateVariable writes to filename the contents of astFile with the value of
  85. // the variable updated to value.
  86. func UpdateVariable(
  87. filename string,
  88. fileset *token.FileSet,
  89. astFile *ast.File,
  90. ident *ast.Ident,
  91. value string,
  92. ) error {
  93. obj := ident.Obj
  94. if obj == nil {
  95. return ErrNotFound
  96. }
  97. if obj.Kind != ast.Con && obj.Kind != ast.Var {
  98. debug("can only update var and const, found %v", obj.Kind)
  99. return ErrNotFound
  100. }
  101. switch decl := obj.Decl.(type) {
  102. case *ast.ValueSpec:
  103. if len(decl.Names) != 1 {
  104. debug("more than one name in ast.ValueSpec")
  105. return ErrNotFound
  106. }
  107. decl.Values[0] = &ast.BasicLit{
  108. Kind: token.STRING,
  109. Value: "`" + value + "`",
  110. }
  111. case *ast.AssignStmt:
  112. if len(decl.Lhs) != 1 {
  113. debug("more than one name in ast.AssignStmt")
  114. return ErrNotFound
  115. }
  116. decl.Rhs[0] = &ast.BasicLit{
  117. Kind: token.STRING,
  118. Value: "`" + value + "`",
  119. }
  120. default:
  121. debug("can only update *ast.ValueSpec, found %T", obj.Decl)
  122. return ErrNotFound
  123. }
  124. var buf bytes.Buffer
  125. if err := format.Node(&buf, fileset, astFile); err != nil {
  126. return fmt.Errorf("failed to format file after update: %w", err)
  127. }
  128. fh, err := os.Create(filename)
  129. if err != nil {
  130. return fmt.Errorf("failed to open file %v: %w", filename, err)
  131. }
  132. if _, err = fh.Write(buf.Bytes()); err != nil {
  133. return fmt.Errorf("failed to write file %v: %w", filename, err)
  134. }
  135. if err := fh.Sync(); err != nil {
  136. return fmt.Errorf("failed to sync file %v: %w", filename, err)
  137. }
  138. return nil
  139. }
  140. func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
  141. for i := 1; i < 3; i++ {
  142. switch e := expr[i].(type) {
  143. case *ast.Ident:
  144. if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
  145. return i, e
  146. }
  147. }
  148. }
  149. return -1, nil
  150. }