context_unix.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. //go:build linux || freebsd
  2. package testutils
  3. import (
  4. "fmt"
  5. "runtime"
  6. "strconv"
  7. "testing"
  8. "github.com/docker/docker/libnetwork/ns"
  9. "github.com/pkg/errors"
  10. "github.com/vishvananda/netns"
  11. "golang.org/x/sys/unix"
  12. )
  13. // OSContext is a handle to a test OS context.
  14. type OSContext struct {
  15. origNS, newNS netns.NsHandle
  16. tid int
  17. caller string // The file:line where SetupTestOSContextEx was called, for interpolating into error messages.
  18. }
  19. // SetupTestOSContextEx joins the current goroutine to a new network namespace.
  20. //
  21. // Compared to [SetupTestOSContext], this function allows goroutines to be
  22. // spawned which are associated with the same OS context via the returned
  23. // OSContext value.
  24. //
  25. // Example usage:
  26. //
  27. // c := SetupTestOSContext(t)
  28. // defer c.Cleanup(t)
  29. func SetupTestOSContextEx(t *testing.T) *OSContext {
  30. runtime.LockOSThread()
  31. origNS, err := netns.Get()
  32. if err != nil {
  33. runtime.UnlockOSThread()
  34. t.Fatalf("Failed to open initial netns: %v", err)
  35. }
  36. c := OSContext{
  37. tid: unix.Gettid(),
  38. origNS: origNS,
  39. }
  40. c.newNS, err = netns.New()
  41. if err != nil {
  42. // netns.New() is not atomic: it could have encountered an error
  43. // after unsharing the current thread's network namespace.
  44. c.restore(t)
  45. t.Fatalf("Failed to enter netns: %v", err)
  46. }
  47. // Since we are switching to a new test namespace make
  48. // sure to re-initialize initNs context
  49. ns.Init()
  50. nl := ns.NlHandle()
  51. lo, err := nl.LinkByName("lo")
  52. if err != nil {
  53. c.restore(t)
  54. t.Fatalf("Failed to get handle to loopback interface 'lo' in new netns: %v", err)
  55. }
  56. if err := nl.LinkSetUp(lo); err != nil {
  57. c.restore(t)
  58. t.Fatalf("Failed to enable loopback interface in new netns: %v", err)
  59. }
  60. _, file, line, ok := runtime.Caller(0)
  61. if ok {
  62. c.caller = file + ":" + strconv.Itoa(line)
  63. }
  64. return &c
  65. }
  66. // Cleanup tears down the OS context. It must be called from the same goroutine
  67. // as the [SetupTestOSContextEx] call which returned c.
  68. //
  69. // Explicit cleanup is required as (*testing.T).Cleanup() makes no guarantees
  70. // about which goroutine the cleanup functions are invoked on.
  71. func (c *OSContext) Cleanup(t *testing.T) {
  72. t.Helper()
  73. if unix.Gettid() != c.tid {
  74. t.Fatalf("c.Cleanup() must be called from the same goroutine as SetupTestOSContextEx() (%s)", c.caller)
  75. }
  76. if err := c.newNS.Close(); err != nil {
  77. t.Logf("Warning: netns closing failed (%v)", err)
  78. }
  79. c.restore(t)
  80. ns.Init()
  81. }
  82. func (c *OSContext) restore(t *testing.T) {
  83. t.Helper()
  84. if err := netns.Set(c.origNS); err != nil {
  85. t.Logf("Warning: failed to restore thread netns (%v)", err)
  86. } else {
  87. runtime.UnlockOSThread()
  88. }
  89. if err := c.origNS.Close(); err != nil {
  90. t.Logf("Warning: netns closing failed (%v)", err)
  91. }
  92. }
  93. // Set sets the OS context of the calling goroutine to c and returns a teardown
  94. // function to restore the calling goroutine's OS context and release resources.
  95. // The teardown function accepts an optional Logger argument.
  96. //
  97. // This is a lower-level interface which is less ergonomic than c.Go() but more
  98. // composable with other goroutine-spawning utilities such as [sync.WaitGroup]
  99. // or [golang.org/x/sync/errgroup.Group].
  100. //
  101. // Example usage:
  102. //
  103. // func TestFoo(t *testing.T) {
  104. // osctx := testutils.SetupTestOSContextEx(t)
  105. // defer osctx.Cleanup(t)
  106. // var eg errgroup.Group
  107. // eg.Go(func() error {
  108. // teardown, err := osctx.Set()
  109. // if err != nil {
  110. // return err
  111. // }
  112. // defer teardown(t)
  113. // // ...
  114. // })
  115. // if err := eg.Wait(); err != nil {
  116. // t.Fatalf("%+v", err)
  117. // }
  118. // }
  119. func (c *OSContext) Set() (func(Logger), error) {
  120. runtime.LockOSThread()
  121. orig, err := netns.Get()
  122. if err != nil {
  123. runtime.UnlockOSThread()
  124. return nil, errors.Wrap(err, "failed to open initial netns for goroutine")
  125. }
  126. if err := errors.WithStack(netns.Set(c.newNS)); err != nil {
  127. runtime.UnlockOSThread()
  128. return nil, errors.Wrap(err, "failed to set goroutine network namespace")
  129. }
  130. tid := unix.Gettid()
  131. _, file, line, callerOK := runtime.Caller(0)
  132. return func(log Logger) {
  133. if unix.Gettid() != tid {
  134. msg := "teardown function must be called from the same goroutine as c.Set()"
  135. if callerOK {
  136. msg += fmt.Sprintf(" (%s:%d)", file, line)
  137. }
  138. panic(msg)
  139. }
  140. if err := netns.Set(orig); err != nil && log != nil {
  141. log.Logf("Warning: failed to restore goroutine thread netns (%v)", err)
  142. } else {
  143. runtime.UnlockOSThread()
  144. }
  145. if err := orig.Close(); err != nil && log != nil {
  146. log.Logf("Warning: netns closing failed (%v)", err)
  147. }
  148. }, nil
  149. }