mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-21 23:20:24 +00:00
Improve test cases and logging
This commit is contained in:
parent
e0a550b216
commit
417b173c78
11 changed files with 654 additions and 56 deletions
11
.travis.yml
11
.travis.yml
|
@ -13,12 +13,7 @@ install:
|
||||||
- go get -v -t ./...
|
- go get -v -t ./...
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- sftpgo &
|
- go test -v ./... -coverprofile=coverage.txt -covermode=atomic
|
||||||
# Wait some seconds to be sure that sftpgo is started and listening
|
|
||||||
- sleep 2
|
|
||||||
- go test -v api/api_test.go
|
|
||||||
- go test -v sftpd/sftpd_test.go
|
|
||||||
|
|
||||||
#test cases run against a real server and the coverage is not detected
|
after_success:
|
||||||
#after_success:
|
- bash <(curl -s https://codecov.io/bash)
|
||||||
# - bash <(curl -s https://codecov.io/bash)
|
|
388
api/api_test.go
388
api/api_test.go
|
@ -1,34 +1,112 @@
|
||||||
package api_test
|
package api_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/drakkan/sftpgo/api"
|
"github.com/drakkan/sftpgo/api"
|
||||||
|
"github.com/drakkan/sftpgo/config"
|
||||||
"github.com/drakkan/sftpgo/dataprovider"
|
"github.com/drakkan/sftpgo/dataprovider"
|
||||||
|
"github.com/drakkan/sftpgo/logger"
|
||||||
|
"github.com/drakkan/sftpgo/sftpd"
|
||||||
)
|
)
|
||||||
|
|
||||||
// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultUsername = "test_user"
|
defaultUsername = "test_user"
|
||||||
defaultPassword = "test_password"
|
defaultPassword = "test_password"
|
||||||
testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"
|
testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"
|
||||||
|
logSender = "APITesting"
|
||||||
|
userPath = "/api/v1/user"
|
||||||
|
activeConnectionsPath = "/api/v1/sftp_connection"
|
||||||
|
quotaScanPath = "/api/v1/quota_scan"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultPerms = []string{dataprovider.PermAny}
|
defaultPerms = []string{dataprovider.PermAny}
|
||||||
homeBasePath string
|
homeBasePath string
|
||||||
|
testServer *httptest.Server
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func TestMain(m *testing.M) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
homeBasePath = "C:\\"
|
homeBasePath = "C:\\"
|
||||||
} else {
|
} else {
|
||||||
homeBasePath = "/tmp"
|
homeBasePath = "/tmp"
|
||||||
}
|
}
|
||||||
|
configDir := ".."
|
||||||
|
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")
|
||||||
|
confName := "sftpgo.conf"
|
||||||
|
logger.InitLogger(logfilePath, zerolog.DebugLevel)
|
||||||
|
configFilePath := filepath.Join(configDir, confName)
|
||||||
|
config.LoadConfig(configFilePath)
|
||||||
|
providerConf := config.GetProviderConf()
|
||||||
|
|
||||||
|
err := dataprovider.Initialize(providerConf, configDir)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn(logSender, "error initializing data provider: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
dataProvider := dataprovider.GetProvider()
|
||||||
|
httpdConf := config.GetHTTPDConfig()
|
||||||
|
router := api.GetHTTPRouter()
|
||||||
|
|
||||||
|
httpdConf.BindPort = 8081
|
||||||
|
api.SetBaseURL("http://127.0.0.1:8081")
|
||||||
|
|
||||||
|
sftpd.SetDataProvider(dataProvider)
|
||||||
|
api.SetDataProvider(dataProvider)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf)
|
||||||
|
s := &http.Server{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort),
|
||||||
|
Handler: router,
|
||||||
|
ReadTimeout: 300 * time.Second,
|
||||||
|
WriteTimeout: 300 * time.Second,
|
||||||
|
MaxHeaderBytes: 1 << 20, // 1MB
|
||||||
|
}
|
||||||
|
if err := s.ListenAndServe(); err != nil {
|
||||||
|
logger.Error(logSender, "could not start HTTP server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
testServer = httptest.NewServer(api.GetHTTPRouter())
|
||||||
|
defer testServer.Close()
|
||||||
|
|
||||||
|
waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort))
|
||||||
|
|
||||||
|
exitCode := m.Run()
|
||||||
|
os.Remove(logfilePath)
|
||||||
|
os.Exit(exitCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitTCPListening(address string) {
|
||||||
|
for {
|
||||||
|
conn, err := net.Dial("tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("tcp server %v not listening: %v\n", address, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("tcp server %v now listening\n", address)
|
||||||
|
defer conn.Close()
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTestUser() dataprovider.User {
|
func getTestUser() dataprovider.User {
|
||||||
|
@ -40,6 +118,27 @@ func getTestUser() dataprovider.User {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getUserAsJSON(t *testing.T, user dataprovider.User) []byte {
|
||||||
|
json, err := json.Marshal(user)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error get user as json: %v", err)
|
||||||
|
return []byte("{}")
|
||||||
|
}
|
||||||
|
return json
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeRequest(req *http.Request) *httptest.ResponseRecorder {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
testServer.Config.Handler.ServeHTTP(rr, req)
|
||||||
|
return rr
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkResponseCode(t *testing.T, expected, actual int) {
|
||||||
|
if expected != actual {
|
||||||
|
t.Errorf("Expected response code %d. Got %d", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBasicUserHandling(t *testing.T) {
|
func TestBasicUserHandling(t *testing.T) {
|
||||||
user, err := api.AddUser(getTestUser(), http.StatusOK)
|
user, err := api.AddUser(getTestUser(), http.StatusOK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -256,14 +355,14 @@ func TestGetUsers(t *testing.T) {
|
||||||
t.Errorf("unable to get users: %v", err)
|
t.Errorf("unable to get users: %v", err)
|
||||||
}
|
}
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("1 user are expected")
|
t.Errorf("1 user is expected")
|
||||||
}
|
}
|
||||||
users, err = api.GetUsers(1, 1, "", http.StatusOK)
|
users, err = api.GetUsers(1, 1, "", http.StatusOK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to get users: %v", err)
|
t.Errorf("unable to get users: %v", err)
|
||||||
}
|
}
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("1 user are expected")
|
t.Errorf("1 user is expected")
|
||||||
}
|
}
|
||||||
err = api.RemoveUser(user1, http.StatusOK)
|
err = api.RemoveUser(user1, http.StatusOK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -296,3 +395,276 @@ func TestStartQuotaScan(t *testing.T) {
|
||||||
t.Errorf("unable to remove user: %v", err)
|
t.Errorf("unable to remove user: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test using mock http server
|
||||||
|
|
||||||
|
func TestBasicUserHandlingMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
err := render.DecodeJSON(rr.Body, &user)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get user: %v", err)
|
||||||
|
}
|
||||||
|
req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusInternalServerError, rr.Code)
|
||||||
|
user.MaxSessions = 10
|
||||||
|
user.UploadBandwidth = 128
|
||||||
|
userAsJSON = getUserAsJSON(t, user)
|
||||||
|
req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
|
||||||
|
var updatedUser dataprovider.User
|
||||||
|
err = render.DecodeJSON(rr.Body, &updatedUser)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error decoding updated user: %v", err)
|
||||||
|
}
|
||||||
|
if user.MaxSessions != updatedUser.MaxSessions || user.UploadBandwidth != updatedUser.UploadBandwidth {
|
||||||
|
t.Errorf("Error modifying user actual: %v, %v", updatedUser.MaxSessions, updatedUser.UploadBandwidth)
|
||||||
|
}
|
||||||
|
req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserByIdInvalidParamsMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, userPath+"/0", nil)
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusNotFound, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, userPath+"/a", nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddUserNoUsernameMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
user.Username = ""
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddUserInvalidHomeDirMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
user.HomeDir = "relative_path"
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddUserInvalidPermsMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
user.Permissions = []string{}
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddUserInvalidJsonMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer([]byte("invalid json")))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateUserInvalidJsonMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
err := render.DecodeJSON(rr.Body, &user)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get user: %v", err)
|
||||||
|
}
|
||||||
|
req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer([]byte("Invalid json")))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateUserInvalidParamsMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
err := render.DecodeJSON(rr.Body, &user)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get user: %v", err)
|
||||||
|
}
|
||||||
|
user.HomeDir = ""
|
||||||
|
userAsJSON = getUserAsJSON(t, user)
|
||||||
|
req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
userID := user.ID
|
||||||
|
user.ID = 0
|
||||||
|
userAsJSON = getUserAsJSON(t, user)
|
||||||
|
req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(userID, 10), bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
user.ID = userID
|
||||||
|
req, _ = http.NewRequest(http.MethodPut, userPath+"/0", bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusNotFound, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodPut, userPath+"/a", bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUsersMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
err := render.DecodeJSON(rr.Body, &user)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get user: %v", err)
|
||||||
|
}
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=510&offset=0&order=ASC&username="+defaultUsername, nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
var users []dataprovider.User
|
||||||
|
err = render.DecodeJSON(rr.Body, &users)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error decoding users: %v", err)
|
||||||
|
}
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("1 user is expected")
|
||||||
|
}
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=a&offset=0&order=ASC", nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=a&order=ASC", nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=0&order=ASCa", nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
|
||||||
|
req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteUserInvalidParamsMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, userPath+"/0", nil)
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusNotFound, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodDelete, userPath+"/a", nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetQuotaScansMock(t *testing.T) {
|
||||||
|
req, err := http.NewRequest("GET", quotaScanPath, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error get quota scan: %v", err)
|
||||||
|
}
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartQuotaScanMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
err := render.DecodeJSON(rr.Body, &user)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get user: %v", err)
|
||||||
|
}
|
||||||
|
_, err = os.Stat(user.HomeDir)
|
||||||
|
if err == nil {
|
||||||
|
os.Remove(user.HomeDir)
|
||||||
|
}
|
||||||
|
userAsJSON = getUserAsJSON(t, user)
|
||||||
|
req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusCreated, rr.Code)
|
||||||
|
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
var scans []sftpd.ActiveQuotaScan
|
||||||
|
err = render.DecodeJSON(rr.Body, &scans)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get active scans: %v", err)
|
||||||
|
}
|
||||||
|
for len(scans) > 0 {
|
||||||
|
req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
err = render.DecodeJSON(rr.Body, &scans)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error get active scans: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = os.Stat(user.HomeDir)
|
||||||
|
if err != nil && os.IsNotExist(err) {
|
||||||
|
os.MkdirAll(user.HomeDir, 0777)
|
||||||
|
}
|
||||||
|
req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusCreated, rr.Code)
|
||||||
|
req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
|
||||||
|
rr = executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartQuotaScanBadUserMock(t *testing.T) {
|
||||||
|
user := getTestUser()
|
||||||
|
userAsJSON := getUserAsJSON(t, user)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartQuotaScanNonExistentUserMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer([]byte("invalid json")))
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusBadRequest, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSFTPConnectionsMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil)
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteActiveConnectionMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil)
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotFoundMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/non/existing/path", nil)
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodNotAllowedMock(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, activeConnectionsPath, nil)
|
||||||
|
rr := executeRequest(req)
|
||||||
|
checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
|
|
@ -15,14 +15,16 @@ import (
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
httpBaseURL = "http://127.0.0.1:8080"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultPerms = []string{dataprovider.PermAny}
|
defaultPerms = []string{dataprovider.PermAny}
|
||||||
|
httpBaseURL = "http://127.0.0.1:8080"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetBaseURL sets the url to use for HTTP request, default is "http://127.0.0.1:8080"
|
||||||
|
func SetBaseURL(url string) {
|
||||||
|
httpBaseURL = url
|
||||||
|
}
|
||||||
|
|
||||||
// AddUser add a new user, useful for tests
|
// AddUser add a new user, useful for tests
|
||||||
func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
|
func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
|
||||||
var newUser dataprovider.User
|
var newUser dataprovider.User
|
||||||
|
@ -170,7 +172,7 @@ func checkResponse(actual int, expected int, resp *http.Response) error {
|
||||||
if expected != actual {
|
if expected != actual {
|
||||||
return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
|
return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
|
||||||
}
|
}
|
||||||
if expected != http.StatusOK {
|
if expected != http.StatusOK && resp != nil {
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
fmt.Printf("request: %v, response body: %v", resp.Request.URL, string(b))
|
fmt.Printf("request: %v, response body: %v", resp.Request.URL, string(b))
|
||||||
|
@ -189,8 +191,6 @@ func checkUser(expected dataprovider.User, actual dataprovider.User) error {
|
||||||
if expected.ID <= 0 {
|
if expected.ID <= 0 {
|
||||||
if actual.ID <= 0 {
|
if actual.ID <= 0 {
|
||||||
return errors.New("actual user ID must be > 0")
|
return errors.New("actual user ID must be > 0")
|
||||||
} else if actual.ID <= 0 {
|
|
||||||
return errors.New("user ID must be >=0")
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if actual.ID != expected.ID {
|
if actual.ID != expected.ID {
|
||||||
|
|
138
api/internal_test.go
Normal file
138
api/internal_test.go
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/drakkan/sftpgo/dataprovider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRespStatus(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
err = &dataprovider.MethodDisabledError{}
|
||||||
|
respStatus := getRespStatus(err)
|
||||||
|
if respStatus != http.StatusForbidden {
|
||||||
|
t.Errorf("wrong resp status extected: %d got: %d", http.StatusForbidden, respStatus)
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("generic error")
|
||||||
|
respStatus = getRespStatus(err)
|
||||||
|
if respStatus != http.StatusInternalServerError {
|
||||||
|
t.Errorf("wrong resp status extected: %d got: %d", http.StatusInternalServerError, respStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckResponse(t *testing.T) {
|
||||||
|
err := checkResponse(200, 201, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("check must fail")
|
||||||
|
}
|
||||||
|
err = checkResponse(400, 400, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("test must succeed, error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckUser(t *testing.T) {
|
||||||
|
expected := dataprovider.User{}
|
||||||
|
actual := dataprovider.User{}
|
||||||
|
actual.Password = "password"
|
||||||
|
err := checkUser(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("actual password must be nil")
|
||||||
|
}
|
||||||
|
actual.Password = ""
|
||||||
|
actual.PublicKey = "pub key"
|
||||||
|
err = checkUser(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("actual public key must be nil")
|
||||||
|
}
|
||||||
|
actual.PublicKey = ""
|
||||||
|
err = checkUser(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("actual ID must be > 0")
|
||||||
|
}
|
||||||
|
expected.ID = 1
|
||||||
|
actual.ID = 2
|
||||||
|
err = checkUser(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("actual ID must be equal to expected ID")
|
||||||
|
}
|
||||||
|
expected.ID = 2
|
||||||
|
actual.ID = 2
|
||||||
|
expected.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload}
|
||||||
|
actual.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks}
|
||||||
|
err = checkUser(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Permissions are not equal")
|
||||||
|
}
|
||||||
|
expected.Permissions = append(expected.Permissions, dataprovider.PermRename)
|
||||||
|
err = checkUser(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Permissions are not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareUserFields(t *testing.T) {
|
||||||
|
expected := dataprovider.User{}
|
||||||
|
actual := dataprovider.User{}
|
||||||
|
expected.Username = "test"
|
||||||
|
err := compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Username does not match")
|
||||||
|
}
|
||||||
|
expected.Username = ""
|
||||||
|
expected.HomeDir = "homedir"
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("HomeDir does not match")
|
||||||
|
}
|
||||||
|
expected.HomeDir = ""
|
||||||
|
expected.UID = 1
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("UID does not match")
|
||||||
|
}
|
||||||
|
expected.UID = 0
|
||||||
|
expected.GID = 1
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("GID does not match")
|
||||||
|
}
|
||||||
|
expected.GID = 0
|
||||||
|
expected.MaxSessions = 2
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("MaxSessions do not match")
|
||||||
|
}
|
||||||
|
expected.MaxSessions = 0
|
||||||
|
expected.QuotaSize = 4096
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("QuotaSize does not match")
|
||||||
|
}
|
||||||
|
expected.QuotaSize = 0
|
||||||
|
expected.QuotaFiles = 2
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("QuotaFiles do not match")
|
||||||
|
}
|
||||||
|
expected.QuotaFiles = 0
|
||||||
|
expected.Permissions = []string{dataprovider.PermCreateDirs}
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Permissions are not equal")
|
||||||
|
}
|
||||||
|
expected.Permissions = nil
|
||||||
|
expected.UploadBandwidth = 64
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("UploadBandwidth does not match")
|
||||||
|
}
|
||||||
|
expected.UploadBandwidth = 0
|
||||||
|
expected.DownloadBandwidth = 128
|
||||||
|
err = compareEqualsUserFields(expected, actual)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("DownloadBandwidth does not match")
|
||||||
|
}
|
||||||
|
}
|
47
config/config_test.go
Normal file
47
config/config_test.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/drakkan/sftpgo/api"
|
||||||
|
"github.com/drakkan/sftpgo/config"
|
||||||
|
"github.com/drakkan/sftpgo/dataprovider"
|
||||||
|
"github.com/drakkan/sftpgo/sftpd"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigTest(t *testing.T) {
|
||||||
|
configDir := ".."
|
||||||
|
confName := "sftpgo.conf"
|
||||||
|
configFilePath := filepath.Join(configDir, confName)
|
||||||
|
err := config.LoadConfig(configFilePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error loading config")
|
||||||
|
}
|
||||||
|
emptyHTTPDConf := api.HTTPDConf{}
|
||||||
|
if config.GetHTTPDConfig() == emptyHTTPDConf {
|
||||||
|
t.Errorf("error loading httpd conf")
|
||||||
|
}
|
||||||
|
emptyProviderConf := dataprovider.Config{}
|
||||||
|
if config.GetProviderConf() == emptyProviderConf {
|
||||||
|
t.Errorf("error loading provider conf")
|
||||||
|
}
|
||||||
|
emptySFTPDConf := sftpd.Configuration{}
|
||||||
|
if config.GetSFTPDConfig() == emptySFTPDConf {
|
||||||
|
t.Errorf("error loading SFTPD conf")
|
||||||
|
}
|
||||||
|
confName = "sftpgo.conf.missing"
|
||||||
|
configFilePath = filepath.Join(configDir, confName)
|
||||||
|
err = config.LoadConfig(configFilePath)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("loading a non existent config file must fail")
|
||||||
|
}
|
||||||
|
ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666)
|
||||||
|
err = config.LoadConfig(configFilePath)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("loading an invalid config file must fail")
|
||||||
|
}
|
||||||
|
os.Remove(configFilePath)
|
||||||
|
}
|
|
@ -25,12 +25,12 @@ func initializeMySQLProvider() error {
|
||||||
dbHandle, err = sql.Open("mysql", connectionString)
|
dbHandle, err = sql.Open("mysql", connectionString)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
numCPU := runtime.NumCPU()
|
numCPU := runtime.NumCPU()
|
||||||
logger.Debug(logSender, "mysql database handle created, connection string: %v, connections: %v", connectionString, numCPU)
|
logger.Debug(logSender, "mysql database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU)
|
||||||
dbHandle.SetMaxIdleConns(numCPU)
|
dbHandle.SetMaxIdleConns(numCPU)
|
||||||
dbHandle.SetMaxOpenConns(numCPU)
|
dbHandle.SetMaxOpenConns(numCPU)
|
||||||
dbHandle.SetConnMaxLifetime(1800 * time.Second)
|
dbHandle.SetConnMaxLifetime(1800 * time.Second)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(logSender, "error creating mysql database handler, connection string: %v, error: %v", connectionString, err)
|
logger.Warn(logSender, "error creating mysql database handler, connection string: \"%v\", error: %v", connectionString, err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,11 +24,11 @@ func initializePGSQLProvider() error {
|
||||||
dbHandle, err = sql.Open("postgres", connectionString)
|
dbHandle, err = sql.Open("postgres", connectionString)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
numCPU := runtime.NumCPU()
|
numCPU := runtime.NumCPU()
|
||||||
logger.Debug(logSender, "postgres database handle created, connection string: %v, connections: %v", connectionString, numCPU)
|
logger.Debug(logSender, "postgres database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU)
|
||||||
dbHandle.SetMaxIdleConns(numCPU)
|
dbHandle.SetMaxIdleConns(numCPU)
|
||||||
dbHandle.SetMaxOpenConns(numCPU)
|
dbHandle.SetMaxOpenConns(numCPU)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(logSender, "error creating postgres database handler, connection string: %v, error: %v", connectionString, err)
|
logger.Warn(logSender, "error creating postgres database handler, connection string: \"%v\", error: %v", connectionString, err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,10 +35,10 @@ func initializeSQLiteProvider(basePath string) error {
|
||||||
}
|
}
|
||||||
dbHandle, err = sql.Open("sqlite3", connectionString)
|
dbHandle, err = sql.Open("sqlite3", connectionString)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Debug(logSender, "sqlite database handle created, connection string: %v", connectionString)
|
logger.Debug(logSender, "sqlite database handle created, connection string: \"%v\"", connectionString)
|
||||||
dbHandle.SetMaxOpenConns(1)
|
dbHandle.SetMaxOpenConns(1)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(logSender, "error creating sqlite database handler, connection string: %v, error: %v", connectionString, err)
|
logger.Warn(logSender, "error creating sqlite database handler, connection string: \"%v\", error: %v", connectionString, err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ func (c Configuration) Initialize(configDir string) error {
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
|
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(logSender, "error starting listener on address %v: %v", listener.Addr().String(), err)
|
logger.Warn(logSender, "error starting listener on address %s:%d: %v", c.BindAddress, c.BindPort, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -247,7 +247,6 @@ func removeTransfer(transfer *Transfer) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if indexToRemove >= 0 {
|
if indexToRemove >= 0 {
|
||||||
//logger.Debug(logSender, "remove index %v from active transfer, size: %v", indexToRemove, len(activeTransfers))
|
|
||||||
activeTransfers[indexToRemove] = activeTransfers[len(activeTransfers)-1]
|
activeTransfers[indexToRemove] = activeTransfers[len(activeTransfers)-1]
|
||||||
activeTransfers = activeTransfers[:len(activeTransfers)-1]
|
activeTransfers = activeTransfers[:len(activeTransfers)-1]
|
||||||
} else {
|
} else {
|
||||||
|
@ -259,29 +258,7 @@ func updateConnectionActivity(id string) {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
if c, ok := openConnections[id]; ok {
|
if c, ok := openConnections[id]; ok {
|
||||||
//logger.Debug(logSender, "update connection activity, id: %v", id)
|
|
||||||
c.lastActivity = time.Now()
|
c.lastActivity = time.Now()
|
||||||
openConnections[id] = c
|
openConnections[id] = c
|
||||||
}
|
}
|
||||||
//logger.Debug(logSender, "connection activity updated: %+v", openConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logConnections() {
|
|
||||||
mutex.RLock()
|
|
||||||
defer mutex.RUnlock()
|
|
||||||
for _, c := range openConnections {
|
|
||||||
logger.Debug(logSender, "active connection %+v", c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func logTransfers() {
|
|
||||||
mutex.RLock()
|
|
||||||
defer mutex.RUnlock()
|
|
||||||
if len(activeTransfers) > 0 {
|
|
||||||
for _, v := range activeTransfers {
|
|
||||||
logger.Debug(logSender, "active transfer: %+v", v)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Debug(logSender, "no active transfer")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,17 +11,25 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/drakkan/sftpgo/api"
|
"github.com/drakkan/sftpgo/api"
|
||||||
|
"github.com/drakkan/sftpgo/config"
|
||||||
"github.com/drakkan/sftpgo/dataprovider"
|
"github.com/drakkan/sftpgo/dataprovider"
|
||||||
|
"github.com/drakkan/sftpgo/logger"
|
||||||
|
"github.com/drakkan/sftpgo/sftpd"
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
logSender = "sftpdTesting"
|
||||||
sftpServerAddr = "127.0.0.1:2022"
|
sftpServerAddr = "127.0.0.1:2022"
|
||||||
defaultUsername = "test_user_sftp"
|
defaultUsername = "test_user_sftp"
|
||||||
defaultPassword = "test_password"
|
defaultPassword = "test_password"
|
||||||
|
@ -71,12 +79,73 @@ var (
|
||||||
homeBasePath string
|
homeBasePath string
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func TestMain(m *testing.M) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
homeBasePath = "C:\\"
|
homeBasePath = "C:\\"
|
||||||
} else {
|
} else {
|
||||||
homeBasePath = "/tmp"
|
homeBasePath = "/tmp"
|
||||||
}
|
}
|
||||||
|
configDir := ".."
|
||||||
|
logfilePath := filepath.Join(configDir, "sftpgo_sftpd_test.log")
|
||||||
|
confName := "sftpgo.conf"
|
||||||
|
logger.InitLogger(logfilePath, zerolog.DebugLevel)
|
||||||
|
configFilePath := filepath.Join(configDir, confName)
|
||||||
|
config.LoadConfig(configFilePath)
|
||||||
|
providerConf := config.GetProviderConf()
|
||||||
|
|
||||||
|
err := dataprovider.Initialize(providerConf, configDir)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn(logSender, "error initializing data provider: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
dataProvider := dataprovider.GetProvider()
|
||||||
|
sftpdConf := config.GetSFTPDConfig()
|
||||||
|
httpdConf := config.GetHTTPDConfig()
|
||||||
|
router := api.GetHTTPRouter()
|
||||||
|
|
||||||
|
sftpd.SetDataProvider(dataProvider)
|
||||||
|
api.SetDataProvider(dataProvider)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf)
|
||||||
|
if err := sftpdConf.Initialize(configDir); err != nil {
|
||||||
|
logger.Error(logSender, "could not start SFTP server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf)
|
||||||
|
s := &http.Server{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort),
|
||||||
|
Handler: router,
|
||||||
|
ReadTimeout: 300 * time.Second,
|
||||||
|
WriteTimeout: 300 * time.Second,
|
||||||
|
MaxHeaderBytes: 1 << 20, // 1MB
|
||||||
|
}
|
||||||
|
if err := s.ListenAndServe(); err != nil {
|
||||||
|
logger.Error(logSender, "could not start HTTP server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitTCPListening(fmt.Sprintf("%s:%d", sftpdConf.BindAddress, sftpdConf.BindPort))
|
||||||
|
waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort))
|
||||||
|
|
||||||
|
exitCode := m.Run()
|
||||||
|
os.Remove(logfilePath)
|
||||||
|
os.Exit(exitCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitTCPListening(address string) {
|
||||||
|
for {
|
||||||
|
conn, err := net.Dial("tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("tcp server %v not listening: %v\n", address, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("tcp server %v now listening\n", address)
|
||||||
|
defer conn.Close()
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {
|
func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {
|
||||||
|
|
Loading…
Reference in a new issue