瀏覽代碼

Do not try to refresh JWT token when doing a login request (#2059)

blotus 2 年之前
父節點
當前提交
83c3818504

+ 1 - 1
cmd/crowdsec-cli/capi.go

@@ -172,7 +172,7 @@ func NewCapiStatusCmd() *cobra.Command {
 			}
 			log.Infof("Loaded credentials from %s", csConfig.API.Server.OnlineClient.CredentialsFilePath)
 			log.Infof("Trying to authenticate with username %s on %s", csConfig.API.Server.OnlineClient.Credentials.Login, apiurl)
-			_, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
+			_, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
 			if err != nil {
 				log.Fatalf("Failed to authenticate to Central API (CAPI) : %s", err)
 			}

+ 1 - 1
cmd/crowdsec-cli/lapi.go

@@ -63,7 +63,7 @@ func runLapiStatus(cmd *cobra.Command, args []string) error {
 	}
 	log.Infof("Loaded credentials from %s", csConfig.API.Client.CredentialsFilePath)
 	log.Infof("Trying to authenticate with username %s on %s", login, apiurl)
-	_, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
+	_, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
 	if err != nil {
 		log.Fatalf("Failed to authenticate to Local API (LAPI) : %s", err)
 	} else {

+ 2 - 3
cmd/crowdsec-cli/support.go

@@ -102,7 +102,6 @@ func collectFeatures() []byte {
 	return w.Bytes()
 }
 
-
 func collectOSInfo() ([]byte, error) {
 	log.Info("Collecting OS info")
 	info, err := osinfo.GetOSInfo()
@@ -194,7 +193,7 @@ func collectAPIStatus(login string, password string, endpoint string, prefix str
 		Scenarios: scenarios,
 	}
 
-	_, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
+	_, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
 	if err != nil {
 		return []byte(fmt.Sprintf("Could not authenticate to API: %s", err))
 	} else {
@@ -277,7 +276,7 @@ cscli support dump -f /tmp/crowdsec-support.zip
 			var err error
 			var skipHub, skipDB, skipCAPI, skipLAPI, skipAgent bool
 			infos := map[string][]byte{
-				SUPPORT_VERSION_PATH: collectVersion(),
+				SUPPORT_VERSION_PATH:  collectVersion(),
 				SUPPORT_FEATURES_PATH: collectFeatures(),
 			}
 

+ 10 - 2
cmd/crowdsec/output.go

@@ -97,13 +97,21 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky
 	if err != nil {
 		return errors.Wrapf(err, "new client api: %s", err)
 	}
-	if _, err = Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
+	authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
 		MachineID: &apiConfig.Login,
 		Password:  &password,
 		Scenarios: scenarios,
-	}); err != nil {
+	})
+	if err != nil {
 		return errors.Wrapf(err, "authenticate watcher (%s)", apiConfig.Login)
 	}
+
+	if err := Client.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
+		return errors.Wrap(err, "unable to parse jwt expiration")
+	}
+
+	Client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token
+
 	//start the heartbeat service
 	log.Debugf("Starting HeartBeat service")
 	Client.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb)

+ 2 - 1
pkg/apiclient/auth.go

@@ -232,8 +232,9 @@ func (t *JWTTransport) refreshJwtToken() error {
 func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 	// in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
 	// we use a mutex to avoid this
+	//We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
 	t.refreshTokenMutex.Lock()
-	if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
+	if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
 		if err := t.refreshJwtToken(); err != nil {
 			t.refreshTokenMutex.Unlock()
 			return nil, err

+ 7 - 5
pkg/apiclient/auth_service.go

@@ -51,18 +51,20 @@ func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.W
 	return resp, nil
 }
 
-func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.WatcherAuthRequest) (*Response, error) {
+func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.WatcherAuthRequest) (models.WatcherAuthResponse, *Response, error) {
+	var authResp models.WatcherAuthResponse
+
 	u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix)
 	req, err := s.client.NewRequest(http.MethodPost, u, &auth)
 	if err != nil {
-		return nil, err
+		return authResp, nil, err
 	}
 
-	resp, err := s.client.Do(ctx, req, nil)
+	resp, err := s.client.Do(ctx, req, &authResp)
 	if err != nil {
-		return resp, err
+		return authResp, resp, err
 	}
-	return resp, nil
+	return authResp, resp, nil
 }
 
 func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) {

+ 2 - 2
pkg/apiclient/auth_service_test.go

@@ -60,7 +60,7 @@ func TestWatcherAuth(t *testing.T) {
 		t.Fatalf("new api client: %s", err)
 	}
 
-	_, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
+	_, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
 		MachineID: &mycfg.MachineID,
 		Password:  &mycfg.Password,
 		Scenarios: mycfg.Scenarios,
@@ -84,7 +84,7 @@ func TestWatcherAuth(t *testing.T) {
 		t.Fatalf("new api client: %s", err)
 	}
 
-	_, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
+	_, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
 		MachineID: &mycfg.MachineID,
 		Password:  &mycfg.Password,
 	})

+ 19 - 12
pkg/apiserver/apic.go

@@ -201,18 +201,25 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
 
 	// The watcher will be authenticated by the RoundTripper the first time it will call CAPI
 	// Explicit authentication will provoke an useless supplementary call to CAPI
-	// scenarios, err := ret.FetchScenariosListFromDB()
-	// if err != nil {
-	// 	return ret, errors.Wrapf(err, "get scenario in db: %s", err)
-	// }
-
-	// if _, err = ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
-	// 	MachineID: &config.Credentials.Login,
-	// 	Password:  &password,
-	// 	Scenarios: scenarios,
-	// }); err != nil {
-	// 	return ret, errors.Wrapf(err, "authenticate watcher (%s)", config.Credentials.Login)
-	// }
+	scenarios, err := ret.FetchScenariosListFromDB()
+	if err != nil {
+		return ret, errors.Wrapf(err, "get scenario in db: %s", err)
+	}
+
+	authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
+		MachineID: &config.Credentials.Login,
+		Password:  &password,
+		Scenarios: scenarios,
+	})
+	if err != nil {
+		return ret, errors.Wrapf(err, "authenticate watcher (%s)", config.Credentials.Login)
+	}
+
+	if err := ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
+		return ret, errors.Wrap(err, "unable to parse jwt expiration")
+	}
+
+	ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token
 
 	return ret, err
 }