singleflight.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package singleflight provides a duplicate function call suppression
  5. // mechanism.
  6. package singleflight
  7. import (
  8. "bytes"
  9. "errors"
  10. "fmt"
  11. "runtime"
  12. "runtime/debug"
  13. "sync"
  14. )
  15. // errGoexit indicates the runtime.Goexit was called in
  16. // the user given function.
  17. var errGoexit = errors.New("runtime.Goexit was called")
  18. // A panicError is an arbitrary value recovered from a panic
  19. // with the stack trace during the execution of given function.
  20. type panicError struct {
  21. value any
  22. stack []byte
  23. }
  24. // Error implements error interface.
  25. func (p *panicError) Error() string {
  26. return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
  27. }
  28. func newPanicError(v any) error {
  29. stack := debug.Stack()
  30. // The first line of the stack trace is of the form "goroutine N [status]:"
  31. // but by the time the panic reaches Do the goroutine may no longer exist
  32. // and its status will have changed. Trim out the misleading line.
  33. if line := bytes.IndexByte(stack[:], '\n'); line >= 0 {
  34. stack = stack[line+1:]
  35. }
  36. return &panicError{value: v, stack: stack}
  37. }
  38. // call is an in-flight or completed singleflight.Do call
  39. type call[T any] struct {
  40. wg sync.WaitGroup
  41. // These fields are written once before the WaitGroup is done
  42. // and are only read after the WaitGroup is done.
  43. val T
  44. err error
  45. // forgotten indicates whether Forget was called with this call's key
  46. // while the call was still in flight.
  47. forgotten bool
  48. // These fields are read and written with the singleflight
  49. // mutex held before the WaitGroup is done, and are read but
  50. // not written after the WaitGroup is done.
  51. dups int
  52. chans []chan<- Result[T]
  53. }
  54. // Group represents a class of work and forms a namespace in
  55. // which units of work can be executed with duplicate suppression.
  56. type Group[T any] struct {
  57. mu sync.Mutex // protects m
  58. m map[string]*call[T] // lazily initialized
  59. }
  60. // Result holds the results of Do, so they can be passed
  61. // on a channel.
  62. type Result[T any] struct {
  63. Val T
  64. Err error
  65. Shared bool
  66. }
  67. // Do executes and returns the results of the given function, making
  68. // sure that only one execution is in-flight for a given key at a
  69. // time. If a duplicate comes in, the duplicate caller waits for the
  70. // original to complete and receives the same results.
  71. // The return value shared indicates whether v was given to multiple callers.
  72. func (g *Group[T]) Do(key string, fn func() (T, error)) (v T, err error, shared bool) {
  73. g.mu.Lock()
  74. if g.m == nil {
  75. g.m = make(map[string]*call[T])
  76. }
  77. if c, ok := g.m[key]; ok {
  78. c.dups++
  79. g.mu.Unlock()
  80. c.wg.Wait()
  81. if e, ok := c.err.(*panicError); ok {
  82. panic(e)
  83. } else if c.err == errGoexit {
  84. runtime.Goexit()
  85. }
  86. return c.val, c.err, true
  87. }
  88. c := new(call[T])
  89. c.wg.Add(1)
  90. g.m[key] = c
  91. g.mu.Unlock()
  92. g.doCall(c, key, fn)
  93. return c.val, c.err, c.dups > 0
  94. }
  95. // DoChan is like Do but returns a channel that will receive the
  96. // results when they are ready.
  97. //
  98. // The returned channel will not be closed.
  99. func (g *Group[T]) DoChan(key string, fn func() (T, error)) <-chan Result[T] {
  100. ch := make(chan Result[T], 1)
  101. g.mu.Lock()
  102. if g.m == nil {
  103. g.m = make(map[string]*call[T])
  104. }
  105. if c, ok := g.m[key]; ok {
  106. c.dups++
  107. c.chans = append(c.chans, ch)
  108. g.mu.Unlock()
  109. return ch
  110. }
  111. c := &call[T]{chans: []chan<- Result[T]{ch}}
  112. c.wg.Add(1)
  113. g.m[key] = c
  114. g.mu.Unlock()
  115. go g.doCall(c, key, fn)
  116. return ch
  117. }
  118. // doCall handles the single call for a key.
  119. func (g *Group[T]) doCall(c *call[T], key string, fn func() (T, error)) {
  120. normalReturn := false
  121. recovered := false
  122. // use double-defer to distinguish panic from runtime.Goexit,
  123. // more details see https://golang.org/cl/134395
  124. defer func() {
  125. // the given function invoked runtime.Goexit
  126. if !normalReturn && !recovered {
  127. c.err = errGoexit
  128. }
  129. c.wg.Done()
  130. g.mu.Lock()
  131. defer g.mu.Unlock()
  132. if !c.forgotten {
  133. delete(g.m, key)
  134. }
  135. if e, ok := c.err.(*panicError); ok {
  136. // In order to prevent the waiting channels from being blocked forever,
  137. // needs to ensure that this panic cannot be recovered.
  138. if len(c.chans) > 0 {
  139. go panic(e)
  140. select {} // Keep this goroutine around so that it will appear in the crash dump.
  141. } else {
  142. panic(e)
  143. }
  144. } else if c.err == errGoexit {
  145. // Already in the process of goexit, no need to call again
  146. } else {
  147. // Normal return
  148. for _, ch := range c.chans {
  149. ch <- Result[T]{c.val, c.err, c.dups > 0}
  150. }
  151. }
  152. }()
  153. func() {
  154. defer func() {
  155. if !normalReturn {
  156. // Ideally, we would wait to take a stack trace until we've determined
  157. // whether this is a panic or a runtime.Goexit.
  158. //
  159. // Unfortunately, the only way we can distinguish the two is to see
  160. // whether the recover stopped the goroutine from terminating, and by
  161. // the time we know that, the part of the stack trace relevant to the
  162. // panic has been discarded.
  163. if r := recover(); r != nil {
  164. c.err = newPanicError(r)
  165. }
  166. }
  167. }()
  168. c.val, c.err = fn()
  169. normalReturn = true
  170. }()
  171. if !normalReturn {
  172. recovered = true
  173. }
  174. }
  175. // Forget tells the singleflight to forget about a key. Future calls
  176. // to Do for this key will call the function rather than waiting for
  177. // an earlier call to complete.
  178. func (g *Group[T]) Forget(key string) {
  179. g.mu.Lock()
  180. if c, ok := g.m[key]; ok {
  181. c.forgotten = true
  182. }
  183. delete(g.m, key)
  184. g.mu.Unlock()
  185. }