compare.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. /*Package cmp provides Comparisons for Assert and Check*/
  2. package cmp // import "gotest.tools/assert/cmp"
  3. import (
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. "github.com/google/go-cmp/cmp"
  8. "gotest.tools/internal/format"
  9. )
  10. // Comparison is a function which compares values and returns ResultSuccess if
  11. // the actual value matches the expected value. If the values do not match the
  12. // Result will contain a message about why it failed.
  13. type Comparison func() Result
  14. // DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp)
  15. // and succeeds if the values are equal.
  16. //
  17. // The comparison can be customized using comparison Options.
  18. // Package https://godoc.org/gotest.tools/assert/opt provides some additional
  19. // commonly used Options.
  20. func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
  21. return func() (result Result) {
  22. defer func() {
  23. if panicmsg, handled := handleCmpPanic(recover()); handled {
  24. result = ResultFailure(panicmsg)
  25. }
  26. }()
  27. diff := cmp.Diff(x, y, opts...)
  28. if diff == "" {
  29. return ResultSuccess
  30. }
  31. return multiLineDiffResult(diff)
  32. }
  33. }
  34. func handleCmpPanic(r interface{}) (string, bool) {
  35. if r == nil {
  36. return "", false
  37. }
  38. panicmsg, ok := r.(string)
  39. if !ok {
  40. panic(r)
  41. }
  42. switch {
  43. case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
  44. return panicmsg, true
  45. }
  46. panic(r)
  47. }
  48. func toResult(success bool, msg string) Result {
  49. if success {
  50. return ResultSuccess
  51. }
  52. return ResultFailure(msg)
  53. }
  54. // Equal succeeds if x == y. See assert.Equal for full documentation.
  55. func Equal(x, y interface{}) Comparison {
  56. return func() Result {
  57. switch {
  58. case x == y:
  59. return ResultSuccess
  60. case isMultiLineStringCompare(x, y):
  61. diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
  62. return multiLineDiffResult(diff)
  63. }
  64. return ResultFailureTemplate(`
  65. {{- .Data.x}} (
  66. {{- with callArg 0 }}{{ formatNode . }} {{end -}}
  67. {{- printf "%T" .Data.x -}}
  68. ) != {{ .Data.y}} (
  69. {{- with callArg 1 }}{{ formatNode . }} {{end -}}
  70. {{- printf "%T" .Data.y -}}
  71. )`,
  72. map[string]interface{}{"x": x, "y": y})
  73. }
  74. }
  75. func isMultiLineStringCompare(x, y interface{}) bool {
  76. strX, ok := x.(string)
  77. if !ok {
  78. return false
  79. }
  80. strY, ok := y.(string)
  81. if !ok {
  82. return false
  83. }
  84. return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
  85. }
  86. func multiLineDiffResult(diff string) Result {
  87. return ResultFailureTemplate(`
  88. --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
  89. +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
  90. {{ .Data.diff }}`,
  91. map[string]interface{}{"diff": diff})
  92. }
  93. // Len succeeds if the sequence has the expected length.
  94. func Len(seq interface{}, expected int) Comparison {
  95. return func() (result Result) {
  96. defer func() {
  97. if e := recover(); e != nil {
  98. result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
  99. }
  100. }()
  101. value := reflect.ValueOf(seq)
  102. length := value.Len()
  103. if length == expected {
  104. return ResultSuccess
  105. }
  106. msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
  107. return ResultFailure(msg)
  108. }
  109. }
  110. // Contains succeeds if item is in collection. Collection may be a string, map,
  111. // slice, or array.
  112. //
  113. // If collection is a string, item must also be a string, and is compared using
  114. // strings.Contains().
  115. // If collection is a Map, contains will succeed if item is a key in the map.
  116. // If collection is a slice or array, item is compared to each item in the
  117. // sequence using reflect.DeepEqual().
  118. func Contains(collection interface{}, item interface{}) Comparison {
  119. return func() Result {
  120. colValue := reflect.ValueOf(collection)
  121. if !colValue.IsValid() {
  122. return ResultFailure(fmt.Sprintf("nil does not contain items"))
  123. }
  124. msg := fmt.Sprintf("%v does not contain %v", collection, item)
  125. itemValue := reflect.ValueOf(item)
  126. switch colValue.Type().Kind() {
  127. case reflect.String:
  128. if itemValue.Type().Kind() != reflect.String {
  129. return ResultFailure("string may only contain strings")
  130. }
  131. return toResult(
  132. strings.Contains(colValue.String(), itemValue.String()),
  133. fmt.Sprintf("string %q does not contain %q", collection, item))
  134. case reflect.Map:
  135. if itemValue.Type() != colValue.Type().Key() {
  136. return ResultFailure(fmt.Sprintf(
  137. "%v can not contain a %v key", colValue.Type(), itemValue.Type()))
  138. }
  139. return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
  140. case reflect.Slice, reflect.Array:
  141. for i := 0; i < colValue.Len(); i++ {
  142. if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
  143. return ResultSuccess
  144. }
  145. }
  146. return ResultFailure(msg)
  147. default:
  148. return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
  149. }
  150. }
  151. }
  152. // Panics succeeds if f() panics.
  153. func Panics(f func()) Comparison {
  154. return func() (result Result) {
  155. defer func() {
  156. if err := recover(); err != nil {
  157. result = ResultSuccess
  158. }
  159. }()
  160. f()
  161. return ResultFailure("did not panic")
  162. }
  163. }
  164. // Error succeeds if err is a non-nil error, and the error message equals the
  165. // expected message.
  166. func Error(err error, message string) Comparison {
  167. return func() Result {
  168. switch {
  169. case err == nil:
  170. return ResultFailure("expected an error, got nil")
  171. case err.Error() != message:
  172. return ResultFailure(fmt.Sprintf(
  173. "expected error %q, got %+v", message, err))
  174. }
  175. return ResultSuccess
  176. }
  177. }
  178. // ErrorContains succeeds if err is a non-nil error, and the error message contains
  179. // the expected substring.
  180. func ErrorContains(err error, substring string) Comparison {
  181. return func() Result {
  182. switch {
  183. case err == nil:
  184. return ResultFailure("expected an error, got nil")
  185. case !strings.Contains(err.Error(), substring):
  186. return ResultFailure(fmt.Sprintf(
  187. "expected error to contain %q, got %+v", substring, err))
  188. }
  189. return ResultSuccess
  190. }
  191. }
  192. // Nil succeeds if obj is a nil interface, pointer, or function.
  193. //
  194. // Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
  195. // maps, and channels.
  196. func Nil(obj interface{}) Comparison {
  197. msgFunc := func(value reflect.Value) string {
  198. return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
  199. }
  200. return isNil(obj, msgFunc)
  201. }
  202. func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
  203. return func() Result {
  204. if obj == nil {
  205. return ResultSuccess
  206. }
  207. value := reflect.ValueOf(obj)
  208. kind := value.Type().Kind()
  209. if kind >= reflect.Chan && kind <= reflect.Slice {
  210. if value.IsNil() {
  211. return ResultSuccess
  212. }
  213. return ResultFailure(msgFunc(value))
  214. }
  215. return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
  216. }
  217. }
  218. // ErrorType succeeds if err is not nil and is of the expected type.
  219. //
  220. // Expected can be one of:
  221. // a func(error) bool which returns true if the error is the expected type,
  222. // an instance of (or a pointer to) a struct of the expected type,
  223. // a pointer to an interface the error is expected to implement,
  224. // a reflect.Type of the expected struct or interface.
  225. func ErrorType(err error, expected interface{}) Comparison {
  226. return func() Result {
  227. switch expectedType := expected.(type) {
  228. case func(error) bool:
  229. return cmpErrorTypeFunc(err, expectedType)
  230. case reflect.Type:
  231. if expectedType.Kind() == reflect.Interface {
  232. return cmpErrorTypeImplementsType(err, expectedType)
  233. }
  234. return cmpErrorTypeEqualType(err, expectedType)
  235. case nil:
  236. return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
  237. }
  238. expectedType := reflect.TypeOf(expected)
  239. switch {
  240. case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
  241. return cmpErrorTypeEqualType(err, expectedType)
  242. case isPtrToInterface(expectedType):
  243. return cmpErrorTypeImplementsType(err, expectedType.Elem())
  244. }
  245. return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
  246. }
  247. }
  248. func cmpErrorTypeFunc(err error, f func(error) bool) Result {
  249. if f(err) {
  250. return ResultSuccess
  251. }
  252. actual := "nil"
  253. if err != nil {
  254. actual = fmt.Sprintf("%s (%T)", err, err)
  255. }
  256. return ResultFailureTemplate(`error is {{ .Data.actual }}
  257. {{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
  258. map[string]interface{}{"actual": actual})
  259. }
  260. func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
  261. if err == nil {
  262. return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
  263. }
  264. errValue := reflect.ValueOf(err)
  265. if errValue.Type() == expectedType {
  266. return ResultSuccess
  267. }
  268. return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
  269. }
  270. func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
  271. if err == nil {
  272. return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
  273. }
  274. errValue := reflect.ValueOf(err)
  275. if errValue.Type().Implements(expectedType) {
  276. return ResultSuccess
  277. }
  278. return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
  279. }
  280. func isPtrToInterface(typ reflect.Type) bool {
  281. return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
  282. }
  283. func isPtrToStruct(typ reflect.Type) bool {
  284. return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
  285. }