auth_service_test.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package apiclient
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "testing"
  11. "github.com/crowdsecurity/go-cs-lib/pkg/version"
  12. "github.com/crowdsecurity/crowdsec/pkg/models"
  13. log "github.com/sirupsen/logrus"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. type BasicMockPayload struct {
  17. MachineID string `json:"machine_id"`
  18. Password string `json:"password"`
  19. }
  20. func getLoginsForMockErrorCases() map[string]int {
  21. loginsForMockErrorCases := map[string]int{
  22. "login_400": http.StatusBadRequest,
  23. "login_409": http.StatusConflict,
  24. "login_500": http.StatusInternalServerError,
  25. }
  26. return loginsForMockErrorCases
  27. }
  28. func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
  29. loginsForMockErrorCases := getLoginsForMockErrorCases()
  30. mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
  31. testMethod(t, r, "POST")
  32. buf := new(bytes.Buffer)
  33. _, _ = buf.ReadFrom(r.Body)
  34. newStr := buf.String()
  35. var payload BasicMockPayload
  36. err := json.Unmarshal([]byte(newStr), &payload)
  37. if err != nil || payload.MachineID == "" || payload.Password == "" {
  38. log.Printf("Bad payload")
  39. w.WriteHeader(http.StatusBadRequest)
  40. }
  41. responseBody := ""
  42. responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID]
  43. if !hasFoundErrorMock {
  44. responseCode = http.StatusOK
  45. responseBody = `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`
  46. } else {
  47. responseBody = fmt.Sprintf("Error %d", responseCode)
  48. }
  49. log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode)
  50. w.WriteHeader(responseCode)
  51. fmt.Fprintf(w, `%s`, responseBody)
  52. })
  53. }
  54. /**
  55. * Test the RegisterClient function
  56. * Making sure it handles the different response code potentially coming from CAPI properly
  57. * 200 => OK
  58. * 400, 409, 500 => Error
  59. */
  60. func TestWatcherRegister(t *testing.T) {
  61. log.SetLevel(log.DebugLevel)
  62. mux, urlx, teardown := setup()
  63. defer teardown()
  64. //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
  65. initBasicMuxMock(t, mux, "/watchers")
  66. log.Printf("URL is %s", urlx)
  67. apiURL, err := url.Parse(urlx + "/")
  68. if err != nil {
  69. t.Fatalf("parsing api url: %s", apiURL)
  70. }
  71. // Valid Registration : should retrieve the client and no err
  72. clientconfig := Config{
  73. MachineID: "test_login",
  74. Password: "test_password",
  75. UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
  76. URL: apiURL,
  77. VersionPrefix: "v1",
  78. }
  79. client, err := RegisterClient(&clientconfig, &http.Client{})
  80. if client == nil || err != nil {
  81. t.Fatalf("while registering client : %s", err)
  82. }
  83. log.Printf("->%T", client)
  84. // Testing error handling on Registration (400, 409, 500): should retrieve an error
  85. errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError}
  86. for _, errorCodeToTest := range errorCodesToTest {
  87. clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
  88. client, err = RegisterClient(&clientconfig, &http.Client{})
  89. if client != nil || err == nil {
  90. t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest)
  91. } else {
  92. log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest)
  93. }
  94. }
  95. }
  96. func TestWatcherAuth(t *testing.T) {
  97. log.SetLevel(log.DebugLevel)
  98. mux, urlx, teardown := setup()
  99. defer teardown()
  100. //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
  101. initBasicMuxMock(t, mux, "/watchers/login")
  102. log.Printf("URL is %s", urlx)
  103. apiURL, err := url.Parse(urlx + "/")
  104. if err != nil {
  105. t.Fatalf("parsing api url: %s", apiURL)
  106. }
  107. //ok auth
  108. clientConfig := &Config{
  109. MachineID: "test_login",
  110. Password: "test_password",
  111. UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
  112. URL: apiURL,
  113. VersionPrefix: "v1",
  114. Scenarios: []string{"crowdsecurity/test"},
  115. }
  116. client, err := NewClient(clientConfig)
  117. if err != nil {
  118. t.Fatalf("new api client: %s", err)
  119. }
  120. _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
  121. MachineID: &clientConfig.MachineID,
  122. Password: &clientConfig.Password,
  123. Scenarios: clientConfig.Scenarios,
  124. })
  125. if err != nil {
  126. t.Fatalf("unexpect auth err 0: %s", err)
  127. }
  128. // Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error
  129. // Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array
  130. errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict}
  131. for _, errorCodeToTest := range errorCodesToTest {
  132. clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
  133. client, err := NewClient(clientConfig)
  134. if err != nil {
  135. t.Fatalf("new api client: %s", err)
  136. }
  137. var resp *Response
  138. _, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
  139. MachineID: &clientConfig.MachineID,
  140. Password: &clientConfig.Password,
  141. })
  142. if err == nil {
  143. resp.Response.Body.Close()
  144. bodyBytes, err := io.ReadAll(resp.Response.Body)
  145. if err != nil {
  146. t.Fatalf("error while reading body: %s", err.Error())
  147. }
  148. log.Printf(string(bodyBytes))
  149. t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
  150. } else {
  151. log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest)
  152. }
  153. }
  154. }
  155. func TestWatcherUnregister(t *testing.T) {
  156. log.SetLevel(log.DebugLevel)
  157. mux, urlx, teardown := setup()
  158. defer teardown()
  159. //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
  160. mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
  161. testMethod(t, r, "DELETE")
  162. assert.Equal(t, r.ContentLength, int64(0))
  163. w.WriteHeader(http.StatusOK)
  164. })
  165. mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
  166. testMethod(t, r, "POST")
  167. buf := new(bytes.Buffer)
  168. _, _ = buf.ReadFrom(r.Body)
  169. newStr := buf.String()
  170. if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]}
  171. ` {
  172. w.WriteHeader(http.StatusOK)
  173. fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
  174. } else {
  175. w.WriteHeader(http.StatusForbidden)
  176. fmt.Fprintf(w, `{"message":"access forbidden"}`)
  177. }
  178. })
  179. log.Printf("URL is %s", urlx)
  180. apiURL, err := url.Parse(urlx + "/")
  181. if err != nil {
  182. t.Fatalf("parsing api url: %s", apiURL)
  183. }
  184. mycfg := &Config{
  185. MachineID: "test_login",
  186. Password: "test_password",
  187. UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
  188. URL: apiURL,
  189. VersionPrefix: "v1",
  190. Scenarios: []string{"crowdsecurity/test"},
  191. }
  192. client, err := NewClient(mycfg)
  193. if err != nil {
  194. t.Fatalf("new api client: %s", err)
  195. }
  196. _, err = client.Auth.UnregisterWatcher(context.Background())
  197. if err != nil {
  198. t.Fatalf("while registering client : %s", err)
  199. }
  200. log.Printf("->%T", client)
  201. }
  202. func TestWatcherEnroll(t *testing.T) {
  203. log.SetLevel(log.DebugLevel)
  204. mux, urlx, teardown := setup()
  205. defer teardown()
  206. mux.HandleFunc("/watchers/enroll", func(w http.ResponseWriter, r *http.Request) {
  207. testMethod(t, r, "POST")
  208. buf := new(bytes.Buffer)
  209. _, _ = buf.ReadFrom(r.Body)
  210. newStr := buf.String()
  211. log.Debugf("body -> %s", newStr)
  212. if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false}
  213. ` {
  214. log.Print("good key")
  215. w.WriteHeader(http.StatusOK)
  216. fmt.Fprintf(w, `{"statusCode": 200, "message": "OK"}`)
  217. } else {
  218. log.Print("bad key")
  219. w.WriteHeader(http.StatusForbidden)
  220. fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`)
  221. }
  222. })
  223. mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
  224. testMethod(t, r, "POST")
  225. w.WriteHeader(http.StatusOK)
  226. fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
  227. })
  228. log.Printf("URL is %s", urlx)
  229. apiURL, err := url.Parse(urlx + "/")
  230. if err != nil {
  231. t.Fatalf("parsing api url: %s", apiURL)
  232. }
  233. mycfg := &Config{
  234. MachineID: "test_login",
  235. Password: "test_password",
  236. UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
  237. URL: apiURL,
  238. VersionPrefix: "v1",
  239. Scenarios: []string{"crowdsecurity/test"},
  240. }
  241. client, err := NewClient(mycfg)
  242. if err != nil {
  243. t.Fatalf("new api client: %s", err)
  244. }
  245. _, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false)
  246. if err != nil {
  247. t.Fatalf("unexpect enroll err: %s", err)
  248. }
  249. _, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
  250. assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error())
  251. }