diff --git a/internal/dataprovider/node.go b/internal/dataprovider/node.go index 8934d91a..231baf09 100644 --- a/internal/dataprovider/node.go +++ b/internal/dataprovider/node.go @@ -1,3 +1,17 @@ +// Copyright (C) 2019-2022 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + package dataprovider import ( @@ -106,22 +120,27 @@ func (n *Node) validate() error { return n.Data.validate() } -func (n *Node) authenticate(token string) error { +func (n *Node) authenticate(token string) (string, error) { if err := n.Data.Key.TryDecrypt(); err != nil { providerLog(logger.LevelError, "unable to decrypt node key: %v", err) - return err + return "", err } if token == "" { - return ErrInvalidCredentials + return "", ErrInvalidCredentials } t, err := jwt.Parse([]byte(token), jwt.WithVerify(jwa.HS256, []byte(n.Data.Key.GetPayload()))) if err != nil { - return fmt.Errorf("unable to parse token: %v", err) + return "", fmt.Errorf("unable to parse token: %v", err) } if err := jwt.Validate(t); err != nil { - return fmt.Errorf("unable to validate token: %v", err) + return "", fmt.Errorf("unable to validate token: %v", err) } - return nil + if admin, ok := t.Get("admin"); ok { + if val, ok := admin.(string); ok && val != "" { + return val, nil + } + } + return "", errors.New("no admin username associated with node token") } // getBaseURL returns the base URL for this node @@ -138,13 +157,14 @@ func (n *Node) getBaseURL() string { } // generateAuthToken generates a new auth token -func (n *Node) generateAuthToken() (string, error) { +func (n *Node) generateAuthToken(username string) (string, error) { if err := n.Data.Key.TryDecrypt(); err != nil { return "", fmt.Errorf("unable to decrypt node key: %w", err) } now := time.Now().UTC() t := jwt.New() + t.Set("admin", username) //nolint:errcheck t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck @@ -156,13 +176,15 @@ func (n *Node) generateAuthToken() (string, error) { return string(payload), nil } -func (n *Node) prepareRequest(ctx context.Context, relativeURL, method string, body io.Reader) (*http.Request, error) { +func (n *Node) prepareRequest(ctx context.Context, username, relativeURL, method string, + body io.Reader, +) (*http.Request, error) { url := fmt.Sprintf("%s%s", n.getBaseURL(), relativeURL) req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } - token, err := n.generateAuthToken() + token, err := n.generateAuthToken(username) if err != nil { return nil, err } @@ -172,11 +194,11 @@ func (n *Node) prepareRequest(ctx context.Context, relativeURL, method string, b // SendGetRequest sends an HTTP GET request to this node. // The responseHolder must be a pointer -func (n *Node) SendGetRequest(relativeURL string, responseHolder any) error { +func (n *Node) SendGetRequest(username, relativeURL string, responseHolder any) error { ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout) defer cancel() - req, err := n.prepareRequest(ctx, relativeURL, http.MethodGet, nil) + req, err := n.prepareRequest(ctx, username, relativeURL, http.MethodGet, nil) if err != nil { return err } @@ -200,11 +222,11 @@ func (n *Node) SendGetRequest(relativeURL string, responseHolder any) error { } // SendDeleteRequest sends an HTTP DELETE request to this node -func (n *Node) SendDeleteRequest(relativeURL string) error { +func (n *Node) SendDeleteRequest(username, relativeURL string) error { ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout) defer cancel() - req, err := n.prepareRequest(ctx, relativeURL, http.MethodDelete, nil) + req, err := n.prepareRequest(ctx, username, relativeURL, http.MethodDelete, nil) if err != nil { return err } @@ -224,9 +246,9 @@ func (n *Node) SendDeleteRequest(relativeURL string) error { } // AuthenticateNodeToken check the validity of the provided token -func AuthenticateNodeToken(token string) error { +func AuthenticateNodeToken(token string) (string, error) { if currentNode == nil { - return errNoClusterNodes + return "", errNoClusterNodes } return currentNode.authenticate(token) } diff --git a/internal/httpd/api_utils.go b/internal/httpd/api_utils.go index d8748ce5..46bbb1c6 100644 --- a/internal/httpd/api_utils.go +++ b/internal/httpd/api_utils.go @@ -162,13 +162,18 @@ func getActiveConnections(w http.ResponseWriter, r *http.Request) { } stats := common.Connections.GetStats() if claims.NodeID == "" { - stats = append(stats, getNodesConnections()...) + stats = append(stats, getNodesConnections(claims.Username)...) } render.JSON(w, r, stats) } func handleCloseConnection(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := getTokenClaims(r) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } connectionID := getURLParam(r, "connectionID") if connectionID == "" { sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest) @@ -190,7 +195,7 @@ func handleCloseConnection(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, nil, http.StatusText(status), status) return } - if err := n.SendDeleteRequest(fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID)); err != nil { + if err := n.SendDeleteRequest(claims.Username, fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID)); err != nil { logger.Warn(logSender, "", "unable to delete connection id %q from node %q: %v", connectionID, n.Name, err) sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) return @@ -200,7 +205,7 @@ func handleCloseConnection(w http.ResponseWriter, r *http.Request) { // getNodesConnections returns the active connections from other nodes. // Errors are silently ignored -func getNodesConnections() []common.ConnectionStatus { +func getNodesConnections(admin string) []common.ConnectionStatus { nodes, err := dataprovider.GetNodes() if err != nil || len(nodes) == 0 { return nil @@ -216,7 +221,7 @@ func getNodesConnections() []common.ConnectionStatus { defer wg.Done() var stats []common.ConnectionStatus - if err := node.SendGetRequest(activeConnectionsPath, &stats); err != nil { + if err := node.SendGetRequest(admin, activeConnectionsPath, &stats); err != nil { logger.Warn(logSender, "", "unable to get connections from node %s: %v", node.Name, err) return } diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 9a53cc09..7f2acd85 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -556,6 +556,11 @@ func TestInvalidToken(t *testing.T) { assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") + rr = httptest.NewRecorder() + handleCloseConnection(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + rr = httptest.NewRecorder() server.handleWebRestore(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -606,6 +611,11 @@ func TestInvalidToken(t *testing.T) { assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "invalid token claims") + rr = httptest.NewRecorder() + server.handleWebGetConnections(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "invalid token claims") + rr = httptest.NewRecorder() addAdmin(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -1398,13 +1408,22 @@ func TestRenderUnexistingFolder(t *testing.T) { } func TestCloseConnectionHandler(t *testing.T) { - req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) + tokenAuth := jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) + claims := make(map[string]any) + claims["username"] = defaultAdminUsername + claims[jwt.ExpirationKey] = time.Now().UTC().Add(1 * time.Hour) + token, _, err := tokenAuth.Encode(claims) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) + assert.NoError(t, err) rctx := chi.NewRouteContext() rctx.URLParams.Add("connectionID", "") req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), jwtauth.TokenCtxKey, token)) rr := httptest.NewRecorder() handleCloseConnection(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "connectionID is mandatory") } func TestRenderInvalidTemplate(t *testing.T) { diff --git a/internal/httpd/middleware.go b/internal/httpd/middleware.go index 85c42a96..e608b387 100644 --- a/internal/httpd/middleware.go +++ b/internal/httpd/middleware.go @@ -320,18 +320,17 @@ func checkNodeToken(tokenAuth *jwtauth.JWTAuth) func(next http.Handler) http.Han if len(token) > 7 && strings.ToUpper(token[0:6]) == "BEARER" { token = token[7:] } - if err := dataprovider.AuthenticateNodeToken(token); err != nil { + admin, err := dataprovider.AuthenticateNodeToken(token) + if err != nil { logger.Debug(logSender, "", "unable to authenticate node token %q: %v", token, err) sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized) return } - c := jwtTokenClaims{ - Username: fmt.Sprintf("node %s", dataprovider.GetNodeName()), + Username: admin, Permissions: []string{dataprovider.PermAdminViewConnections, dataprovider.PermAdminCloseConnections}, NodeID: dataprovider.GetNodeName(), } - resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr)) if err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/internal/httpd/webadmin.go b/internal/httpd/webadmin.go index 7dbf188a..a6c3507c 100644 --- a/internal/httpd/webadmin.go +++ b/internal/httpd/webadmin.go @@ -2844,8 +2844,13 @@ func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := getTokenClaims(r) + if err != nil || claims.Username == "" { + s.renderBadRequestPage(w, r, errors.New("invalid token claims")) + return + } connectionStats := common.Connections.GetStats() - connectionStats = append(connectionStats, getNodesConnections()...) + connectionStats = append(connectionStats, getNodesConnections(claims.Username)...) data := connectionsPage{ basePage: s.getBasePageData(pageConnectionsTitle, webConnectionsPath, r), Connections: connectionStats,