api_utils: set a timeout for http requests

These methods are used for test cases only, anyway a timeout is not bad :)
This commit is contained in:
Nicola Murino 2019-07-27 20:42:45 +02:00
parent 4a46b84dd5
commit c9e6fa0dd6

View file

@ -7,7 +7,9 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"time"
"github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/sftpd" "github.com/drakkan/sftpgo/sftpd"
@ -25,6 +27,12 @@ func SetBaseURL(url string) {
httpBaseURL = url httpBaseURL = url
} }
func getHTTPClient() *http.Client {
return &http.Client{
Timeout: 15 * time.Second,
}
}
// AddUser add a new user, useful for tests // AddUser add a new user, useful for tests
func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) { func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
var newUser dataprovider.User var newUser dataprovider.User
@ -32,7 +40,7 @@ func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User,
if err != nil { if err != nil {
return newUser, err return newUser, err
} }
resp, err := http.Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON)) resp, err := getHTTPClient().Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON))
if err != nil { if err != nil {
return newUser, err return newUser, err
} }
@ -61,8 +69,7 @@ func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.Us
if err != nil { if err != nil {
return user, err return user, err
} }
req.Header.Set("Content-Type", "application/json") resp, err := getHTTPClient().Do(req)
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return user, err return user, err
} }
@ -86,7 +93,7 @@ func RemoveUser(user dataprovider.User, expectedStatusCode int) error {
if err != nil { if err != nil {
return err return err
} }
resp, err := http.DefaultClient.Do(req) resp, err := getHTTPClient().Do(req)
if err != nil { if err != nil {
return err return err
} }
@ -97,7 +104,7 @@ func RemoveUser(user dataprovider.User, expectedStatusCode int) error {
// GetUserByID get user by id, useful for tests // GetUserByID get user by id, useful for tests
func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error) { func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error) {
var user dataprovider.User var user dataprovider.User
resp, err := http.Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10)) resp, err := getHTTPClient().Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10))
if err != nil { if err != nil {
return user, err return user, err
} }
@ -112,11 +119,11 @@ func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error
// GetUsers useful for tests // GetUsers useful for tests
func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, error) { func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, error) {
var users []dataprovider.User var users []dataprovider.User
req, err := http.NewRequest(http.MethodGet, httpBaseURL+userPath, nil) url, err := url.Parse(httpBaseURL + userPath)
if err != nil { if err != nil {
return users, err return users, err
} }
q := req.URL.Query() q := url.Query()
if limit > 0 { if limit > 0 {
q.Add("limit", strconv.FormatInt(limit, 10)) q.Add("limit", strconv.FormatInt(limit, 10))
} }
@ -126,8 +133,8 @@ func GetUsers(limit int64, offset int64, username string, expectedStatusCode int
if len(username) > 0 { if len(username) > 0 {
q.Add("username", username) q.Add("username", username)
} }
req.URL.RawQuery = q.Encode() url.RawQuery = q.Encode()
resp, err := http.DefaultClient.Do(req) resp, err := getHTTPClient().Get(url.String())
if err != nil { if err != nil {
return users, err return users, err
} }
@ -142,7 +149,7 @@ func GetUsers(limit int64, offset int64, username string, expectedStatusCode int
// GetQuotaScans get active quota scans, useful for tests // GetQuotaScans get active quota scans, useful for tests
func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, error) { func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, error) {
var quotaScans []sftpd.ActiveQuotaScan var quotaScans []sftpd.ActiveQuotaScan
resp, err := http.Get(httpBaseURL + quotaScanPath) resp, err := getHTTPClient().Get(httpBaseURL + quotaScanPath)
if err != nil { if err != nil {
return quotaScans, err return quotaScans, err
} }
@ -160,7 +167,7 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) error {
if err != nil { if err != nil {
return err return err
} }
resp, err := http.Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON)) resp, err := getHTTPClient().Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON))
if err != nil { if err != nil {
return err return err
} }