auth.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. package apiclient
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand"
  9. "net/http"
  10. "net/http/httputil"
  11. "net/url"
  12. "sync"
  13. "time"
  14. "github.com/go-openapi/strfmt"
  15. log "github.com/sirupsen/logrus"
  16. "github.com/crowdsecurity/crowdsec/pkg/fflag"
  17. "github.com/crowdsecurity/crowdsec/pkg/models"
  18. )
  19. type APIKeyTransport struct {
  20. APIKey string
  21. // Transport is the underlying HTTP transport to use when making requests.
  22. // It will default to http.DefaultTransport if nil.
  23. Transport http.RoundTripper
  24. URL *url.URL
  25. VersionPrefix string
  26. UserAgent string
  27. }
  28. // RoundTrip implements the RoundTripper interface.
  29. func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  30. if t.APIKey == "" {
  31. return nil, errors.New("APIKey is empty")
  32. }
  33. // We must make a copy of the Request so
  34. // that we don't modify the Request we were given. This is required by the
  35. // specification of http.RoundTripper.
  36. req = cloneRequest(req)
  37. req.Header.Add("X-Api-Key", t.APIKey)
  38. if t.UserAgent != "" {
  39. req.Header.Add("User-Agent", t.UserAgent)
  40. }
  41. log.Debugf("req-api: %s %s", req.Method, req.URL.String())
  42. if log.GetLevel() >= log.TraceLevel {
  43. dump, _ := httputil.DumpRequest(req, true)
  44. log.Tracef("auth-api request: %s", string(dump))
  45. }
  46. // Make the HTTP request.
  47. resp, err := t.transport().RoundTrip(req)
  48. if err != nil {
  49. log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
  50. return resp, err
  51. }
  52. if log.GetLevel() >= log.TraceLevel {
  53. dump, _ := httputil.DumpResponse(resp, true)
  54. log.Tracef("auth-api response: %s", string(dump))
  55. }
  56. log.Debugf("resp-api: http %d", resp.StatusCode)
  57. return resp, err
  58. }
  59. func (t *APIKeyTransport) Client() *http.Client {
  60. return &http.Client{Transport: t}
  61. }
  62. func (t *APIKeyTransport) transport() http.RoundTripper {
  63. if t.Transport != nil {
  64. return t.Transport
  65. }
  66. return http.DefaultTransport
  67. }
  68. type retryRoundTripper struct {
  69. next http.RoundTripper
  70. maxAttempts int
  71. retryStatusCodes []int
  72. withBackOff bool
  73. onBeforeRequest func(attempt int)
  74. }
  75. func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
  76. for _, code := range r.retryStatusCodes {
  77. if code == statusCode {
  78. return true
  79. }
  80. }
  81. return false
  82. }
  83. func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  84. var (
  85. resp *http.Response
  86. err error
  87. )
  88. backoff := 0
  89. maxAttempts := r.maxAttempts
  90. if fflag.DisableHttpRetryBackoff.IsEnabled() {
  91. maxAttempts = 1
  92. }
  93. for i := 0; i < maxAttempts; i++ {
  94. if i > 0 {
  95. if r.withBackOff {
  96. //nolint:gosec
  97. backoff += 10 + rand.Intn(20)
  98. }
  99. log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
  100. select {
  101. case <-req.Context().Done():
  102. return resp, req.Context().Err()
  103. case <-time.After(time.Duration(backoff) * time.Second):
  104. }
  105. }
  106. if r.onBeforeRequest != nil {
  107. r.onBeforeRequest(i)
  108. }
  109. clonedReq := cloneRequest(req)
  110. resp, err = r.next.RoundTrip(clonedReq)
  111. if err != nil {
  112. if left := maxAttempts - i - 1; left > 0 {
  113. log.Errorf("error while performing request: %s; %d retries left", err, left)
  114. }
  115. continue
  116. }
  117. if !r.ShouldRetry(resp.StatusCode) {
  118. return resp, nil
  119. }
  120. }
  121. return resp, err
  122. }
  123. type JWTTransport struct {
  124. MachineID *string
  125. Password *strfmt.Password
  126. Token string
  127. Expiration time.Time
  128. Scenarios []string
  129. URL *url.URL
  130. VersionPrefix string
  131. UserAgent string
  132. // Transport is the underlying HTTP transport to use when making requests.
  133. // It will default to http.DefaultTransport if nil.
  134. Transport http.RoundTripper
  135. UpdateScenario func() ([]string, error)
  136. refreshTokenMutex sync.Mutex
  137. }
  138. func (t *JWTTransport) refreshJwtToken() error {
  139. var err error
  140. if t.UpdateScenario != nil {
  141. t.Scenarios, err = t.UpdateScenario()
  142. if err != nil {
  143. return fmt.Errorf("can't update scenario list: %s", err)
  144. }
  145. log.Debugf("scenarios list updated for '%s'", *t.MachineID)
  146. }
  147. auth := models.WatcherAuthRequest{
  148. MachineID: t.MachineID,
  149. Password: t.Password,
  150. Scenarios: t.Scenarios,
  151. }
  152. var response models.WatcherAuthResponse
  153. /*
  154. we don't use the main client, so let's build the body
  155. */
  156. var buf io.ReadWriter = &bytes.Buffer{}
  157. enc := json.NewEncoder(buf)
  158. enc.SetEscapeHTML(false)
  159. err = enc.Encode(auth)
  160. if err != nil {
  161. return fmt.Errorf("could not encode jwt auth body: %w", err)
  162. }
  163. req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
  164. if err != nil {
  165. return fmt.Errorf("could not create request: %w", err)
  166. }
  167. req.Header.Add("Content-Type", "application/json")
  168. client := &http.Client{
  169. Transport: &retryRoundTripper{
  170. next: http.DefaultTransport,
  171. maxAttempts: 5,
  172. withBackOff: true,
  173. retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
  174. },
  175. }
  176. if t.UserAgent != "" {
  177. req.Header.Add("User-Agent", t.UserAgent)
  178. }
  179. if log.GetLevel() >= log.TraceLevel {
  180. dump, _ := httputil.DumpRequest(req, true)
  181. log.Tracef("auth-jwt request: %s", string(dump))
  182. }
  183. log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String())
  184. resp, err := client.Do(req)
  185. if err != nil {
  186. return fmt.Errorf("could not get jwt token: %w", err)
  187. }
  188. log.Debugf("auth-jwt : http %d", resp.StatusCode)
  189. if log.GetLevel() >= log.TraceLevel {
  190. dump, _ := httputil.DumpResponse(resp, true)
  191. log.Tracef("auth-jwt response: %s", string(dump))
  192. }
  193. defer resp.Body.Close()
  194. if resp.StatusCode < 200 || resp.StatusCode >= 300 {
  195. log.Debugf("received response status %q when fetching %v", resp.Status, req.URL)
  196. err = CheckResponse(resp)
  197. if err != nil {
  198. return err
  199. }
  200. }
  201. if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
  202. return fmt.Errorf("unable to decode response: %w", err)
  203. }
  204. if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
  205. return fmt.Errorf("unable to parse jwt expiration: %w", err)
  206. }
  207. t.Token = response.Token
  208. log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
  209. return nil
  210. }
  211. // RoundTrip implements the RoundTripper interface.
  212. func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  213. // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
  214. // we use a mutex to avoid this
  215. // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
  216. t.refreshTokenMutex.Lock()
  217. if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
  218. if err := t.refreshJwtToken(); err != nil {
  219. t.refreshTokenMutex.Unlock()
  220. return nil, err
  221. }
  222. }
  223. t.refreshTokenMutex.Unlock()
  224. if t.UserAgent != "" {
  225. req.Header.Add("User-Agent", t.UserAgent)
  226. }
  227. req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
  228. if log.GetLevel() >= log.TraceLevel {
  229. //requestToDump := cloneRequest(req)
  230. dump, _ := httputil.DumpRequest(req, true)
  231. log.Tracef("req-jwt: %s", string(dump))
  232. }
  233. // Make the HTTP request.
  234. resp, err := t.transport().RoundTrip(req)
  235. if log.GetLevel() >= log.TraceLevel {
  236. dump, _ := httputil.DumpResponse(resp, true)
  237. log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
  238. }
  239. if err != nil {
  240. // we had an error (network error for example, or 401 because token is refused), reset the token ?
  241. t.Token = ""
  242. return resp, fmt.Errorf("performing jwt auth: %w", err)
  243. }
  244. if resp != nil {
  245. log.Debugf("resp-jwt: %d", resp.StatusCode)
  246. }
  247. return resp, nil
  248. }
  249. func (t *JWTTransport) Client() *http.Client {
  250. return &http.Client{Transport: t}
  251. }
  252. func (t *JWTTransport) ResetToken() {
  253. log.Debug("resetting jwt token")
  254. t.refreshTokenMutex.Lock()
  255. t.Token = ""
  256. t.refreshTokenMutex.Unlock()
  257. }
  258. func (t *JWTTransport) transport() http.RoundTripper {
  259. var transport http.RoundTripper
  260. if t.Transport != nil {
  261. transport = t.Transport
  262. } else {
  263. transport = http.DefaultTransport
  264. }
  265. // a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded
  266. return &retryRoundTripper{
  267. next: &retryRoundTripper{
  268. next: transport,
  269. maxAttempts: 5,
  270. withBackOff: true,
  271. retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
  272. },
  273. maxAttempts: 2,
  274. withBackOff: false,
  275. retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
  276. onBeforeRequest: func(attempt int) {
  277. // reset the token only in the second attempt as this is when we know we had a 401 or 403
  278. // the second attempt is supposed to refresh the token
  279. if attempt > 0 {
  280. t.ResetToken()
  281. }
  282. },
  283. }
  284. }
  285. // cloneRequest returns a clone of the provided *http.Request. The clone is a
  286. // shallow copy of the struct and its Header map.
  287. func cloneRequest(r *http.Request) *http.Request {
  288. // shallow copy of the struct
  289. r2 := new(http.Request)
  290. *r2 = *r
  291. // deep copy of the Header
  292. r2.Header = make(http.Header, len(r.Header))
  293. for k, s := range r.Header {
  294. r2.Header[k] = append([]string(nil), s...)
  295. }
  296. if r.Body != nil {
  297. var b bytes.Buffer
  298. b.ReadFrom(r.Body)
  299. r.Body = io.NopCloser(&b)
  300. r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
  301. }
  302. return r2
  303. }