diff --git a/.travis.yml b/.travis.yml index 6e02b3d2..c23ead77 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,12 +13,7 @@ install: - go get -v -t ./... script: - - sftpgo & - # Wait some seconds to be sure that sftpgo is started and listening - - sleep 2 - - go test -v api/api_test.go - - go test -v sftpd/sftpd_test.go + - go test -v ./... -coverprofile=coverage.txt -covermode=atomic -#test cases run against a real server and the coverage is not detected -#after_success: -# - bash <(curl -s https://codecov.io/bash) \ No newline at end of file +after_success: + - bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/api/api_test.go b/api/api_test.go index e55cfa01..1fefb724 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -1,34 +1,112 @@ package api_test import ( + "bytes" + "encoding/json" + "fmt" + "net" "net/http" + "net/http/httptest" + "os" "path/filepath" "runtime" + "strconv" "testing" + "time" + + "github.com/go-chi/render" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/config" "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/sftpd" ) -// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API - const ( - defaultUsername = "test_user" - defaultPassword = "test_password" - testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + defaultUsername = "test_user" + defaultPassword = "test_password" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + logSender = "APITesting" + userPath = "/api/v1/user" + activeConnectionsPath = "/api/v1/sftp_connection" + quotaScanPath = "/api/v1/quota_scan" ) var ( defaultPerms = []string{dataprovider.PermAny} homeBasePath string + testServer *httptest.Server ) -func init() { +func TestMain(m *testing.M) { if runtime.GOOS == "windows" { homeBasePath = "C:\\" } else { homeBasePath = "/tmp" } + configDir := ".." + logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") + confName := "sftpgo.conf" + logger.InitLogger(logfilePath, zerolog.DebugLevel) + configFilePath := filepath.Join(configDir, confName) + config.LoadConfig(configFilePath) + providerConf := config.GetProviderConf() + + err := dataprovider.Initialize(providerConf, configDir) + if err != nil { + logger.Warn(logSender, "error initializing data provider: %v", err) + os.Exit(1) + } + dataProvider := dataprovider.GetProvider() + httpdConf := config.GetHTTPDConfig() + router := api.GetHTTPRouter() + + httpdConf.BindPort = 8081 + api.SetBaseURL("http://127.0.0.1:8081") + + sftpd.SetDataProvider(dataProvider) + api.SetDataProvider(dataProvider) + + go func() { + logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf) + s := &http.Server{ + Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort), + Handler: router, + ReadTimeout: 300 * time.Second, + WriteTimeout: 300 * time.Second, + MaxHeaderBytes: 1 << 20, // 1MB + } + if err := s.ListenAndServe(); err != nil { + logger.Error(logSender, "could not start HTTP server: %v", err) + } + }() + + testServer = httptest.NewServer(api.GetHTTPRouter()) + defer testServer.Close() + + waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort)) + + exitCode := m.Run() + os.Remove(logfilePath) + os.Exit(exitCode) +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + fmt.Printf("tcp server %v not listening: %v\n", address, err) + continue + } + fmt.Printf("tcp server %v now listening\n", address) + defer conn.Close() + break + } } func getTestUser() dataprovider.User { @@ -40,6 +118,27 @@ func getTestUser() dataprovider.User { } } +func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { + json, err := json.Marshal(user) + if err != nil { + t.Errorf("error get user as json: %v", err) + return []byte("{}") + } + return json +} + +func executeRequest(req *http.Request) *httptest.ResponseRecorder { + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + return rr +} + +func checkResponseCode(t *testing.T, expected, actual int) { + if expected != actual { + t.Errorf("Expected response code %d. Got %d", expected, actual) + } +} + func TestBasicUserHandling(t *testing.T) { user, err := api.AddUser(getTestUser(), http.StatusOK) if err != nil { @@ -256,14 +355,14 @@ func TestGetUsers(t *testing.T) { t.Errorf("unable to get users: %v", err) } if len(users) != 1 { - t.Errorf("1 user are expected") + t.Errorf("1 user is expected") } users, err = api.GetUsers(1, 1, "", http.StatusOK) if err != nil { t.Errorf("unable to get users: %v", err) } if len(users) != 1 { - t.Errorf("1 user are expected") + t.Errorf("1 user is expected") } err = api.RemoveUser(user1, http.StatusOK) if err != nil { @@ -296,3 +395,276 @@ func TestStartQuotaScan(t *testing.T) { t.Errorf("unable to remove user: %v", err) } } + +// test using mock http server + +func TestBasicUserHandlingMock(t *testing.T) { + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + err := render.DecodeJSON(rr.Body, &user) + if err != nil { + t.Errorf("Error get user: %v", err) + } + req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr.Code) + user.MaxSessions = 10 + user.UploadBandwidth = 128 + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + + req, _ = http.NewRequest(http.MethodGet, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + + var updatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updatedUser) + if err != nil { + t.Errorf("Error decoding updated user: %v", err) + } + if user.MaxSessions != updatedUser.MaxSessions || user.UploadBandwidth != updatedUser.UploadBandwidth { + t.Errorf("Error modifying user actual: %v, %v", updatedUser.MaxSessions, updatedUser.UploadBandwidth) + } + req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestGetUserByIdInvalidParamsMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, userPath+"/0", nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr.Code) + req, _ = http.NewRequest(http.MethodGet, userPath+"/a", nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestAddUserNoUsernameMock(t *testing.T) { + user := getTestUser() + user.Username = "" + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestAddUserInvalidHomeDirMock(t *testing.T) { + user := getTestUser() + user.HomeDir = "relative_path" + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestAddUserInvalidPermsMock(t *testing.T) { + user := getTestUser() + user.Permissions = []string{} + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestAddUserInvalidJsonMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer([]byte("invalid json"))) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestUpdateUserInvalidJsonMock(t *testing.T) { + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + err := render.DecodeJSON(rr.Body, &user) + if err != nil { + t.Errorf("Error get user: %v", err) + } + req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer([]byte("Invalid json"))) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestUpdateUserInvalidParamsMock(t *testing.T) { + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + err := render.DecodeJSON(rr.Body, &user) + if err != nil { + t.Errorf("Error get user: %v", err) + } + user.HomeDir = "" + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + userID := user.ID + user.ID = 0 + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(userID, 10), bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + user.ID = userID + req, _ = http.NewRequest(http.MethodPut, userPath+"/0", bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr.Code) + req, _ = http.NewRequest(http.MethodPut, userPath+"/a", bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestGetUsersMock(t *testing.T) { + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + err := render.DecodeJSON(rr.Body, &user) + if err != nil { + t.Errorf("Error get user: %v", err) + } + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=510&offset=0&order=ASC&username="+defaultUsername, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + var users []dataprovider.User + err = render.DecodeJSON(rr.Body, &users) + if err != nil { + t.Errorf("Error decoding users: %v", err) + } + if len(users) != 1 { + t.Errorf("1 user is expected") + } + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=a&offset=0&order=ASC", nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=a&order=ASC", nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=0&order=ASCa", nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) + + req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestDeleteUserInvalidParamsMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodDelete, userPath+"/0", nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr.Code) + req, _ = http.NewRequest(http.MethodDelete, userPath+"/a", nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestGetQuotaScansMock(t *testing.T) { + req, err := http.NewRequest("GET", quotaScanPath, nil) + if err != nil { + t.Errorf("error get quota scan: %v", err) + } + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestStartQuotaScanMock(t *testing.T) { + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + err := render.DecodeJSON(rr.Body, &user) + if err != nil { + t.Errorf("Error get user: %v", err) + } + _, err = os.Stat(user.HomeDir) + if err == nil { + os.Remove(user.HomeDir) + } + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr.Code) + + req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + var scans []sftpd.ActiveQuotaScan + err = render.DecodeJSON(rr.Body, &scans) + if err != nil { + t.Errorf("Error get active scans: %v", err) + } + for len(scans) > 0 { + req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) + err = render.DecodeJSON(rr.Body, &scans) + if err != nil { + t.Errorf("Error get active scans: %v", err) + break + } + } + _, err = os.Stat(user.HomeDir) + if err != nil && os.IsNotExist(err) { + os.MkdirAll(user.HomeDir, 0777) + } + req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr.Code) + req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestStartQuotaScanBadUserMock(t *testing.T) { + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr.Code) +} + +func TestStartQuotaScanNonExistentUserMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer([]byte("invalid json"))) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr.Code) +} + +func TestGetSFTPConnectionsMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr.Code) +} + +func TestDeleteActiveConnectionMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr.Code) +} + +func TestNotFoundMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "/non/existing/path", nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr.Code) +} + +func TestMethodNotAllowedMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, activeConnectionsPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code) +} diff --git a/api/api_utils.go b/api/api_utils.go index da288830..f850b9e9 100644 --- a/api/api_utils.go +++ b/api/api_utils.go @@ -15,14 +15,16 @@ import ( "github.com/go-chi/render" ) -const ( - httpBaseURL = "http://127.0.0.1:8080" -) - var ( defaultPerms = []string{dataprovider.PermAny} + httpBaseURL = "http://127.0.0.1:8080" ) +// SetBaseURL sets the url to use for HTTP request, default is "http://127.0.0.1:8080" +func SetBaseURL(url string) { + httpBaseURL = url +} + // AddUser add a new user, useful for tests func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) { var newUser dataprovider.User @@ -170,7 +172,7 @@ func checkResponse(actual int, expected int, resp *http.Response) error { if expected != actual { return fmt.Errorf("wrong status code: got %v want %v", actual, expected) } - if expected != http.StatusOK { + if expected != http.StatusOK && resp != nil { b, err := ioutil.ReadAll(resp.Body) if err == nil { fmt.Printf("request: %v, response body: %v", resp.Request.URL, string(b)) @@ -189,8 +191,6 @@ func checkUser(expected dataprovider.User, actual dataprovider.User) error { if expected.ID <= 0 { if actual.ID <= 0 { return errors.New("actual user ID must be > 0") - } else if actual.ID <= 0 { - return errors.New("user ID must be >=0") } } else { if actual.ID != expected.ID { diff --git a/api/internal_test.go b/api/internal_test.go new file mode 100644 index 00000000..49fe639c --- /dev/null +++ b/api/internal_test.go @@ -0,0 +1,138 @@ +package api + +import ( + "fmt" + "net/http" + "testing" + + "github.com/drakkan/sftpgo/dataprovider" +) + +func TestGetRespStatus(t *testing.T) { + var err error + err = &dataprovider.MethodDisabledError{} + respStatus := getRespStatus(err) + if respStatus != http.StatusForbidden { + t.Errorf("wrong resp status extected: %d got: %d", http.StatusForbidden, respStatus) + } + err = fmt.Errorf("generic error") + respStatus = getRespStatus(err) + if respStatus != http.StatusInternalServerError { + t.Errorf("wrong resp status extected: %d got: %d", http.StatusInternalServerError, respStatus) + } +} + +func TestCheckResponse(t *testing.T) { + err := checkResponse(200, 201, nil) + if err == nil { + t.Errorf("check must fail") + } + err = checkResponse(400, 400, nil) + if err != nil { + t.Errorf("test must succeed, error: %v", err) + } +} + +func TestCheckUser(t *testing.T) { + expected := dataprovider.User{} + actual := dataprovider.User{} + actual.Password = "password" + err := checkUser(expected, actual) + if err == nil { + t.Errorf("actual password must be nil") + } + actual.Password = "" + actual.PublicKey = "pub key" + err = checkUser(expected, actual) + if err == nil { + t.Errorf("actual public key must be nil") + } + actual.PublicKey = "" + err = checkUser(expected, actual) + if err == nil { + t.Errorf("actual ID must be > 0") + } + expected.ID = 1 + actual.ID = 2 + err = checkUser(expected, actual) + if err == nil { + t.Errorf("actual ID must be equal to expected ID") + } + expected.ID = 2 + actual.ID = 2 + expected.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload} + actual.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} + err = checkUser(expected, actual) + if err == nil { + t.Errorf("Permissions are not equal") + } + expected.Permissions = append(expected.Permissions, dataprovider.PermRename) + err = checkUser(expected, actual) + if err == nil { + t.Errorf("Permissions are not equal") + } +} + +func TestCompareUserFields(t *testing.T) { + expected := dataprovider.User{} + actual := dataprovider.User{} + expected.Username = "test" + err := compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("Username does not match") + } + expected.Username = "" + expected.HomeDir = "homedir" + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("HomeDir does not match") + } + expected.HomeDir = "" + expected.UID = 1 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("UID does not match") + } + expected.UID = 0 + expected.GID = 1 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("GID does not match") + } + expected.GID = 0 + expected.MaxSessions = 2 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("MaxSessions do not match") + } + expected.MaxSessions = 0 + expected.QuotaSize = 4096 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("QuotaSize does not match") + } + expected.QuotaSize = 0 + expected.QuotaFiles = 2 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("QuotaFiles do not match") + } + expected.QuotaFiles = 0 + expected.Permissions = []string{dataprovider.PermCreateDirs} + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("Permissions are not equal") + } + expected.Permissions = nil + expected.UploadBandwidth = 64 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("UploadBandwidth does not match") + } + expected.UploadBandwidth = 0 + expected.DownloadBandwidth = 128 + err = compareEqualsUserFields(expected, actual) + if err == nil { + t.Errorf("DownloadBandwidth does not match") + } +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 00000000..32246381 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,47 @@ +package config_test + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/config" + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/sftpd" +) + +func TestLoadConfigTest(t *testing.T) { + configDir := ".." + confName := "sftpgo.conf" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configFilePath) + if err != nil { + t.Errorf("error loading config") + } + emptyHTTPDConf := api.HTTPDConf{} + if config.GetHTTPDConfig() == emptyHTTPDConf { + t.Errorf("error loading httpd conf") + } + emptyProviderConf := dataprovider.Config{} + if config.GetProviderConf() == emptyProviderConf { + t.Errorf("error loading provider conf") + } + emptySFTPDConf := sftpd.Configuration{} + if config.GetSFTPDConfig() == emptySFTPDConf { + t.Errorf("error loading SFTPD conf") + } + confName = "sftpgo.conf.missing" + configFilePath = filepath.Join(configDir, confName) + err = config.LoadConfig(configFilePath) + if err == nil { + t.Errorf("loading a non existent config file must fail") + } + ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666) + err = config.LoadConfig(configFilePath) + if err == nil { + t.Errorf("loading an invalid config file must fail") + } + os.Remove(configFilePath) +} diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 218cb090..dda6553d 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -25,12 +25,12 @@ func initializeMySQLProvider() error { dbHandle, err = sql.Open("mysql", connectionString) if err == nil { numCPU := runtime.NumCPU() - logger.Debug(logSender, "mysql database handle created, connection string: %v, connections: %v", connectionString, numCPU) + logger.Debug(logSender, "mysql database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU) dbHandle.SetMaxIdleConns(numCPU) dbHandle.SetMaxOpenConns(numCPU) dbHandle.SetConnMaxLifetime(1800 * time.Second) } else { - logger.Warn(logSender, "error creating mysql database handler, connection string: %v, error: %v", connectionString, err) + logger.Warn(logSender, "error creating mysql database handler, connection string: \"%v\", error: %v", connectionString, err) } return err } diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index bbc8d7cc..753bfff9 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -24,11 +24,11 @@ func initializePGSQLProvider() error { dbHandle, err = sql.Open("postgres", connectionString) if err == nil { numCPU := runtime.NumCPU() - logger.Debug(logSender, "postgres database handle created, connection string: %v, connections: %v", connectionString, numCPU) + logger.Debug(logSender, "postgres database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU) dbHandle.SetMaxIdleConns(numCPU) dbHandle.SetMaxOpenConns(numCPU) } else { - logger.Warn(logSender, "error creating postgres database handler, connection string: %v, error: %v", connectionString, err) + logger.Warn(logSender, "error creating postgres database handler, connection string: \"%v\", error: %v", connectionString, err) } return err } diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index 1179244b..0c0af379 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -35,10 +35,10 @@ func initializeSQLiteProvider(basePath string) error { } dbHandle, err = sql.Open("sqlite3", connectionString) if err == nil { - logger.Debug(logSender, "sqlite database handle created, connection string: %v", connectionString) + logger.Debug(logSender, "sqlite database handle created, connection string: \"%v\"", connectionString) dbHandle.SetMaxOpenConns(1) } else { - logger.Warn(logSender, "error creating sqlite database handler, connection string: %v, error: %v", connectionString, err) + logger.Warn(logSender, "error creating sqlite database handler, connection string: \"%v\", error: %v", connectionString, err) } return err } diff --git a/sftpd/server.go b/sftpd/server.go index f015e51b..85a347b9 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -88,7 +88,7 @@ func (c Configuration) Initialize(configDir string) error { listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort)) if err != nil { - logger.Warn(logSender, "error starting listener on address %v: %v", listener.Addr().String(), err) + logger.Warn(logSender, "error starting listener on address %s:%d: %v", c.BindAddress, c.BindPort, err) return err } diff --git a/sftpd/sftpd.go b/sftpd/sftpd.go index d06d9832..0fdbe897 100644 --- a/sftpd/sftpd.go +++ b/sftpd/sftpd.go @@ -247,7 +247,6 @@ func removeTransfer(transfer *Transfer) { } } if indexToRemove >= 0 { - //logger.Debug(logSender, "remove index %v from active transfer, size: %v", indexToRemove, len(activeTransfers)) activeTransfers[indexToRemove] = activeTransfers[len(activeTransfers)-1] activeTransfers = activeTransfers[:len(activeTransfers)-1] } else { @@ -259,29 +258,7 @@ func updateConnectionActivity(id string) { mutex.Lock() defer mutex.Unlock() if c, ok := openConnections[id]; ok { - //logger.Debug(logSender, "update connection activity, id: %v", id) c.lastActivity = time.Now() openConnections[id] = c } - //logger.Debug(logSender, "connection activity updated: %+v", openConnections) -} - -func logConnections() { - mutex.RLock() - defer mutex.RUnlock() - for _, c := range openConnections { - logger.Debug(logSender, "active connection %+v", c) - } -} - -func logTransfers() { - mutex.RLock() - defer mutex.RUnlock() - if len(activeTransfers) > 0 { - for _, v := range activeTransfers { - logger.Debug(logSender, "active transfer: %+v", v) - } - } else { - logger.Debug(logSender, "no active transfer") - } } diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 45269f19..5d3e8930 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -11,17 +11,25 @@ import ( "path/filepath" "runtime" "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/config" "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/sftpd" "github.com/pkg/sftp" + "github.com/rs/zerolog" ) -// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API - const ( + logSender = "sftpdTesting" sftpServerAddr = "127.0.0.1:2022" defaultUsername = "test_user_sftp" defaultPassword = "test_password" @@ -71,12 +79,73 @@ var ( homeBasePath string ) -func init() { +func TestMain(m *testing.M) { if runtime.GOOS == "windows" { homeBasePath = "C:\\" } else { homeBasePath = "/tmp" } + configDir := ".." + logfilePath := filepath.Join(configDir, "sftpgo_sftpd_test.log") + confName := "sftpgo.conf" + logger.InitLogger(logfilePath, zerolog.DebugLevel) + configFilePath := filepath.Join(configDir, confName) + config.LoadConfig(configFilePath) + providerConf := config.GetProviderConf() + + err := dataprovider.Initialize(providerConf, configDir) + if err != nil { + logger.Warn(logSender, "error initializing data provider: %v", err) + os.Exit(1) + } + dataProvider := dataprovider.GetProvider() + sftpdConf := config.GetSFTPDConfig() + httpdConf := config.GetHTTPDConfig() + router := api.GetHTTPRouter() + + sftpd.SetDataProvider(dataProvider) + api.SetDataProvider(dataProvider) + + go func() { + logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf) + if err := sftpdConf.Initialize(configDir); err != nil { + logger.Error(logSender, "could not start SFTP server: %v", err) + } + }() + + go func() { + logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf) + s := &http.Server{ + Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort), + Handler: router, + ReadTimeout: 300 * time.Second, + WriteTimeout: 300 * time.Second, + MaxHeaderBytes: 1 << 20, // 1MB + } + if err := s.ListenAndServe(); err != nil { + logger.Error(logSender, "could not start HTTP server: %v", err) + } + }() + + waitTCPListening(fmt.Sprintf("%s:%d", sftpdConf.BindAddress, sftpdConf.BindPort)) + waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort)) + + exitCode := m.Run() + os.Remove(logfilePath) + os.Exit(exitCode) +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + fmt.Printf("tcp server %v not listening: %v\n", address, err) + continue + } + fmt.Printf("tcp server %v now listening\n", address) + defer conn.Close() + break + } } func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {