main_test.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. //go:build !windows
  2. package authz // import "github.com/docker/docker/integration/plugin/authz"
  3. import (
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/http/httptest"
  10. "os"
  11. "strings"
  12. "testing"
  13. "github.com/docker/docker/pkg/authorization"
  14. "github.com/docker/docker/pkg/plugins"
  15. "github.com/docker/docker/testutil"
  16. "github.com/docker/docker/testutil/daemon"
  17. "github.com/docker/docker/testutil/environment"
  18. "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
  19. "go.opentelemetry.io/otel"
  20. "go.opentelemetry.io/otel/attribute"
  21. "go.opentelemetry.io/otel/codes"
  22. "gotest.tools/v3/skip"
  23. )
  24. var (
  25. testEnv *environment.Execution
  26. d *daemon.Daemon
  27. server *httptest.Server
  28. baseContext context.Context
  29. )
  30. func TestMain(m *testing.M) {
  31. shutdown := testutil.ConfigureTracing()
  32. ctx, span := otel.Tracer("").Start(context.Background(), "integration/plugin/authz.TestMain")
  33. baseContext = ctx
  34. var err error
  35. testEnv, err = environment.New(ctx)
  36. if err != nil {
  37. span.SetStatus(codes.Error, err.Error())
  38. span.End()
  39. shutdown(ctx)
  40. panic(err)
  41. }
  42. err = environment.EnsureFrozenImagesLinux(ctx, testEnv)
  43. if err != nil {
  44. span.SetStatus(codes.Error, err.Error())
  45. span.End()
  46. shutdown(ctx)
  47. panic(err)
  48. }
  49. testEnv.Print()
  50. setupSuite()
  51. exitCode := m.Run()
  52. teardownSuite()
  53. if exitCode != 0 {
  54. span.SetAttributes(attribute.Int("exit", exitCode))
  55. span.SetStatus(codes.Error, "m.Run() exited with non-zero exit code")
  56. }
  57. shutdown(ctx)
  58. os.Exit(exitCode)
  59. }
  60. func setupTest(t *testing.T) context.Context {
  61. skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
  62. skip.If(t, testEnv.DaemonInfo.OSType == "windows")
  63. skip.If(t, testEnv.IsRootless, "rootless mode has different view of localhost")
  64. ctx := testutil.StartSpan(baseContext, t)
  65. environment.ProtectAll(ctx, t, testEnv)
  66. d = daemon.New(t, daemon.WithExperimental())
  67. t.Cleanup(func() {
  68. if d != nil {
  69. d.Stop(t)
  70. }
  71. testEnv.Clean(ctx, t)
  72. })
  73. return ctx
  74. }
  75. func setupSuite() {
  76. mux := http.NewServeMux()
  77. server = httptest.NewServer(otelhttp.NewHandler(mux, ""))
  78. mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
  79. b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
  80. if err != nil {
  81. panic("could not marshal json for /Plugin.Activate: " + err.Error())
  82. }
  83. w.Write(b)
  84. })
  85. mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
  86. defer r.Body.Close()
  87. body, err := io.ReadAll(r.Body)
  88. if err != nil {
  89. panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
  90. }
  91. authReq := authorization.Request{}
  92. err = json.Unmarshal(body, &authReq)
  93. if err != nil {
  94. panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
  95. }
  96. assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
  97. assertAuthHeaders(authReq.RequestHeaders)
  98. // Count only server version api
  99. if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
  100. ctrl.versionReqCount++
  101. }
  102. ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)
  103. reqRes := ctrl.reqRes
  104. if isAllowed(authReq.RequestURI) {
  105. reqRes = authorization.Response{Allow: true}
  106. }
  107. if reqRes.Err != "" {
  108. w.WriteHeader(http.StatusInternalServerError)
  109. }
  110. b, err := json.Marshal(reqRes)
  111. if err != nil {
  112. panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
  113. }
  114. ctrl.reqUser = authReq.User
  115. w.Write(b)
  116. })
  117. mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
  118. defer r.Body.Close()
  119. body, err := io.ReadAll(r.Body)
  120. if err != nil {
  121. panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
  122. }
  123. authReq := authorization.Request{}
  124. err = json.Unmarshal(body, &authReq)
  125. if err != nil {
  126. panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
  127. }
  128. assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
  129. assertAuthHeaders(authReq.ResponseHeaders)
  130. // Count only server version api
  131. if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
  132. ctrl.versionResCount++
  133. }
  134. resRes := ctrl.resRes
  135. if isAllowed(authReq.RequestURI) {
  136. resRes = authorization.Response{Allow: true}
  137. }
  138. if resRes.Err != "" {
  139. w.WriteHeader(http.StatusInternalServerError)
  140. }
  141. b, err := json.Marshal(resRes)
  142. if err != nil {
  143. panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
  144. }
  145. ctrl.resUser = authReq.User
  146. w.Write(b)
  147. })
  148. }
  149. func teardownSuite() {
  150. if server == nil {
  151. return
  152. }
  153. server.Close()
  154. }
  155. // assertAuthHeaders validates authentication headers are removed
  156. func assertAuthHeaders(headers map[string]string) error {
  157. for k := range headers {
  158. if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
  159. panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
  160. }
  161. }
  162. return nil
  163. }
  164. // assertBody asserts that body is removed for non text/json requests
  165. func assertBody(requestURI string, headers map[string]string, body []byte) {
  166. if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
  167. panic("Body included for authentication endpoint " + string(body))
  168. }
  169. for k, v := range headers {
  170. if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
  171. return
  172. }
  173. }
  174. if len(body) > 0 {
  175. panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
  176. }
  177. }