diff --git a/api/api_utils.go b/api/api_utils.go index f850b9e9..693f3bfe 100644 --- a/api/api_utils.go +++ b/api/api_utils.go @@ -7,7 +7,9 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strconv" + "time" "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/sftpd" @@ -25,6 +27,12 @@ func SetBaseURL(url string) { httpBaseURL = url } +func getHTTPClient() *http.Client { + return &http.Client{ + Timeout: 15 * time.Second, + } +} + // AddUser add a new user, useful for tests func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) { var newUser dataprovider.User @@ -32,7 +40,7 @@ func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, if err != nil { return newUser, err } - resp, err := http.Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON)) + resp, err := getHTTPClient().Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON)) if err != nil { return newUser, err } @@ -61,8 +69,7 @@ func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.Us if err != nil { return user, err } - req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := getHTTPClient().Do(req) if err != nil { return user, err } @@ -86,7 +93,7 @@ func RemoveUser(user dataprovider.User, expectedStatusCode int) error { if err != nil { return err } - resp, err := http.DefaultClient.Do(req) + resp, err := getHTTPClient().Do(req) if err != nil { return err } @@ -97,7 +104,7 @@ func RemoveUser(user dataprovider.User, expectedStatusCode int) error { // GetUserByID get user by id, useful for tests func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error) { var user dataprovider.User - resp, err := http.Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10)) + resp, err := getHTTPClient().Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10)) if err != nil { return user, err } @@ -112,11 +119,11 @@ func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error // GetUsers useful for tests func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, error) { var users []dataprovider.User - req, err := http.NewRequest(http.MethodGet, httpBaseURL+userPath, nil) + url, err := url.Parse(httpBaseURL + userPath) if err != nil { return users, err } - q := req.URL.Query() + q := url.Query() if limit > 0 { q.Add("limit", strconv.FormatInt(limit, 10)) } @@ -126,8 +133,8 @@ func GetUsers(limit int64, offset int64, username string, expectedStatusCode int if len(username) > 0 { q.Add("username", username) } - req.URL.RawQuery = q.Encode() - resp, err := http.DefaultClient.Do(req) + url.RawQuery = q.Encode() + resp, err := getHTTPClient().Get(url.String()) if err != nil { return users, err } @@ -142,7 +149,7 @@ func GetUsers(limit int64, offset int64, username string, expectedStatusCode int // GetQuotaScans get active quota scans, useful for tests func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, error) { var quotaScans []sftpd.ActiveQuotaScan - resp, err := http.Get(httpBaseURL + quotaScanPath) + resp, err := getHTTPClient().Get(httpBaseURL + quotaScanPath) if err != nil { return quotaScans, err } @@ -160,7 +167,7 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) error { if err != nil { return err } - resp, err := http.Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON)) + resp, err := getHTTPClient().Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON)) if err != nil { return err }