authz_unix_test.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. // +build !windows
  2. // TODO Windows: This uses a Unix socket for testing. This might be possible
  3. // to port to Windows using a named pipe instead.
  4. package authorization // import "github.com/docker/docker/pkg/authorization"
  5. import (
  6. "bytes"
  7. "encoding/json"
  8. "io/ioutil"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "os"
  13. "path"
  14. "reflect"
  15. "strings"
  16. "testing"
  17. "github.com/docker/docker/pkg/plugins"
  18. "github.com/docker/go-connections/tlsconfig"
  19. "github.com/gorilla/mux"
  20. )
  21. const (
  22. pluginAddress = "authz-test-plugin.sock"
  23. )
  24. func TestAuthZRequestPluginError(t *testing.T) {
  25. server := authZPluginTestServer{t: t}
  26. server.start()
  27. defer server.stop()
  28. authZPlugin := createTestPlugin(t)
  29. request := Request{
  30. User: "user",
  31. RequestBody: []byte("sample body"),
  32. RequestURI: "www.authz.com/auth",
  33. RequestMethod: "GET",
  34. RequestHeaders: map[string]string{"header": "value"},
  35. }
  36. server.replayResponse = Response{
  37. Err: "an error",
  38. }
  39. actualResponse, err := authZPlugin.AuthZRequest(&request)
  40. if err != nil {
  41. t.Fatalf("Failed to authorize request %v", err)
  42. }
  43. if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
  44. t.Fatal("Response must be equal")
  45. }
  46. if !reflect.DeepEqual(request, server.recordedRequest) {
  47. t.Fatal("Requests must be equal")
  48. }
  49. }
  50. func TestAuthZRequestPlugin(t *testing.T) {
  51. server := authZPluginTestServer{t: t}
  52. server.start()
  53. defer server.stop()
  54. authZPlugin := createTestPlugin(t)
  55. request := Request{
  56. User: "user",
  57. RequestBody: []byte("sample body"),
  58. RequestURI: "www.authz.com/auth",
  59. RequestMethod: "GET",
  60. RequestHeaders: map[string]string{"header": "value"},
  61. }
  62. server.replayResponse = Response{
  63. Allow: true,
  64. Msg: "Sample message",
  65. }
  66. actualResponse, err := authZPlugin.AuthZRequest(&request)
  67. if err != nil {
  68. t.Fatalf("Failed to authorize request %v", err)
  69. }
  70. if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
  71. t.Fatal("Response must be equal")
  72. }
  73. if !reflect.DeepEqual(request, server.recordedRequest) {
  74. t.Fatal("Requests must be equal")
  75. }
  76. }
  77. func TestAuthZResponsePlugin(t *testing.T) {
  78. server := authZPluginTestServer{t: t}
  79. server.start()
  80. defer server.stop()
  81. authZPlugin := createTestPlugin(t)
  82. request := Request{
  83. User: "user",
  84. RequestURI: "something.com/auth",
  85. RequestBody: []byte("sample body"),
  86. }
  87. server.replayResponse = Response{
  88. Allow: true,
  89. Msg: "Sample message",
  90. }
  91. actualResponse, err := authZPlugin.AuthZResponse(&request)
  92. if err != nil {
  93. t.Fatalf("Failed to authorize request %v", err)
  94. }
  95. if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
  96. t.Fatal("Response must be equal")
  97. }
  98. if !reflect.DeepEqual(request, server.recordedRequest) {
  99. t.Fatal("Requests must be equal")
  100. }
  101. }
  102. func TestResponseModifier(t *testing.T) {
  103. r := httptest.NewRecorder()
  104. m := NewResponseModifier(r)
  105. m.Header().Set("h1", "v1")
  106. m.Write([]byte("body"))
  107. m.WriteHeader(http.StatusInternalServerError)
  108. m.FlushAll()
  109. if r.Header().Get("h1") != "v1" {
  110. t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
  111. }
  112. if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
  113. t.Fatalf("Body value must exists %s", r.Body.Bytes())
  114. }
  115. if r.Code != http.StatusInternalServerError {
  116. t.Fatalf("Status code must be correct %d", r.Code)
  117. }
  118. }
  119. func TestDrainBody(t *testing.T) {
  120. tests := []struct {
  121. length int // length is the message length send to drainBody
  122. expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
  123. }{
  124. {10, 10}, // Small message size
  125. {maxBodySize - 1, maxBodySize - 1}, // Max message size
  126. {maxBodySize * 2, 0}, // Large message size (skip copying body)
  127. }
  128. for _, test := range tests {
  129. msg := strings.Repeat("a", test.length)
  130. body, closer, err := drainBody(ioutil.NopCloser(bytes.NewReader([]byte(msg))))
  131. if err != nil {
  132. t.Fatal(err)
  133. }
  134. if len(body) != test.expectedBodyLength {
  135. t.Fatalf("Body must be copied, actual length: '%d'", len(body))
  136. }
  137. if closer == nil {
  138. t.Fatal("Closer must not be nil")
  139. }
  140. modified, err := ioutil.ReadAll(closer)
  141. if err != nil {
  142. t.Fatalf("Error must not be nil: '%v'", err)
  143. }
  144. if len(modified) != len(msg) {
  145. t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
  146. }
  147. }
  148. }
  149. func TestSendBody(t *testing.T) {
  150. var (
  151. url = "nothing.com"
  152. testcases = []struct {
  153. contentType string
  154. expected bool
  155. }{
  156. {
  157. contentType: "application/json",
  158. expected: true,
  159. },
  160. {
  161. contentType: "Application/json",
  162. expected: true,
  163. },
  164. {
  165. contentType: "application/JSON",
  166. expected: true,
  167. },
  168. {
  169. contentType: "APPLICATION/JSON",
  170. expected: true,
  171. },
  172. {
  173. contentType: "application/json; charset=utf-8",
  174. expected: true,
  175. },
  176. {
  177. contentType: "application/json;charset=utf-8",
  178. expected: true,
  179. },
  180. {
  181. contentType: "application/json; charset=UTF8",
  182. expected: true,
  183. },
  184. {
  185. contentType: "application/json;charset=UTF8",
  186. expected: true,
  187. },
  188. {
  189. contentType: "text/html",
  190. expected: false,
  191. },
  192. {
  193. contentType: "",
  194. expected: false,
  195. },
  196. }
  197. )
  198. for _, testcase := range testcases {
  199. header := http.Header{}
  200. header.Set("Content-Type", testcase.contentType)
  201. if b := sendBody(url, header); b != testcase.expected {
  202. t.Fatalf("Unexpected Content-Type; Expected: %t, Actual: %t", testcase.expected, b)
  203. }
  204. }
  205. }
  206. func TestResponseModifierOverride(t *testing.T) {
  207. r := httptest.NewRecorder()
  208. m := NewResponseModifier(r)
  209. m.Header().Set("h1", "v1")
  210. m.Write([]byte("body"))
  211. m.WriteHeader(http.StatusInternalServerError)
  212. overrideHeader := make(http.Header)
  213. overrideHeader.Add("h1", "v2")
  214. overrideHeaderBytes, err := json.Marshal(overrideHeader)
  215. if err != nil {
  216. t.Fatalf("override header failed %v", err)
  217. }
  218. m.OverrideHeader(overrideHeaderBytes)
  219. m.OverrideBody([]byte("override body"))
  220. m.OverrideStatusCode(http.StatusNotFound)
  221. m.FlushAll()
  222. if r.Header().Get("h1") != "v2" {
  223. t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
  224. }
  225. if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
  226. t.Fatalf("Body value must exists %s", r.Body.Bytes())
  227. }
  228. if r.Code != http.StatusNotFound {
  229. t.Fatalf("Status code must be correct %d", r.Code)
  230. }
  231. }
  232. // createTestPlugin creates a new sample authorization plugin
  233. func createTestPlugin(t *testing.T) *authorizationPlugin {
  234. pwd, err := os.Getwd()
  235. if err != nil {
  236. t.Fatal(err)
  237. }
  238. client, err := plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), &tlsconfig.Options{InsecureSkipVerify: true})
  239. if err != nil {
  240. t.Fatalf("Failed to create client %v", err)
  241. }
  242. return &authorizationPlugin{name: "plugin", plugin: client}
  243. }
  244. // AuthZPluginTestServer is a simple server that implements the authZ plugin interface
  245. type authZPluginTestServer struct {
  246. listener net.Listener
  247. t *testing.T
  248. // request stores the request sent from the daemon to the plugin
  249. recordedRequest Request
  250. // response stores the response sent from the plugin to the daemon
  251. replayResponse Response
  252. server *httptest.Server
  253. }
  254. // start starts the test server that implements the plugin
  255. func (t *authZPluginTestServer) start() {
  256. r := mux.NewRouter()
  257. l, err := net.Listen("unix", pluginAddress)
  258. if err != nil {
  259. t.t.Fatal(err)
  260. }
  261. t.listener = l
  262. r.HandleFunc("/Plugin.Activate", t.activate)
  263. r.HandleFunc("/"+AuthZApiRequest, t.auth)
  264. r.HandleFunc("/"+AuthZApiResponse, t.auth)
  265. t.server = &httptest.Server{
  266. Listener: l,
  267. Config: &http.Server{
  268. Handler: r,
  269. Addr: pluginAddress,
  270. },
  271. }
  272. t.server.Start()
  273. }
  274. // stop stops the test server that implements the plugin
  275. func (t *authZPluginTestServer) stop() {
  276. t.server.Close()
  277. os.Remove(pluginAddress)
  278. if t.listener != nil {
  279. t.listener.Close()
  280. }
  281. }
  282. // auth is a used to record/replay the authentication api messages
  283. func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
  284. t.recordedRequest = Request{}
  285. body, err := ioutil.ReadAll(r.Body)
  286. if err != nil {
  287. t.t.Fatal(err)
  288. }
  289. r.Body.Close()
  290. json.Unmarshal(body, &t.recordedRequest)
  291. b, err := json.Marshal(t.replayResponse)
  292. if err != nil {
  293. t.t.Fatal(err)
  294. }
  295. w.Write(b)
  296. }
  297. func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
  298. b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
  299. if err != nil {
  300. t.t.Fatal(err)
  301. }
  302. w.Write(b)
  303. }