main_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. //go:build !windows
  2. package authz // import "github.com/docker/docker/integration/plugin/authz"
  3. import (
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "strings"
  11. "testing"
  12. "github.com/docker/docker/pkg/authorization"
  13. "github.com/docker/docker/pkg/plugins"
  14. "github.com/docker/docker/testutil/daemon"
  15. "github.com/docker/docker/testutil/environment"
  16. "gotest.tools/v3/skip"
  17. )
  18. var (
  19. testEnv *environment.Execution
  20. d *daemon.Daemon
  21. server *httptest.Server
  22. )
  23. func TestMain(m *testing.M) {
  24. var err error
  25. testEnv, err = environment.New()
  26. if err != nil {
  27. fmt.Println(err)
  28. os.Exit(1)
  29. }
  30. err = environment.EnsureFrozenImagesLinux(testEnv)
  31. if err != nil {
  32. fmt.Println(err)
  33. os.Exit(1)
  34. }
  35. testEnv.Print()
  36. setupSuite()
  37. exitCode := m.Run()
  38. teardownSuite()
  39. os.Exit(exitCode)
  40. }
  41. func setupTest(t *testing.T) func() {
  42. skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
  43. skip.If(t, testEnv.DaemonInfo.OSType == "windows")
  44. skip.If(t, testEnv.IsRootless, "rootless mode has different view of localhost")
  45. environment.ProtectAll(t, testEnv)
  46. d = daemon.New(t, daemon.WithExperimental())
  47. return func() {
  48. if d != nil {
  49. d.Stop(t)
  50. }
  51. testEnv.Clean(t)
  52. }
  53. }
  54. func setupSuite() {
  55. mux := http.NewServeMux()
  56. server = httptest.NewServer(mux)
  57. mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
  58. b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
  59. if err != nil {
  60. panic("could not marshal json for /Plugin.Activate: " + err.Error())
  61. }
  62. w.Write(b)
  63. })
  64. mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
  65. defer r.Body.Close()
  66. body, err := io.ReadAll(r.Body)
  67. if err != nil {
  68. panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
  69. }
  70. authReq := authorization.Request{}
  71. err = json.Unmarshal(body, &authReq)
  72. if err != nil {
  73. panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
  74. }
  75. assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
  76. assertAuthHeaders(authReq.RequestHeaders)
  77. // Count only server version api
  78. if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
  79. ctrl.versionReqCount++
  80. }
  81. ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)
  82. reqRes := ctrl.reqRes
  83. if isAllowed(authReq.RequestURI) {
  84. reqRes = authorization.Response{Allow: true}
  85. }
  86. if reqRes.Err != "" {
  87. w.WriteHeader(http.StatusInternalServerError)
  88. }
  89. b, err := json.Marshal(reqRes)
  90. if err != nil {
  91. panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
  92. }
  93. ctrl.reqUser = authReq.User
  94. w.Write(b)
  95. })
  96. mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
  97. defer r.Body.Close()
  98. body, err := io.ReadAll(r.Body)
  99. if err != nil {
  100. panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
  101. }
  102. authReq := authorization.Request{}
  103. err = json.Unmarshal(body, &authReq)
  104. if err != nil {
  105. panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
  106. }
  107. assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
  108. assertAuthHeaders(authReq.ResponseHeaders)
  109. // Count only server version api
  110. if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
  111. ctrl.versionResCount++
  112. }
  113. resRes := ctrl.resRes
  114. if isAllowed(authReq.RequestURI) {
  115. resRes = authorization.Response{Allow: true}
  116. }
  117. if resRes.Err != "" {
  118. w.WriteHeader(http.StatusInternalServerError)
  119. }
  120. b, err := json.Marshal(resRes)
  121. if err != nil {
  122. panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
  123. }
  124. ctrl.resUser = authReq.User
  125. w.Write(b)
  126. })
  127. }
  128. func teardownSuite() {
  129. if server == nil {
  130. return
  131. }
  132. server.Close()
  133. }
  134. // assertAuthHeaders validates authentication headers are removed
  135. func assertAuthHeaders(headers map[string]string) error {
  136. for k := range headers {
  137. if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
  138. panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
  139. }
  140. }
  141. return nil
  142. }
  143. // assertBody asserts that body is removed for non text/json requests
  144. func assertBody(requestURI string, headers map[string]string, body []byte) {
  145. if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
  146. panic("Body included for authentication endpoint " + string(body))
  147. }
  148. for k, v := range headers {
  149. if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
  150. return
  151. }
  152. }
  153. if len(body) > 0 {
  154. panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
  155. }
  156. }