diff --git a/api/api_test.go b/api/api_test.go index c048ad6d..fb34d6c0 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -383,6 +383,20 @@ func TestStartQuotaScan(t *testing.T) { } } +func TestGetSFTPConnections(t *testing.T) { + _, err := api.GetSFTPConnections(http.StatusOK) + if err != nil { + t.Errorf("unable to get sftp connections: %v", err) + } +} + +func TestCloseActiveSFTPConnection(t *testing.T) { + err := api.CloseSFTPConnection("non_existent_id", http.StatusNotFound) + if err != nil { + t.Errorf("unexpected error closing non existent sftp connection: %v", err) + } +} + // test using mock http server func TestBasicUserHandlingMock(t *testing.T) { diff --git a/api/api_utils.go b/api/api_utils.go index 5873215b..0c9d7601 100644 --- a/api/api_utils.go +++ b/api/api_utils.go @@ -180,6 +180,35 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) error { return checkResponse(resp.StatusCode, expectedStatusCode, resp) } +// GetSFTPConnections returns status and stats for active SFTP connections +func GetSFTPConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, error) { + var connections []sftpd.ConnectionStatus + resp, err := getHTTPClient().Get(httpBaseURL + activeConnectionsPath) + if err != nil { + return connections, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode, resp) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &connections) + } + return connections, err +} + +// CloseSFTPConnection closes an active SFTP connection identified by connectionID +func CloseSFTPConnection(connectionID string, expectedStatusCode int) error { + req, err := http.NewRequest(http.MethodDelete, httpBaseURL+activeConnectionsPath+"/"+connectionID, nil) + if err != nil { + return err + } + resp, err := getHTTPClient().Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + return checkResponse(resp.StatusCode, expectedStatusCode, resp) +} + func checkResponse(actual int, expected int, resp *http.Response) error { if expected != actual { return fmt.Errorf("wrong status code: got %v want %v", actual, expected) diff --git a/api/internal_test.go b/api/internal_test.go index db7abf68..b4e32c51 100644 --- a/api/internal_test.go +++ b/api/internal_test.go @@ -8,6 +8,11 @@ import ( "github.com/drakkan/sftpgo/dataprovider" ) +const ( + invalidURL = "http://foo\x7f.com/" + inactiveURL = "http://127.0.0.1:12345" +) + func TestGetRespStatus(t *testing.T) { var err error err = &dataprovider.MethodDisabledError{} @@ -136,3 +141,65 @@ func TestCompareUserFields(t *testing.T) { t.Errorf("DownloadBandwidth does not match") } } + +func TestApiCallsWithBadURL(t *testing.T) { + oldBaseURL := httpBaseURL + SetBaseURL(invalidURL) + u := dataprovider.User{} + _, err := UpdateUser(u, http.StatusBadRequest) + if err == nil { + t.Errorf("request with invalid URL must fail") + } + err = RemoveUser(u, http.StatusNotFound) + if err == nil { + t.Errorf("request with invalid URL must fail") + } + err = CloseSFTPConnection("non_existent_id", http.StatusNotFound) + if err == nil { + t.Errorf("request with invalid URL must fail") + } + SetBaseURL(oldBaseURL) +} + +func TestApiCallToNotListeningServer(t *testing.T) { + oldBaseURL := httpBaseURL + SetBaseURL(inactiveURL) + u := dataprovider.User{} + _, err := AddUser(u, http.StatusBadRequest) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + _, err = UpdateUser(u, http.StatusNotFound) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + err = RemoveUser(u, http.StatusNotFound) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + _, err = GetUserByID(-1, http.StatusNotFound) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + _, err = GetUsers(100, 0, "", http.StatusOK) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + _, err = GetQuotaScans(http.StatusOK) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + err = StartQuotaScan(u, http.StatusNotFound) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + _, err = GetSFTPConnections(http.StatusOK) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + err = CloseSFTPConnection("non_existent_id", http.StatusNotFound) + if err == nil { + t.Errorf("request to an inactive URL must fail") + } + SetBaseURL(oldBaseURL) +}