context_unix.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. //go:build linux || freebsd
  2. package netnsutils
  3. import (
  4. "fmt"
  5. "runtime"
  6. "strconv"
  7. "testing"
  8. "github.com/docker/docker/internal/testutils"
  9. "github.com/docker/docker/libnetwork/ns"
  10. "github.com/pkg/errors"
  11. "github.com/vishvananda/netns"
  12. "golang.org/x/sys/unix"
  13. )
  14. // OSContext is a handle to a test OS context.
  15. type OSContext struct {
  16. origNS, newNS netns.NsHandle
  17. tid int
  18. caller string // The file:line where SetupTestOSContextEx was called, for interpolating into error messages.
  19. }
  20. // SetupTestOSContext joins the current goroutine to a new network namespace,
  21. // and returns its associated teardown function.
  22. //
  23. // Example usage:
  24. //
  25. // defer SetupTestOSContext(t)()
  26. func SetupTestOSContext(t *testing.T) func() {
  27. c := SetupTestOSContextEx(t)
  28. return func() { c.Cleanup(t) }
  29. }
  30. // SetupTestOSContextEx joins the current goroutine to a new network namespace.
  31. //
  32. // Compared to [SetupTestOSContext], this function allows goroutines to be
  33. // spawned which are associated with the same OS context via the returned
  34. // OSContext value.
  35. //
  36. // Example usage:
  37. //
  38. // c := SetupTestOSContext(t)
  39. // defer c.Cleanup(t)
  40. func SetupTestOSContextEx(t *testing.T) *OSContext {
  41. runtime.LockOSThread()
  42. origNS, err := netns.Get()
  43. if err != nil {
  44. runtime.UnlockOSThread()
  45. t.Fatalf("Failed to open initial netns: %v", err)
  46. }
  47. c := OSContext{
  48. tid: unix.Gettid(),
  49. origNS: origNS,
  50. }
  51. c.newNS, err = netns.New()
  52. if err != nil {
  53. // netns.New() is not atomic: it could have encountered an error
  54. // after unsharing the current thread's network namespace.
  55. c.restore(t)
  56. t.Fatalf("Failed to enter netns: %v", err)
  57. }
  58. // Since we are switching to a new test namespace make
  59. // sure to re-initialize initNs context
  60. ns.Init()
  61. nl := ns.NlHandle()
  62. lo, err := nl.LinkByName("lo")
  63. if err != nil {
  64. c.restore(t)
  65. t.Fatalf("Failed to get handle to loopback interface 'lo' in new netns: %v", err)
  66. }
  67. if err := nl.LinkSetUp(lo); err != nil {
  68. c.restore(t)
  69. t.Fatalf("Failed to enable loopback interface in new netns: %v", err)
  70. }
  71. _, file, line, ok := runtime.Caller(0)
  72. if ok {
  73. c.caller = file + ":" + strconv.Itoa(line)
  74. }
  75. return &c
  76. }
  77. // Cleanup tears down the OS context. It must be called from the same goroutine
  78. // as the [SetupTestOSContextEx] call which returned c.
  79. //
  80. // Explicit cleanup is required as (*testing.T).Cleanup() makes no guarantees
  81. // about which goroutine the cleanup functions are invoked on.
  82. func (c *OSContext) Cleanup(t *testing.T) {
  83. t.Helper()
  84. if unix.Gettid() != c.tid {
  85. t.Fatalf("c.Cleanup() must be called from the same goroutine as SetupTestOSContextEx() (%s)", c.caller)
  86. }
  87. if err := c.newNS.Close(); err != nil {
  88. t.Logf("Warning: netns closing failed (%v)", err)
  89. }
  90. c.restore(t)
  91. ns.Init()
  92. }
  93. func (c *OSContext) restore(t *testing.T) {
  94. t.Helper()
  95. if err := netns.Set(c.origNS); err != nil {
  96. t.Logf("Warning: failed to restore thread netns (%v)", err)
  97. } else {
  98. runtime.UnlockOSThread()
  99. }
  100. if err := c.origNS.Close(); err != nil {
  101. t.Logf("Warning: netns closing failed (%v)", err)
  102. }
  103. }
  104. // Set sets the OS context of the calling goroutine to c and returns a teardown
  105. // function to restore the calling goroutine's OS context and release resources.
  106. // The teardown function accepts an optional Logger argument.
  107. //
  108. // This is a lower-level interface which is less ergonomic than c.Go() but more
  109. // composable with other goroutine-spawning utilities such as [sync.WaitGroup]
  110. // or [golang.org/x/sync/errgroup.Group].
  111. //
  112. // Example usage:
  113. //
  114. // func TestFoo(t *testing.T) {
  115. // osctx := testutils.SetupTestOSContextEx(t)
  116. // defer osctx.Cleanup(t)
  117. // var eg errgroup.Group
  118. // eg.Go(func() error {
  119. // teardown, err := osctx.Set()
  120. // if err != nil {
  121. // return err
  122. // }
  123. // defer teardown(t)
  124. // // ...
  125. // })
  126. // if err := eg.Wait(); err != nil {
  127. // t.Fatalf("%+v", err)
  128. // }
  129. // }
  130. func (c *OSContext) Set() (func(testutils.Logger), error) {
  131. runtime.LockOSThread()
  132. orig, err := netns.Get()
  133. if err != nil {
  134. runtime.UnlockOSThread()
  135. return nil, errors.Wrap(err, "failed to open initial netns for goroutine")
  136. }
  137. if err := errors.WithStack(netns.Set(c.newNS)); err != nil {
  138. runtime.UnlockOSThread()
  139. return nil, errors.Wrap(err, "failed to set goroutine network namespace")
  140. }
  141. tid := unix.Gettid()
  142. _, file, line, callerOK := runtime.Caller(0)
  143. return func(log testutils.Logger) {
  144. if unix.Gettid() != tid {
  145. msg := "teardown function must be called from the same goroutine as c.Set()"
  146. if callerOK {
  147. msg += fmt.Sprintf(" (%s:%d)", file, line)
  148. }
  149. panic(msg)
  150. }
  151. if err := netns.Set(orig); err != nil && log != nil {
  152. log.Logf("Warning: failed to restore goroutine thread netns (%v)", err)
  153. } else {
  154. runtime.UnlockOSThread()
  155. }
  156. if err := orig.Close(); err != nil && log != nil {
  157. log.Logf("Warning: netns closing failed (%v)", err)
  158. }
  159. }, nil
  160. }
  161. // Go starts running fn in a new goroutine inside the test OS context.
  162. func (c *OSContext) Go(t *testing.T, fn func()) {
  163. t.Helper()
  164. errCh := make(chan error, 1)
  165. go func() {
  166. teardown, err := c.Set()
  167. if err != nil {
  168. errCh <- err
  169. return
  170. }
  171. defer teardown(t)
  172. close(errCh)
  173. fn()
  174. }()
  175. if err := <-errCh; err != nil {
  176. t.Fatalf("%+v", err)
  177. }
  178. }