فهرست منبع

Improve test cases and logging

Nicola Murino 6 سال پیش
والد
کامیت
417b173c78
11فایلهای تغییر یافته به همراه654 افزوده شده و 56 حذف شده
  1. 3 8
      .travis.yml
  2. 380 8
      api/api_test.go
  3. 7 7
      api/api_utils.go
  4. 138 0
      api/internal_test.go
  5. 47 0
      config/config_test.go
  6. 2 2
      dataprovider/mysql.go
  7. 2 2
      dataprovider/pgsql.go
  8. 2 2
      dataprovider/sqlite.go
  9. 1 1
      sftpd/server.go
  10. 0 23
      sftpd/sftpd.go
  11. 72 3
      sftpd/sftpd_test.go

+ 3 - 8
.travis.yml

@@ -13,12 +13,7 @@ install:
   - go get -v -t ./...
 
 script:
-  - sftpgo &
-  # Wait some seconds to be sure that sftpgo is started and listening 
-  - sleep 2
-  - go test -v api/api_test.go
-  - go test -v sftpd/sftpd_test.go
+  - go test -v ./... -coverprofile=coverage.txt -covermode=atomic
 
-#test cases run against a real server and the coverage is not detected
-#after_success:
-#  - bash <(curl -s https://codecov.io/bash)
+after_success:
+  - bash <(curl -s https://codecov.io/bash)

+ 380 - 8
api/api_test.go

@@ -1,34 +1,112 @@
 package api_test
 
 import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"net"
 	"net/http"
+	"net/http/httptest"
+	"os"
 	"path/filepath"
 	"runtime"
+	"strconv"
 	"testing"
+	"time"
+
+	"github.com/go-chi/render"
+	_ "github.com/go-sql-driver/mysql"
+	_ "github.com/lib/pq"
+	_ "github.com/mattn/go-sqlite3"
+	"github.com/rs/zerolog"
 
 	"github.com/drakkan/sftpgo/api"
+	"github.com/drakkan/sftpgo/config"
 	"github.com/drakkan/sftpgo/dataprovider"
+	"github.com/drakkan/sftpgo/logger"
+	"github.com/drakkan/sftpgo/sftpd"
 )
 
-// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API
-
 const (
-	defaultUsername = "test_user"
-	defaultPassword = "test_password"
-	testPubKey      = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"
+	defaultUsername       = "test_user"
+	defaultPassword       = "test_password"
+	testPubKey            = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"
+	logSender             = "APITesting"
+	userPath              = "/api/v1/user"
+	activeConnectionsPath = "/api/v1/sftp_connection"
+	quotaScanPath         = "/api/v1/quota_scan"
 )
 
 var (
 	defaultPerms = []string{dataprovider.PermAny}
 	homeBasePath string
+	testServer   *httptest.Server
 )
 
-func init() {
+func TestMain(m *testing.M) {
 	if runtime.GOOS == "windows" {
 		homeBasePath = "C:\\"
 	} else {
 		homeBasePath = "/tmp"
 	}
+	configDir := ".."
+	logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")
+	confName := "sftpgo.conf"
+	logger.InitLogger(logfilePath, zerolog.DebugLevel)
+	configFilePath := filepath.Join(configDir, confName)
+	config.LoadConfig(configFilePath)
+	providerConf := config.GetProviderConf()
+
+	err := dataprovider.Initialize(providerConf, configDir)
+	if err != nil {
+		logger.Warn(logSender, "error initializing data provider: %v", err)
+		os.Exit(1)
+	}
+	dataProvider := dataprovider.GetProvider()
+	httpdConf := config.GetHTTPDConfig()
+	router := api.GetHTTPRouter()
+
+	httpdConf.BindPort = 8081
+	api.SetBaseURL("http://127.0.0.1:8081")
+
+	sftpd.SetDataProvider(dataProvider)
+	api.SetDataProvider(dataProvider)
+
+	go func() {
+		logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf)
+		s := &http.Server{
+			Addr:           fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort),
+			Handler:        router,
+			ReadTimeout:    300 * time.Second,
+			WriteTimeout:   300 * time.Second,
+			MaxHeaderBytes: 1 << 20, // 1MB
+		}
+		if err := s.ListenAndServe(); err != nil {
+			logger.Error(logSender, "could not start HTTP server: %v", err)
+		}
+	}()
+
+	testServer = httptest.NewServer(api.GetHTTPRouter())
+	defer testServer.Close()
+
+	waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort))
+
+	exitCode := m.Run()
+	os.Remove(logfilePath)
+	os.Exit(exitCode)
+}
+
+func waitTCPListening(address string) {
+	for {
+		conn, err := net.Dial("tcp", address)
+		if err != nil {
+			fmt.Printf("tcp server %v not listening: %v\n", address, err)
+			continue
+		}
+		fmt.Printf("tcp server %v now listening\n", address)
+		defer conn.Close()
+		break
+	}
 }
 
 func getTestUser() dataprovider.User {
@@ -40,6 +118,27 @@ func getTestUser() dataprovider.User {
 	}
 }
 
+func getUserAsJSON(t *testing.T, user dataprovider.User) []byte {
+	json, err := json.Marshal(user)
+	if err != nil {
+		t.Errorf("error get user as json: %v", err)
+		return []byte("{}")
+	}
+	return json
+}
+
+func executeRequest(req *http.Request) *httptest.ResponseRecorder {
+	rr := httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	return rr
+}
+
+func checkResponseCode(t *testing.T, expected, actual int) {
+	if expected != actual {
+		t.Errorf("Expected response code %d. Got %d", expected, actual)
+	}
+}
+
 func TestBasicUserHandling(t *testing.T) {
 	user, err := api.AddUser(getTestUser(), http.StatusOK)
 	if err != nil {
@@ -256,14 +355,14 @@ func TestGetUsers(t *testing.T) {
 		t.Errorf("unable to get users: %v", err)
 	}
 	if len(users) != 1 {
-		t.Errorf("1 user are expected")
+		t.Errorf("1 user is expected")
 	}
 	users, err = api.GetUsers(1, 1, "", http.StatusOK)
 	if err != nil {
 		t.Errorf("unable to get users: %v", err)
 	}
 	if len(users) != 1 {
-		t.Errorf("1 user are expected")
+		t.Errorf("1 user is expected")
 	}
 	err = api.RemoveUser(user1, http.StatusOK)
 	if err != nil {
@@ -296,3 +395,276 @@ func TestStartQuotaScan(t *testing.T) {
 		t.Errorf("unable to remove user: %v", err)
 	}
 }
+
+// test using mock http server
+
+func TestBasicUserHandlingMock(t *testing.T) {
+	user := getTestUser()
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	err := render.DecodeJSON(rr.Body, &user)
+	if err != nil {
+		t.Errorf("Error get user: %v", err)
+	}
+	req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusInternalServerError, rr.Code)
+	user.MaxSessions = 10
+	user.UploadBandwidth = 128
+	userAsJSON = getUserAsJSON(t, user)
+	req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+
+	req, _ = http.NewRequest(http.MethodGet, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+
+	var updatedUser dataprovider.User
+	err = render.DecodeJSON(rr.Body, &updatedUser)
+	if err != nil {
+		t.Errorf("Error decoding updated user: %v", err)
+	}
+	if user.MaxSessions != updatedUser.MaxSessions || user.UploadBandwidth != updatedUser.UploadBandwidth {
+		t.Errorf("Error modifying user actual: %v, %v", updatedUser.MaxSessions, updatedUser.UploadBandwidth)
+	}
+	req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestGetUserByIdInvalidParamsMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodGet, userPath+"/0", nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusNotFound, rr.Code)
+	req, _ = http.NewRequest(http.MethodGet, userPath+"/a", nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestAddUserNoUsernameMock(t *testing.T) {
+	user := getTestUser()
+	user.Username = ""
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestAddUserInvalidHomeDirMock(t *testing.T) {
+	user := getTestUser()
+	user.HomeDir = "relative_path"
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestAddUserInvalidPermsMock(t *testing.T) {
+	user := getTestUser()
+	user.Permissions = []string{}
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestAddUserInvalidJsonMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer([]byte("invalid json")))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestUpdateUserInvalidJsonMock(t *testing.T) {
+	user := getTestUser()
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	err := render.DecodeJSON(rr.Body, &user)
+	if err != nil {
+		t.Errorf("Error get user: %v", err)
+	}
+	req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer([]byte("Invalid json")))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+	req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestUpdateUserInvalidParamsMock(t *testing.T) {
+	user := getTestUser()
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	err := render.DecodeJSON(rr.Body, &user)
+	if err != nil {
+		t.Errorf("Error get user: %v", err)
+	}
+	user.HomeDir = ""
+	userAsJSON = getUserAsJSON(t, user)
+	req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+	userID := user.ID
+	user.ID = 0
+	userAsJSON = getUserAsJSON(t, user)
+	req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(userID, 10), bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+	user.ID = userID
+	req, _ = http.NewRequest(http.MethodPut, userPath+"/0", bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusNotFound, rr.Code)
+	req, _ = http.NewRequest(http.MethodPut, userPath+"/a", bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+	req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestGetUsersMock(t *testing.T) {
+	user := getTestUser()
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	err := render.DecodeJSON(rr.Body, &user)
+	if err != nil {
+		t.Errorf("Error get user: %v", err)
+	}
+	req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=510&offset=0&order=ASC&username="+defaultUsername, nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	var users []dataprovider.User
+	err = render.DecodeJSON(rr.Body, &users)
+	if err != nil {
+		t.Errorf("Error decoding users: %v", err)
+	}
+	if len(users) != 1 {
+		t.Errorf("1 user is expected")
+	}
+	req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=a&offset=0&order=ASC", nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+	req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=a&order=ASC", nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+	req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=0&order=ASCa", nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+
+	req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestDeleteUserInvalidParamsMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodDelete, userPath+"/0", nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusNotFound, rr.Code)
+	req, _ = http.NewRequest(http.MethodDelete, userPath+"/a", nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestGetQuotaScansMock(t *testing.T) {
+	req, err := http.NewRequest("GET", quotaScanPath, nil)
+	if err != nil {
+		t.Errorf("error get quota scan: %v", err)
+	}
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestStartQuotaScanMock(t *testing.T) {
+	user := getTestUser()
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	err := render.DecodeJSON(rr.Body, &user)
+	if err != nil {
+		t.Errorf("Error get user: %v", err)
+	}
+	_, err = os.Stat(user.HomeDir)
+	if err == nil {
+		os.Remove(user.HomeDir)
+	}
+	userAsJSON = getUserAsJSON(t, user)
+	req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusCreated, rr.Code)
+
+	req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+	var scans []sftpd.ActiveQuotaScan
+	err = render.DecodeJSON(rr.Body, &scans)
+	if err != nil {
+		t.Errorf("Error get active scans: %v", err)
+	}
+	for len(scans) > 0 {
+		req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil)
+		rr = executeRequest(req)
+		checkResponseCode(t, http.StatusOK, rr.Code)
+		err = render.DecodeJSON(rr.Body, &scans)
+		if err != nil {
+			t.Errorf("Error get active scans: %v", err)
+			break
+		}
+	}
+	_, err = os.Stat(user.HomeDir)
+	if err != nil && os.IsNotExist(err) {
+		os.MkdirAll(user.HomeDir, 0777)
+	}
+	req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON))
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusCreated, rr.Code)
+	req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestStartQuotaScanBadUserMock(t *testing.T) {
+	user := getTestUser()
+	userAsJSON := getUserAsJSON(t, user)
+	req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusNotFound, rr.Code)
+}
+
+func TestStartQuotaScanNonExistentUserMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer([]byte("invalid json")))
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusBadRequest, rr.Code)
+}
+
+func TestGetSFTPConnectionsMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr.Code)
+}
+
+func TestDeleteActiveConnectionMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusNotFound, rr.Code)
+}
+
+func TestNotFoundMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodGet, "/non/existing/path", nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusNotFound, rr.Code)
+}
+
+func TestMethodNotAllowedMock(t *testing.T) {
+	req, _ := http.NewRequest(http.MethodPost, activeConnectionsPath, nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code)
+}

+ 7 - 7
api/api_utils.go

@@ -15,14 +15,16 @@ import (
 	"github.com/go-chi/render"
 )
 
-const (
-	httpBaseURL = "http://127.0.0.1:8080"
-)
-
 var (
 	defaultPerms = []string{dataprovider.PermAny}
+	httpBaseURL  = "http://127.0.0.1:8080"
 )
 
+// SetBaseURL sets the url to use for HTTP request, default is "http://127.0.0.1:8080"
+func SetBaseURL(url string) {
+	httpBaseURL = url
+}
+
 // AddUser add a new user, useful for tests
 func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) {
 	var newUser dataprovider.User
@@ -170,7 +172,7 @@ func checkResponse(actual int, expected int, resp *http.Response) error {
 	if expected != actual {
 		return fmt.Errorf("wrong status code: got %v want %v", actual, expected)
 	}
-	if expected != http.StatusOK {
+	if expected != http.StatusOK && resp != nil {
 		b, err := ioutil.ReadAll(resp.Body)
 		if err == nil {
 			fmt.Printf("request: %v, response body: %v", resp.Request.URL, string(b))
@@ -189,8 +191,6 @@ func checkUser(expected dataprovider.User, actual dataprovider.User) error {
 	if expected.ID <= 0 {
 		if actual.ID <= 0 {
 			return errors.New("actual user ID must be > 0")
-		} else if actual.ID <= 0 {
-			return errors.New("user ID must be >=0")
 		}
 	} else {
 		if actual.ID != expected.ID {

+ 138 - 0
api/internal_test.go

@@ -0,0 +1,138 @@
+package api
+
+import (
+	"fmt"
+	"net/http"
+	"testing"
+
+	"github.com/drakkan/sftpgo/dataprovider"
+)
+
+func TestGetRespStatus(t *testing.T) {
+	var err error
+	err = &dataprovider.MethodDisabledError{}
+	respStatus := getRespStatus(err)
+	if respStatus != http.StatusForbidden {
+		t.Errorf("wrong resp status extected: %d got: %d", http.StatusForbidden, respStatus)
+	}
+	err = fmt.Errorf("generic error")
+	respStatus = getRespStatus(err)
+	if respStatus != http.StatusInternalServerError {
+		t.Errorf("wrong resp status extected: %d got: %d", http.StatusInternalServerError, respStatus)
+	}
+}
+
+func TestCheckResponse(t *testing.T) {
+	err := checkResponse(200, 201, nil)
+	if err == nil {
+		t.Errorf("check must fail")
+	}
+	err = checkResponse(400, 400, nil)
+	if err != nil {
+		t.Errorf("test must succeed, error: %v", err)
+	}
+}
+
+func TestCheckUser(t *testing.T) {
+	expected := dataprovider.User{}
+	actual := dataprovider.User{}
+	actual.Password = "password"
+	err := checkUser(expected, actual)
+	if err == nil {
+		t.Errorf("actual password must be nil")
+	}
+	actual.Password = ""
+	actual.PublicKey = "pub key"
+	err = checkUser(expected, actual)
+	if err == nil {
+		t.Errorf("actual public key must be nil")
+	}
+	actual.PublicKey = ""
+	err = checkUser(expected, actual)
+	if err == nil {
+		t.Errorf("actual ID must be > 0")
+	}
+	expected.ID = 1
+	actual.ID = 2
+	err = checkUser(expected, actual)
+	if err == nil {
+		t.Errorf("actual ID must be equal to expected ID")
+	}
+	expected.ID = 2
+	actual.ID = 2
+	expected.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload}
+	actual.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks}
+	err = checkUser(expected, actual)
+	if err == nil {
+		t.Errorf("Permissions are not equal")
+	}
+	expected.Permissions = append(expected.Permissions, dataprovider.PermRename)
+	err = checkUser(expected, actual)
+	if err == nil {
+		t.Errorf("Permissions are not equal")
+	}
+}
+
+func TestCompareUserFields(t *testing.T) {
+	expected := dataprovider.User{}
+	actual := dataprovider.User{}
+	expected.Username = "test"
+	err := compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("Username does not match")
+	}
+	expected.Username = ""
+	expected.HomeDir = "homedir"
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("HomeDir does not match")
+	}
+	expected.HomeDir = ""
+	expected.UID = 1
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("UID does not match")
+	}
+	expected.UID = 0
+	expected.GID = 1
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("GID does not match")
+	}
+	expected.GID = 0
+	expected.MaxSessions = 2
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("MaxSessions do not match")
+	}
+	expected.MaxSessions = 0
+	expected.QuotaSize = 4096
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("QuotaSize does not match")
+	}
+	expected.QuotaSize = 0
+	expected.QuotaFiles = 2
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("QuotaFiles do not match")
+	}
+	expected.QuotaFiles = 0
+	expected.Permissions = []string{dataprovider.PermCreateDirs}
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("Permissions are not equal")
+	}
+	expected.Permissions = nil
+	expected.UploadBandwidth = 64
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("UploadBandwidth does not match")
+	}
+	expected.UploadBandwidth = 0
+	expected.DownloadBandwidth = 128
+	err = compareEqualsUserFields(expected, actual)
+	if err == nil {
+		t.Errorf("DownloadBandwidth does not match")
+	}
+}

+ 47 - 0
config/config_test.go

@@ -0,0 +1,47 @@
+package config_test
+
+import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/drakkan/sftpgo/api"
+	"github.com/drakkan/sftpgo/config"
+	"github.com/drakkan/sftpgo/dataprovider"
+	"github.com/drakkan/sftpgo/sftpd"
+)
+
+func TestLoadConfigTest(t *testing.T) {
+	configDir := ".."
+	confName := "sftpgo.conf"
+	configFilePath := filepath.Join(configDir, confName)
+	err := config.LoadConfig(configFilePath)
+	if err != nil {
+		t.Errorf("error loading config")
+	}
+	emptyHTTPDConf := api.HTTPDConf{}
+	if config.GetHTTPDConfig() == emptyHTTPDConf {
+		t.Errorf("error loading httpd conf")
+	}
+	emptyProviderConf := dataprovider.Config{}
+	if config.GetProviderConf() == emptyProviderConf {
+		t.Errorf("error loading provider conf")
+	}
+	emptySFTPDConf := sftpd.Configuration{}
+	if config.GetSFTPDConfig() == emptySFTPDConf {
+		t.Errorf("error loading SFTPD conf")
+	}
+	confName = "sftpgo.conf.missing"
+	configFilePath = filepath.Join(configDir, confName)
+	err = config.LoadConfig(configFilePath)
+	if err == nil {
+		t.Errorf("loading a non existent config file must fail")
+	}
+	ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666)
+	err = config.LoadConfig(configFilePath)
+	if err == nil {
+		t.Errorf("loading an invalid config file must fail")
+	}
+	os.Remove(configFilePath)
+}

+ 2 - 2
dataprovider/mysql.go

@@ -25,12 +25,12 @@ func initializeMySQLProvider() error {
 	dbHandle, err = sql.Open("mysql", connectionString)
 	if err == nil {
 		numCPU := runtime.NumCPU()
-		logger.Debug(logSender, "mysql database handle created, connection string: %v, connections: %v", connectionString, numCPU)
+		logger.Debug(logSender, "mysql database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU)
 		dbHandle.SetMaxIdleConns(numCPU)
 		dbHandle.SetMaxOpenConns(numCPU)
 		dbHandle.SetConnMaxLifetime(1800 * time.Second)
 	} else {
-		logger.Warn(logSender, "error creating mysql database handler, connection string: %v, error: %v", connectionString, err)
+		logger.Warn(logSender, "error creating mysql database handler, connection string: \"%v\", error: %v", connectionString, err)
 	}
 	return err
 }

+ 2 - 2
dataprovider/pgsql.go

@@ -24,11 +24,11 @@ func initializePGSQLProvider() error {
 	dbHandle, err = sql.Open("postgres", connectionString)
 	if err == nil {
 		numCPU := runtime.NumCPU()
-		logger.Debug(logSender, "postgres database handle created, connection string: %v, connections: %v", connectionString, numCPU)
+		logger.Debug(logSender, "postgres database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU)
 		dbHandle.SetMaxIdleConns(numCPU)
 		dbHandle.SetMaxOpenConns(numCPU)
 	} else {
-		logger.Warn(logSender, "error creating postgres database handler, connection string: %v, error: %v", connectionString, err)
+		logger.Warn(logSender, "error creating postgres database handler, connection string: \"%v\", error: %v", connectionString, err)
 	}
 	return err
 }

+ 2 - 2
dataprovider/sqlite.go

@@ -35,10 +35,10 @@ func initializeSQLiteProvider(basePath string) error {
 	}
 	dbHandle, err = sql.Open("sqlite3", connectionString)
 	if err == nil {
-		logger.Debug(logSender, "sqlite database handle created, connection string: %v", connectionString)
+		logger.Debug(logSender, "sqlite database handle created, connection string: \"%v\"", connectionString)
 		dbHandle.SetMaxOpenConns(1)
 	} else {
-		logger.Warn(logSender, "error creating sqlite database handler, connection string: %v, error: %v", connectionString, err)
+		logger.Warn(logSender, "error creating sqlite database handler, connection string: \"%v\", error: %v", connectionString, err)
 	}
 	return err
 }

+ 1 - 1
sftpd/server.go

@@ -88,7 +88,7 @@ func (c Configuration) Initialize(configDir string) error {
 
 	listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
 	if err != nil {
-		logger.Warn(logSender, "error starting listener on address %v: %v", listener.Addr().String(), err)
+		logger.Warn(logSender, "error starting listener on address %s:%d: %v", c.BindAddress, c.BindPort, err)
 		return err
 	}
 

+ 0 - 23
sftpd/sftpd.go

@@ -247,7 +247,6 @@ func removeTransfer(transfer *Transfer) {
 		}
 	}
 	if indexToRemove >= 0 {
-		//logger.Debug(logSender, "remove index %v from active transfer, size: %v", indexToRemove, len(activeTransfers))
 		activeTransfers[indexToRemove] = activeTransfers[len(activeTransfers)-1]
 		activeTransfers = activeTransfers[:len(activeTransfers)-1]
 	} else {
@@ -259,29 +258,7 @@ func updateConnectionActivity(id string) {
 	mutex.Lock()
 	defer mutex.Unlock()
 	if c, ok := openConnections[id]; ok {
-		//logger.Debug(logSender, "update connection activity, id: %v", id)
 		c.lastActivity = time.Now()
 		openConnections[id] = c
 	}
-	//logger.Debug(logSender, "connection activity updated: %+v", openConnections)
-}
-
-func logConnections() {
-	mutex.RLock()
-	defer mutex.RUnlock()
-	for _, c := range openConnections {
-		logger.Debug(logSender, "active connection %+v", c)
-	}
-}
-
-func logTransfers() {
-	mutex.RLock()
-	defer mutex.RUnlock()
-	if len(activeTransfers) > 0 {
-		for _, v := range activeTransfers {
-			logger.Debug(logSender, "active transfer: %+v", v)
-		}
-	} else {
-		logger.Debug(logSender, "no active transfer")
-	}
 }

+ 72 - 3
sftpd/sftpd_test.go

@@ -11,17 +11,25 @@ import (
 	"path/filepath"
 	"runtime"
 	"testing"
+	"time"
+
+	_ "github.com/go-sql-driver/mysql"
+	_ "github.com/lib/pq"
+	_ "github.com/mattn/go-sqlite3"
 
 	"golang.org/x/crypto/ssh"
 
 	"github.com/drakkan/sftpgo/api"
+	"github.com/drakkan/sftpgo/config"
 	"github.com/drakkan/sftpgo/dataprovider"
+	"github.com/drakkan/sftpgo/logger"
+	"github.com/drakkan/sftpgo/sftpd"
 	"github.com/pkg/sftp"
+	"github.com/rs/zerolog"
 )
 
-// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API
-
 const (
+	logSender       = "sftpdTesting"
 	sftpServerAddr  = "127.0.0.1:2022"
 	defaultUsername = "test_user_sftp"
 	defaultPassword = "test_password"
@@ -71,12 +79,73 @@ var (
 	homeBasePath string
 )
 
-func init() {
+func TestMain(m *testing.M) {
 	if runtime.GOOS == "windows" {
 		homeBasePath = "C:\\"
 	} else {
 		homeBasePath = "/tmp"
 	}
+	configDir := ".."
+	logfilePath := filepath.Join(configDir, "sftpgo_sftpd_test.log")
+	confName := "sftpgo.conf"
+	logger.InitLogger(logfilePath, zerolog.DebugLevel)
+	configFilePath := filepath.Join(configDir, confName)
+	config.LoadConfig(configFilePath)
+	providerConf := config.GetProviderConf()
+
+	err := dataprovider.Initialize(providerConf, configDir)
+	if err != nil {
+		logger.Warn(logSender, "error initializing data provider: %v", err)
+		os.Exit(1)
+	}
+	dataProvider := dataprovider.GetProvider()
+	sftpdConf := config.GetSFTPDConfig()
+	httpdConf := config.GetHTTPDConfig()
+	router := api.GetHTTPRouter()
+
+	sftpd.SetDataProvider(dataProvider)
+	api.SetDataProvider(dataProvider)
+
+	go func() {
+		logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf)
+		if err := sftpdConf.Initialize(configDir); err != nil {
+			logger.Error(logSender, "could not start SFTP server: %v", err)
+		}
+	}()
+
+	go func() {
+		logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf)
+		s := &http.Server{
+			Addr:           fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort),
+			Handler:        router,
+			ReadTimeout:    300 * time.Second,
+			WriteTimeout:   300 * time.Second,
+			MaxHeaderBytes: 1 << 20, // 1MB
+		}
+		if err := s.ListenAndServe(); err != nil {
+			logger.Error(logSender, "could not start HTTP server: %v", err)
+		}
+	}()
+
+	waitTCPListening(fmt.Sprintf("%s:%d", sftpdConf.BindAddress, sftpdConf.BindPort))
+	waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort))
+
+	exitCode := m.Run()
+	os.Remove(logfilePath)
+	os.Exit(exitCode)
+}
+
+func waitTCPListening(address string) {
+	for {
+		conn, err := net.Dial("tcp", address)
+		if err != nil {
+			fmt.Printf("tcp server %v not listening: %v\n", address, err)
+			continue
+		}
+		fmt.Printf("tcp server %v now listening\n", address)
+		defer conn.Close()
+		break
+	}
 }
 
 func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {