api_utils: set a timeout for http requests

These methods are used for test cases only, anyway a timeout is not bad :)
This commit is contained in:
Nicola Murino 2019-07-27 20:42:45 +02:00
parent 4a46b84dd5
commit c9e6fa0dd6

View file

@ -7,7 +7,9 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"time"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/sftpd"
@ -25,6 +27,12 @@ func SetBaseURL(url string) {
httpBaseURL = url
}
func getHTTPClient() *http.Client {
return &http.Client{
Timeout: 15 * time.Second,
}
}
// AddUser add a new user, useful for tests
func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
var newUser dataprovider.User
@ -32,7 +40,7 @@ func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User,
if err != nil {
return newUser, err
}
resp, err := http.Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON))
resp, err := getHTTPClient().Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON))
if err != nil {
return newUser, err
}
@ -61,8 +69,7 @@ func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.Us
if err != nil {
return user, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := getHTTPClient().Do(req)
if err != nil {
return user, err
}
@ -86,7 +93,7 @@ func RemoveUser(user dataprovider.User, expectedStatusCode int) error {
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
resp, err := getHTTPClient().Do(req)
if err != nil {
return err
}
@ -97,7 +104,7 @@ func RemoveUser(user dataprovider.User, expectedStatusCode int) error {
// GetUserByID get user by id, useful for tests
func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error) {
var user dataprovider.User
resp, err := http.Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10))
resp, err := getHTTPClient().Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10))
if err != nil {
return user, err
}
@ -112,11 +119,11 @@ func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error
// GetUsers useful for tests
func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, error) {
var users []dataprovider.User
req, err := http.NewRequest(http.MethodGet, httpBaseURL+userPath, nil)
url, err := url.Parse(httpBaseURL + userPath)
if err != nil {
return users, err
}
q := req.URL.Query()
q := url.Query()
if limit > 0 {
q.Add("limit", strconv.FormatInt(limit, 10))
}
@ -126,8 +133,8 @@ func GetUsers(limit int64, offset int64, username string, expectedStatusCode int
if len(username) > 0 {
q.Add("username", username)
}
req.URL.RawQuery = q.Encode()
resp, err := http.DefaultClient.Do(req)
url.RawQuery = q.Encode()
resp, err := getHTTPClient().Get(url.String())
if err != nil {
return users, err
}
@ -142,7 +149,7 @@ func GetUsers(limit int64, offset int64, username string, expectedStatusCode int
// GetQuotaScans get active quota scans, useful for tests
func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, error) {
var quotaScans []sftpd.ActiveQuotaScan
resp, err := http.Get(httpBaseURL + quotaScanPath)
resp, err := getHTTPClient().Get(httpBaseURL + quotaScanPath)
if err != nil {
return quotaScans, err
}
@ -160,7 +167,7 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) error {
if err != nil {
return err
}
resp, err := http.Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON))
resp, err := getHTTPClient().Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON))
if err != nil {
return err
}