authz_unix_test.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. // +build !windows
  2. // TODO Windows: This uses a Unix socket for testing. This might be possible
  3. // to port to Windows using a named pipe instead.
  4. package authorization
  5. import (
  6. "bytes"
  7. "encoding/json"
  8. "io/ioutil"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "os"
  13. "path"
  14. "reflect"
  15. "strings"
  16. "testing"
  17. "github.com/docker/docker/pkg/plugins"
  18. "github.com/docker/go-connections/tlsconfig"
  19. "github.com/gorilla/mux"
  20. )
  21. const (
  22. pluginAddress = "authz-test-plugin.sock"
  23. )
  24. func TestAuthZRequestPluginError(t *testing.T) {
  25. server := authZPluginTestServer{t: t}
  26. server.start()
  27. defer server.stop()
  28. authZPlugin := createTestPlugin(t)
  29. request := Request{
  30. User: "user",
  31. RequestBody: []byte("sample body"),
  32. RequestURI: "www.authz.com/auth",
  33. RequestMethod: "GET",
  34. RequestHeaders: map[string]string{"header": "value"},
  35. }
  36. server.replayResponse = Response{
  37. Err: "an error",
  38. }
  39. actualResponse, err := authZPlugin.AuthZRequest(&request)
  40. if err != nil {
  41. t.Fatalf("Failed to authorize request %v", err)
  42. }
  43. if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
  44. t.Fatal("Response must be equal")
  45. }
  46. if !reflect.DeepEqual(request, server.recordedRequest) {
  47. t.Fatal("Requests must be equal")
  48. }
  49. }
  50. func TestAuthZRequestPlugin(t *testing.T) {
  51. server := authZPluginTestServer{t: t}
  52. server.start()
  53. defer server.stop()
  54. authZPlugin := createTestPlugin(t)
  55. request := Request{
  56. User: "user",
  57. RequestBody: []byte("sample body"),
  58. RequestURI: "www.authz.com/auth",
  59. RequestMethod: "GET",
  60. RequestHeaders: map[string]string{"header": "value"},
  61. }
  62. server.replayResponse = Response{
  63. Allow: true,
  64. Msg: "Sample message",
  65. }
  66. actualResponse, err := authZPlugin.AuthZRequest(&request)
  67. if err != nil {
  68. t.Fatalf("Failed to authorize request %v", err)
  69. }
  70. if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
  71. t.Fatal("Response must be equal")
  72. }
  73. if !reflect.DeepEqual(request, server.recordedRequest) {
  74. t.Fatal("Requests must be equal")
  75. }
  76. }
  77. func TestAuthZResponsePlugin(t *testing.T) {
  78. server := authZPluginTestServer{t: t}
  79. server.start()
  80. defer server.stop()
  81. authZPlugin := createTestPlugin(t)
  82. request := Request{
  83. User: "user",
  84. RequestURI: "something.com/auth",
  85. RequestBody: []byte("sample body"),
  86. }
  87. server.replayResponse = Response{
  88. Allow: true,
  89. Msg: "Sample message",
  90. }
  91. actualResponse, err := authZPlugin.AuthZResponse(&request)
  92. if err != nil {
  93. t.Fatalf("Failed to authorize request %v", err)
  94. }
  95. if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
  96. t.Fatal("Response must be equal")
  97. }
  98. if !reflect.DeepEqual(request, server.recordedRequest) {
  99. t.Fatal("Requests must be equal")
  100. }
  101. }
  102. func TestResponseModifier(t *testing.T) {
  103. r := httptest.NewRecorder()
  104. m := NewResponseModifier(r)
  105. m.Header().Set("h1", "v1")
  106. m.Write([]byte("body"))
  107. m.WriteHeader(http.StatusInternalServerError)
  108. m.FlushAll()
  109. if r.Header().Get("h1") != "v1" {
  110. t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
  111. }
  112. if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
  113. t.Fatalf("Body value must exists %s", r.Body.Bytes())
  114. }
  115. if r.Code != http.StatusInternalServerError {
  116. t.Fatalf("Status code must be correct %d", r.Code)
  117. }
  118. }
  119. func TestDrainBody(t *testing.T) {
  120. tests := []struct {
  121. length int // length is the message length send to drainBody
  122. expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
  123. }{
  124. {10, 10}, // Small message size
  125. {maxBodySize - 1, maxBodySize - 1}, // Max message size
  126. {maxBodySize * 2, 0}, // Large message size (skip copying body)
  127. }
  128. for _, test := range tests {
  129. msg := strings.Repeat("a", test.length)
  130. body, closer, err := drainBody(ioutil.NopCloser(bytes.NewReader([]byte(msg))))
  131. if err != nil {
  132. t.Fatal(err)
  133. }
  134. if len(body) != test.expectedBodyLength {
  135. t.Fatalf("Body must be copied, actual length: '%d'", len(body))
  136. }
  137. if closer == nil {
  138. t.Fatal("Closer must not be nil")
  139. }
  140. modified, err := ioutil.ReadAll(closer)
  141. if err != nil {
  142. t.Fatalf("Error must not be nil: '%v'", err)
  143. }
  144. if len(modified) != len(msg) {
  145. t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
  146. }
  147. }
  148. }
  149. func TestResponseModifierOverride(t *testing.T) {
  150. r := httptest.NewRecorder()
  151. m := NewResponseModifier(r)
  152. m.Header().Set("h1", "v1")
  153. m.Write([]byte("body"))
  154. m.WriteHeader(http.StatusInternalServerError)
  155. overrideHeader := make(http.Header)
  156. overrideHeader.Add("h1", "v2")
  157. overrideHeaderBytes, err := json.Marshal(overrideHeader)
  158. if err != nil {
  159. t.Fatalf("override header failed %v", err)
  160. }
  161. m.OverrideHeader(overrideHeaderBytes)
  162. m.OverrideBody([]byte("override body"))
  163. m.OverrideStatusCode(http.StatusNotFound)
  164. m.FlushAll()
  165. if r.Header().Get("h1") != "v2" {
  166. t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
  167. }
  168. if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
  169. t.Fatalf("Body value must exists %s", r.Body.Bytes())
  170. }
  171. if r.Code != http.StatusNotFound {
  172. t.Fatalf("Status code must be correct %d", r.Code)
  173. }
  174. }
  175. // createTestPlugin creates a new sample authorization plugin
  176. func createTestPlugin(t *testing.T) *authorizationPlugin {
  177. pwd, err := os.Getwd()
  178. if err != nil {
  179. t.Fatal(err)
  180. }
  181. client, err := plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), &tlsconfig.Options{InsecureSkipVerify: true})
  182. if err != nil {
  183. t.Fatalf("Failed to create client %v", err)
  184. }
  185. return &authorizationPlugin{name: "plugin", plugin: client}
  186. }
  187. // AuthZPluginTestServer is a simple server that implements the authZ plugin interface
  188. type authZPluginTestServer struct {
  189. listener net.Listener
  190. t *testing.T
  191. // request stores the request sent from the daemon to the plugin
  192. recordedRequest Request
  193. // response stores the response sent from the plugin to the daemon
  194. replayResponse Response
  195. server *httptest.Server
  196. }
  197. // start starts the test server that implements the plugin
  198. func (t *authZPluginTestServer) start() {
  199. r := mux.NewRouter()
  200. l, err := net.Listen("unix", pluginAddress)
  201. if err != nil {
  202. t.t.Fatal(err)
  203. }
  204. t.listener = l
  205. r.HandleFunc("/Plugin.Activate", t.activate)
  206. r.HandleFunc("/"+AuthZApiRequest, t.auth)
  207. r.HandleFunc("/"+AuthZApiResponse, t.auth)
  208. t.server = &httptest.Server{
  209. Listener: l,
  210. Config: &http.Server{
  211. Handler: r,
  212. Addr: pluginAddress,
  213. },
  214. }
  215. t.server.Start()
  216. }
  217. // stop stops the test server that implements the plugin
  218. func (t *authZPluginTestServer) stop() {
  219. t.server.Close()
  220. os.Remove(pluginAddress)
  221. if t.listener != nil {
  222. t.listener.Close()
  223. }
  224. }
  225. // auth is a used to record/replay the authentication api messages
  226. func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
  227. t.recordedRequest = Request{}
  228. body, err := ioutil.ReadAll(r.Body)
  229. if err != nil {
  230. t.t.Fatal(err)
  231. }
  232. r.Body.Close()
  233. json.Unmarshal(body, &t.recordedRequest)
  234. b, err := json.Marshal(t.replayResponse)
  235. if err != nil {
  236. t.t.Fatal(err)
  237. }
  238. w.Write(b)
  239. }
  240. func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
  241. b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
  242. if err != nil {
  243. t.t.Fatal(err)
  244. }
  245. w.Write(b)
  246. }