auth.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. package apiclient
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "math/rand"
  6. "sync"
  7. "time"
  8. //"errors"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "net/http/httputil"
  13. "net/url"
  14. "github.com/crowdsecurity/crowdsec/pkg/fflag"
  15. "github.com/crowdsecurity/crowdsec/pkg/models"
  16. "github.com/go-openapi/strfmt"
  17. "github.com/pkg/errors"
  18. log "github.com/sirupsen/logrus"
  19. //"google.golang.org/appengine/log"
  20. )
  21. type APIKeyTransport struct {
  22. APIKey string
  23. // Transport is the underlying HTTP transport to use when making requests.
  24. // It will default to http.DefaultTransport if nil.
  25. Transport http.RoundTripper
  26. URL *url.URL
  27. VersionPrefix string
  28. UserAgent string
  29. }
  30. // RoundTrip implements the RoundTripper interface.
  31. func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  32. if t.APIKey == "" {
  33. return nil, errors.New("APIKey is empty")
  34. }
  35. // We must make a copy of the Request so
  36. // that we don't modify the Request we were given. This is required by the
  37. // specification of http.RoundTripper.
  38. req = cloneRequest(req)
  39. req.Header.Add("X-Api-Key", t.APIKey)
  40. if t.UserAgent != "" {
  41. req.Header.Add("User-Agent", t.UserAgent)
  42. }
  43. log.Debugf("req-api: %s %s", req.Method, req.URL.String())
  44. if log.GetLevel() >= log.TraceLevel {
  45. dump, _ := httputil.DumpRequest(req, true)
  46. log.Tracef("auth-api request: %s", string(dump))
  47. }
  48. // Make the HTTP request.
  49. resp, err := t.transport().RoundTrip(req)
  50. if err != nil {
  51. log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
  52. return resp, err
  53. }
  54. if log.GetLevel() >= log.TraceLevel {
  55. dump, _ := httputil.DumpResponse(resp, true)
  56. log.Tracef("auth-api response: %s", string(dump))
  57. }
  58. log.Debugf("resp-api: http %d", resp.StatusCode)
  59. return resp, err
  60. }
  61. func (t *APIKeyTransport) Client() *http.Client {
  62. return &http.Client{Transport: t}
  63. }
  64. func (t *APIKeyTransport) transport() http.RoundTripper {
  65. if t.Transport != nil {
  66. return t.Transport
  67. }
  68. return http.DefaultTransport
  69. }
  70. type retryRoundTripper struct {
  71. next http.RoundTripper
  72. maxAttempts int
  73. retryStatusCodes []int
  74. withBackOff bool
  75. onBeforeRequest func(attempt int)
  76. }
  77. func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
  78. for _, code := range r.retryStatusCodes {
  79. if code == statusCode {
  80. return true
  81. }
  82. }
  83. return false
  84. }
  85. func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  86. var resp *http.Response
  87. var err error
  88. backoff := 0
  89. for i := 0; i < r.maxAttempts; i++ {
  90. if i > 0 {
  91. if r.withBackOff && !fflag.DisableHttpRetryBackoff.IsEnabled() {
  92. backoff += 10 + rand.Intn(20)
  93. }
  94. log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
  95. select {
  96. case <-req.Context().Done():
  97. return resp, req.Context().Err()
  98. case <-time.After(time.Duration(backoff) * time.Second):
  99. }
  100. }
  101. if r.onBeforeRequest != nil {
  102. r.onBeforeRequest(i)
  103. }
  104. clonedReq := cloneRequest(req)
  105. resp, err = r.next.RoundTrip(clonedReq)
  106. if err != nil {
  107. log.Errorf("error while performing request: %s; %d retries left", err, r.maxAttempts-i-1)
  108. continue
  109. }
  110. if !r.ShouldRetry(resp.StatusCode) {
  111. return resp, nil
  112. }
  113. }
  114. return resp, err
  115. }
  116. type JWTTransport struct {
  117. MachineID *string
  118. Password *strfmt.Password
  119. Token string
  120. Expiration time.Time
  121. Scenarios []string
  122. URL *url.URL
  123. VersionPrefix string
  124. UserAgent string
  125. // Transport is the underlying HTTP transport to use when making requests.
  126. // It will default to http.DefaultTransport if nil.
  127. Transport http.RoundTripper
  128. UpdateScenario func() ([]string, error)
  129. refreshTokenMutex sync.Mutex
  130. }
  131. func (t *JWTTransport) refreshJwtToken() error {
  132. var err error
  133. if t.UpdateScenario != nil {
  134. t.Scenarios, err = t.UpdateScenario()
  135. if err != nil {
  136. return fmt.Errorf("can't update scenario list: %s", err)
  137. }
  138. log.Debugf("scenarios list updated for '%s'", *t.MachineID)
  139. }
  140. var auth = models.WatcherAuthRequest{
  141. MachineID: t.MachineID,
  142. Password: t.Password,
  143. Scenarios: t.Scenarios,
  144. }
  145. var response models.WatcherAuthResponse
  146. /*
  147. we don't use the main client, so let's build the body
  148. */
  149. var buf io.ReadWriter = &bytes.Buffer{}
  150. enc := json.NewEncoder(buf)
  151. enc.SetEscapeHTML(false)
  152. err = enc.Encode(auth)
  153. if err != nil {
  154. return errors.Wrap(err, "could not encode jwt auth body")
  155. }
  156. req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
  157. if err != nil {
  158. return errors.Wrap(err, "could not create request")
  159. }
  160. req.Header.Add("Content-Type", "application/json")
  161. client := &http.Client{
  162. Transport: &retryRoundTripper{
  163. next: http.DefaultTransport,
  164. maxAttempts: 5,
  165. withBackOff: true,
  166. retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
  167. },
  168. }
  169. if t.UserAgent != "" {
  170. req.Header.Add("User-Agent", t.UserAgent)
  171. }
  172. if log.GetLevel() >= log.TraceLevel {
  173. dump, _ := httputil.DumpRequest(req, true)
  174. log.Tracef("auth-jwt request: %s", string(dump))
  175. }
  176. log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String())
  177. resp, err := client.Do(req)
  178. if err != nil {
  179. return errors.Wrap(err, "could not get jwt token")
  180. }
  181. log.Debugf("auth-jwt : http %d", resp.StatusCode)
  182. if log.GetLevel() >= log.TraceLevel {
  183. dump, _ := httputil.DumpResponse(resp, true)
  184. log.Tracef("auth-jwt response: %s", string(dump))
  185. }
  186. defer resp.Body.Close()
  187. if resp.StatusCode < 200 || resp.StatusCode >= 300 {
  188. log.Debugf("received response status %q when fetching %v", resp.Status, req.URL)
  189. err = CheckResponse(resp)
  190. if err != nil {
  191. return err
  192. }
  193. }
  194. if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
  195. return errors.Wrap(err, "unable to decode response")
  196. }
  197. if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
  198. return errors.Wrap(err, "unable to parse jwt expiration")
  199. }
  200. t.Token = response.Token
  201. log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
  202. return nil
  203. }
  204. // RoundTrip implements the RoundTripper interface.
  205. func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  206. // in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
  207. // we use a mutex to avoid this
  208. //We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
  209. t.refreshTokenMutex.Lock()
  210. if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
  211. if err := t.refreshJwtToken(); err != nil {
  212. t.refreshTokenMutex.Unlock()
  213. return nil, err
  214. }
  215. }
  216. t.refreshTokenMutex.Unlock()
  217. if t.UserAgent != "" {
  218. req.Header.Add("User-Agent", t.UserAgent)
  219. }
  220. req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
  221. if log.GetLevel() >= log.TraceLevel {
  222. //requestToDump := cloneRequest(req)
  223. dump, _ := httputil.DumpRequest(req, true)
  224. log.Tracef("req-jwt: %s", string(dump))
  225. }
  226. // Make the HTTP request.
  227. resp, err := t.transport().RoundTrip(req)
  228. if log.GetLevel() >= log.TraceLevel {
  229. dump, _ := httputil.DumpResponse(resp, true)
  230. log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
  231. }
  232. if err != nil {
  233. /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
  234. t.Token = ""
  235. return resp, errors.Wrapf(err, "performing jwt auth")
  236. }
  237. log.Debugf("resp-jwt: %d", resp.StatusCode)
  238. return resp, nil
  239. }
  240. func (t *JWTTransport) Client() *http.Client {
  241. return &http.Client{Transport: t}
  242. }
  243. func (t *JWTTransport) ResetToken() {
  244. log.Debug("resetting jwt token")
  245. t.refreshTokenMutex.Lock()
  246. t.Token = ""
  247. t.refreshTokenMutex.Unlock()
  248. }
  249. func (t *JWTTransport) transport() http.RoundTripper {
  250. var transport http.RoundTripper
  251. if t.Transport != nil {
  252. transport = t.Transport
  253. } else {
  254. transport = http.DefaultTransport
  255. }
  256. // a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded
  257. return &retryRoundTripper{
  258. next: &retryRoundTripper{
  259. next: transport,
  260. maxAttempts: 5,
  261. withBackOff: true,
  262. retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
  263. },
  264. maxAttempts: 2,
  265. withBackOff: false,
  266. retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
  267. onBeforeRequest: func(attempt int) {
  268. // reset the token only in the second attempt as this is when we know we had a 401 or 403
  269. // the second attempt is supposed to refresh the token
  270. if attempt > 0 {
  271. t.ResetToken()
  272. }
  273. },
  274. }
  275. }
  276. // cloneRequest returns a clone of the provided *http.Request. The clone is a
  277. // shallow copy of the struct and its Header map.
  278. func cloneRequest(r *http.Request) *http.Request {
  279. // shallow copy of the struct
  280. r2 := new(http.Request)
  281. *r2 = *r
  282. // deep copy of the Header
  283. r2.Header = make(http.Header, len(r.Header))
  284. for k, s := range r.Header {
  285. r2.Header[k] = append([]string(nil), s...)
  286. }
  287. if r.Body != nil {
  288. var b bytes.Buffer
  289. b.ReadFrom(r.Body)
  290. r.Body = io.NopCloser(&b)
  291. r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
  292. }
  293. return r2
  294. }