123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- package source
- import (
- "bytes"
- "errors"
- "flag"
- "fmt"
- "go/ast"
- "go/format"
- "go/parser"
- "go/token"
- "os"
- "runtime"
- "strings"
- )
- // IsUpdate is returns true if the -update flag is set. It indicates the user
- // running the tests would like to update any golden values.
- func IsUpdate() bool {
- if Update {
- return true
- }
- return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
- }
- // Update is a shim for testing, and for compatibility with the old -update-golden
- // flag.
- var Update bool
- func init() {
- if f := flag.Lookup("update"); f != nil {
- getter, ok := f.Value.(flag.Getter)
- msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
- if !ok {
- panic(msg)
- }
- if _, ok := getter.Get().(bool); !ok {
- panic(msg)
- }
- return
- }
- flag.Bool("update", false, "update golden values")
- }
- // ErrNotFound indicates that UpdateExpectedValue failed to find the
- // variable to update, likely because it is not a package level variable.
- var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
- // UpdateExpectedValue looks for a package-level variable with a name that
- // starts with expected in the arguments to the caller. If the variable is
- // found, the value of the variable will be updated to value of the other
- // argument to the caller.
- func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
- _, filename, line, ok := runtime.Caller(stackIndex + 1)
- if !ok {
- return errors.New("failed to get call stack")
- }
- debug("call stack position: %s:%d", filename, line)
- fileset := token.NewFileSet()
- astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
- if err != nil {
- return fmt.Errorf("failed to parse source file %s: %w", filename, err)
- }
- expr, err := getCallExprArgs(fileset, astFile, line)
- if err != nil {
- return fmt.Errorf("call from %s:%d: %w", filename, line, err)
- }
- if len(expr) < 3 {
- debug("not enough arguments %d: %v",
- len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
- return ErrNotFound
- }
- argIndex, ident := getIdentForExpectedValueArg(expr)
- if argIndex < 0 || ident == nil {
- debug("no arguments started with the word 'expected': %v",
- debugFormatNode{Node: &ast.CallExpr{Args: expr}})
- return ErrNotFound
- }
- value := x
- if argIndex == 1 {
- value = y
- }
- strValue, ok := value.(string)
- if !ok {
- debug("value must be type string, got %T", value)
- return ErrNotFound
- }
- return UpdateVariable(filename, fileset, astFile, ident, strValue)
- }
- // UpdateVariable writes to filename the contents of astFile with the value of
- // the variable updated to value.
- func UpdateVariable(
- filename string,
- fileset *token.FileSet,
- astFile *ast.File,
- ident *ast.Ident,
- value string,
- ) error {
- obj := ident.Obj
- if obj == nil {
- return ErrNotFound
- }
- if obj.Kind != ast.Con && obj.Kind != ast.Var {
- debug("can only update var and const, found %v", obj.Kind)
- return ErrNotFound
- }
- switch decl := obj.Decl.(type) {
- case *ast.ValueSpec:
- if len(decl.Names) != 1 {
- debug("more than one name in ast.ValueSpec")
- return ErrNotFound
- }
- decl.Values[0] = &ast.BasicLit{
- Kind: token.STRING,
- Value: "`" + value + "`",
- }
- case *ast.AssignStmt:
- if len(decl.Lhs) != 1 {
- debug("more than one name in ast.AssignStmt")
- return ErrNotFound
- }
- decl.Rhs[0] = &ast.BasicLit{
- Kind: token.STRING,
- Value: "`" + value + "`",
- }
- default:
- debug("can only update *ast.ValueSpec, found %T", obj.Decl)
- return ErrNotFound
- }
- var buf bytes.Buffer
- if err := format.Node(&buf, fileset, astFile); err != nil {
- return fmt.Errorf("failed to format file after update: %w", err)
- }
- fh, err := os.Create(filename)
- if err != nil {
- return fmt.Errorf("failed to open file %v: %w", filename, err)
- }
- if _, err = fh.Write(buf.Bytes()); err != nil {
- return fmt.Errorf("failed to write file %v: %w", filename, err)
- }
- if err := fh.Sync(); err != nil {
- return fmt.Errorf("failed to sync file %v: %w", filename, err)
- }
- return nil
- }
- func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
- for i := 1; i < 3; i++ {
- switch e := expr[i].(type) {
- case *ast.Ident:
- if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
- return i, e
- }
- }
- }
- return -1, nil
- }
|