middleware_test.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. package authorization // import "github.com/docker/docker/pkg/authorization"
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "strings"
  6. "testing"
  7. "github.com/docker/docker/pkg/plugingetter"
  8. "gotest.tools/v3/assert"
  9. )
  10. func TestMiddleware(t *testing.T) {
  11. pluginNames := []string{"testPlugin1", "testPlugin2"}
  12. var pluginGetter plugingetter.PluginGetter
  13. m := NewMiddleware(pluginNames, pluginGetter)
  14. authPlugins := m.getAuthzPlugins()
  15. assert.Equal(t, 2, len(authPlugins))
  16. assert.Equal(t, pluginNames[0], authPlugins[0].Name())
  17. assert.Equal(t, pluginNames[1], authPlugins[1].Name())
  18. }
  19. func TestNewResponseModifier(t *testing.T) {
  20. recorder := httptest.NewRecorder()
  21. modifier := NewResponseModifier(recorder)
  22. modifier.Header().Set("H1", "V1")
  23. modifier.Write([]byte("body"))
  24. assert.Assert(t, !modifier.Hijacked())
  25. modifier.WriteHeader(http.StatusInternalServerError)
  26. assert.Assert(t, modifier.RawBody() != nil)
  27. raw, err := modifier.RawHeaders()
  28. assert.Assert(t, raw != nil)
  29. assert.NilError(t, err)
  30. headerData := strings.Split(strings.TrimSpace(string(raw)), ":")
  31. assert.Equal(t, "H1", strings.TrimSpace(headerData[0]))
  32. assert.Equal(t, "V1", strings.TrimSpace(headerData[1]))
  33. modifier.Flush()
  34. modifier.FlushAll()
  35. if recorder.Header().Get("H1") != "V1" {
  36. t.Fatalf("Header value must exists %s", recorder.Header().Get("H1"))
  37. }
  38. }
  39. func setAuthzPlugins(m *Middleware, plugins []Plugin) {
  40. m.mu.Lock()
  41. m.plugins = plugins
  42. m.mu.Unlock()
  43. }