diff --git a/README.md b/README.md new file mode 100644 index 00000000..2f2a025b --- /dev/null +++ b/README.md @@ -0,0 +1,200 @@ +# SFTPGo + +Full featured and highly configurable SFTP server software + +## Features + +- Each account is chrooted to his Home Dir +- SFTP accounts are virtual accounts stored in a "data provider" +- SQLite, MySQL and PostgreSQL data providers are supported. The `Provider` interface could be extented to support non SQL backends too +- Public key and password authentication +- Quota support: accounts can have individual quota expressed as max number of files and max total size +- Bandwidth throttling is supported, with distinct settings for upload and download +- Per user maximum concurrent sessions +- Per user permissions: list directories content, upload, download, delete, rename, create directories, create symlinks can be enabled or disabled +- Per user files/folders ownership: you can map all the users to the system account that runs SFTPGo (all platforms are supported) or you can run SFTPGo as root user and map each user or group of users to a different system account (*NIX only) +- REST API for users and quota management and real time reports for the active connections with possibility of forcibly closing a connection +- Log files are accurate and they are saved in the easily parsable JSON format +- Automatically terminating idle connections + +## Platforms + +SFTPGo is developed and tested on Linux, regularly the test cases are executed and pass on macOS and Windows. +Other UNIX variants such as *BSD should work too. + +## Requirements + +- Go 1.12 or higher +- A suitable SQL server to use as data provider: MySQL (4.1+) or SQLite 3.x or PostreSQL (10+) + +## Installation + +Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: + +```bash +$ go get -u github.com/drakkan/sftpgo +``` + +Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. + +A systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree. + +Alternately you can use distro packages: + +- Arch Linux PKGBUILD is available on [AUR](https://aur.archlinux.org/sftpgo-git "SFTPGo") + +## Configuration + +The `sftpgo` executable supports the following command line flags: + +- `-config-dir` string. Location of the config dir. This directory should contain the `sftpgo.conf` configuration file, the private key for the SFTP server (`id_rsa` file) and the SQLite database if you use SQLite as data provider. The server private key will be autogenerated if the user that executes SFTPGo has write access to the config-dir. The default value is "." +- `-log-file-path` string. Location for the log file, default "sftpgo.log" + +Before starting `sftpgo` a dataprovider must be configured. + +Sample SQL scripts to create the required database structure can be found insite the source tree [sql](https://github.com/drakkan/sftpgo/tree/master/sql "sql") directory. + +The `sftpgo.conf` configuration file contains the following sections: + +- **"sftpd"**, the configuration for the SFTP server + - `bind_port`, integer the port used for serving SFTP requests. Default: 2022 + - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: "" + - `idle_timeout`, integer. Time in minutes after which an idle client will be disconnected. Default: 15 + - `umask`, string. Umask for the new files and directories. This setting has no effect on Windows. Default: "0022" +- **"data_provider"**, the configuration for the data provider + - `driver`, string. Supported drivers are `sqlite`, `mysql`, `postgresql` + - `name`, string. Database name + - `host`, string. Database host. Leave empty for driver `sqlite` + - `port`, integer. Database port. Leave empty for driver `sqlite` + - `username`, string. Database user. Leave empty for driver `sqlite` + - `password`, string. Database password. Leave empty for driver `sqlite` + - `sslmode`, integer. Used for drivers `mysql` and `postgresql`. 0 disable SSL/TLS connections, 1 require ssl, 2 set ssl mode to `verify-ca` for driver `postgresql` and `skip-verify` for driver `mysql`, 3 set ssl mode to `verify-full` for driver `postgresql` and `preferred` for driver `mysql` + - `connectionstring`, string. Provide a custom database connection string. If not empty this connection string will be used instead of build one using the previous parameters + - `users_table`, string. Database table for SFTP users + - `manage_users`, integer. Set to 0 to disable users management, 1 to enable + - `track_quota`, integer. Set to 0 to disable quota tracking, 1 to update the used quota each time a user upload or delete a file +- **"httpd"**, the configuration for the HTTP server used to serve REST API + - `bind_port`, integer the port used for serving HTTP requests. Default: 8080 + - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: "127.0.0.1" + +Here is a full example showing the default config: + +```{ + "sftpd":{ + "bind_port":2022, + "bind_address": "", + "idle_timeout": 15, + "umask": "0022" + }, + "data_provider": { + "driver": "sqlite", + "name": "sftpgo.db", + "host": "", + "port": 5432, + "username": "", + "password": "", + "sslmode": 0, + "connection_string": "", + "users_table": "users", + "manage_users": 1, + "track_quota": 1 + }, + "httpd":{ + "bind_port":8080, + "bind_address": "127.0.0.1" + } +} +``` + +## Account's configuration properties + +For each account the following properties can be configured: + +- `username` +- `password` used for password authentication. The password with be stored using argon2id hashing algo +- `public_key` used for public key authentication. At least one between password and public key is mandatory +- `home_dir` The user cannot upload or download files outside this directory. Must be an absolute path +- `uid`, `gid`. If sftpgo runs as root then the created files and directories will be assigned to this system uid/gid. Ignored on windows and if sftpgo runs as non root user: in this case files and directories for all SFTP users will be owned by the system user that runs sftpgo +- `max_sessions` maximum concurrent sessions. 0 means unlimited +- `quota_size` maximum size allowed. 0 means unlimited +- `quota_files` maximum number of files allowed. 0 means unlimited +- `permissions` the following permissions are supported: + - `*` all permission are granted + - `list` list items is allowed + - `download` download files is allowed + - `upload` upload files is allowed + - `delete` delete files or directories is allowed + - `rename` rename files or directories is allowed + - `create_dirs` create directories is allowed + - `create_symlinks` create links is allowed +- `upload_bandwidth` maximum upload bandwidth as KB/s, 0 means unlimited +- `download_bandwidth` maximum download bandwidth as KB/s, 0 means unlimited + +These properties are stored inside the data provider. If you want to use your existing accounts you can create a database view. Since a view is read only you have to disable user management and quota tracking so sftpgo will never try to write to the view. + +## REST API + +SFTPGo exposes REST API to manage users and quota and to get real time reports for the active connections with possibility of forcibly closing a connection. + +If quota tracking is enabled in `sftpgo.conf` configuration file then the used size and number of files are updated each time a file is added/removed. If files are added/removed not using SFTP you can rescan the user home dir and update the used quota using the REST API. + +REST API is designed to run on localhost or on a trusted network, if you need https or authentication you can setup a reverse proxy using an HTTP Server such as Apache or NGNIX. + +The OpenAPI 3 schema for the exposed API can be found inside the source tree: [openapi.yaml](https://github.com/drakkan/sftpgo/tree/master/api/schema/openapi.yaml "OpenAPI 3 specs"). + +## Logs + +Inside the log file each line is a JSON struct, each struct has a `sender` fields that identify the log type. + +The logs can be divided into the following categories: + +- **"app logs"**, internal logs used to debug `sftpgo`: + - `sender` string. This is generally the package name that emits the log + - `time` string. Date/time with millisecond precision + - `level` string + - `message` string +- **"transfer logs"**, SFTP transfer logs: + - `sender` string. `SFTPUpload` or `SFTPDownload` + - `time` string. Date/time with millisecond precision + - `level` string + - `elapsed_ms`, int64. Elapsed time, as milliseconds, for the upload/download + - `size_bytes`, int64. Size, as bytes, of the download/upload + - `username`, string + - `file_path` string + - `connection_id` string. Unique SFTP connection identifier +- **"command logs"**, SFTP command logs: + - `sender` string. `SFTPRename`, `SFTPRmdir`, `SFTPMkdir`, `SFTPSymlink`, `SFTPRemove` + - `level` string + - `username`, string + - `file_path` string + - `target_path` string + - `connection_id` string. Unique SFTP connection identifier +- **"http logs"**, REST API logs: + - `sender` string. `httpd` + - `level` string + - `remote_addr` string. IP and port of the remote client + - `proto` string, for example `HTTP/1.1` + - `method` string. HTTP method (`GET`, `POST`, `PUT`, `DELETE` etc.) + - `user_agent` string + - `uri` string. Full uri + - `resp_status` integer. HTTP response status code + - `resp_size` integer. Size in bytes of the HTTP response + - `elapsed_ms` int64. Elapsed time, as milliseconds, to complete the request + - `request_id` string. Unique request identifier + +## Acknowledgements + +- [pkg/sftp](https://github.com/pkg/sftp) +- [go-chi](https://github.com/go-chi/chi) +- [zerolog](https://github.com/rs/zerolog) +- [lumberjack](https://gopkg.in/natefinch/lumberjack.v2) +- [argon2id](https://github.com/alexedwards/argon2id) +- [go-sqlite3](https://github.com/mattn/go-sqlite3) +- [go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) +- [lib/pq](https://github.com/lib/pq) + +Some code was initially taken from [Pterodactyl sftp server](https://github.com/pterodactyl/sftp-server) + +## License + +GNU GPLv3 \ No newline at end of file diff --git a/api/api.go b/api/api.go new file mode 100644 index 00000000..36fb5b49 --- /dev/null +++ b/api/api.go @@ -0,0 +1,69 @@ +package api + +import ( + "net/http" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/go-chi/chi" + "github.com/go-chi/render" +) + +const ( + logSender = "api" + activeConnectionsPath = "/api/v1/sftp_connection" + quotaScanPath = "/api/v1/quota_scan" + userPath = "/api/v1/user" +) + +var ( + router *chi.Mux + dataProvider dataprovider.Provider +) + +// HTTPDConf httpd daemon configuration +type HTTPDConf struct { + BindPort int `json:"bind_port"` + BindAddress string `json:"bind_address"` +} + +type apiResponse struct { + Error string `json:"error"` + Message string `json:"message"` + HTTPStatus int `json:"status"` +} + +func init() { + initializeRouter() +} + +// SetDataProvider sets the data provider +func SetDataProvider(provider dataprovider.Provider) { + dataProvider = provider +} + +func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { + var errorString string + if err != nil { + errorString = err.Error() + } + resp := apiResponse{ + Error: errorString, + Message: message, + HTTPStatus: code, + } + if code != http.StatusOK { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(code) + } + render.JSON(w, r, resp) +} + +func getRespStatus(err error) int { + if _, ok := err.(*dataprovider.ValidationError); ok { + return http.StatusBadRequest + } + if _, ok := err.(*dataprovider.MethodDisabledError); ok { + return http.StatusForbidden + } + return http.StatusInternalServerError +} diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 00000000..5688c6d0 --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,292 @@ +package api_test + +import ( + "net/http" + "path/filepath" + "runtime" + "testing" + + "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/dataprovider" +) + +// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API + +const ( + defaultUsername = "test_user" + defaultPassword = "test_password" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" +) + +var ( + defaultPerms = []string{dataprovider.PermAny} + homeBasePath string +) + +func init() { + if runtime.GOOS == "windows" { + homeBasePath = "C:\\" + } else { + homeBasePath = "/tmp" + } +} + +func getTestUser() dataprovider.User { + return dataprovider.User{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Permissions: defaultPerms, + } +} + +func TestBasicUserHandling(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.MaxSessions = 10 + user.QuotaSize = 4096 + user.QuotaFiles = 2 + user.UploadBandwidth = 128 + user.DownloadBandwidth = 64 + user, err = api.UpdateUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to update user: %v", err) + } + users, err := api.GetUsers(0, 0, defaultUsername, http.StatusOK) + if err != nil { + t.Errorf("unable to get users: %v", err) + } + if len(users) != 1 { + t.Errorf("number of users mismatch, expected: 1, actual: %v", len(users)) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove: %v", err) + } +} + +func TestAddUserNoCredentials(t *testing.T) { + u := getTestUser() + u.Password = "" + u.PublicKey = "" + _, err := api.AddUser(u, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error adding user with no credentials: %v", err) + } +} + +func TestAddUserNoUsername(t *testing.T) { + u := getTestUser() + u.Username = "" + _, err := api.AddUser(u, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error adding user with no home dir: %v", err) + } +} + +func TestAddUserNoHomeDir(t *testing.T) { + u := getTestUser() + u.HomeDir = "" + _, err := api.AddUser(u, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error adding user with no home dir: %v", err) + } +} + +func TestAddUserInvalidHomeDir(t *testing.T) { + u := getTestUser() + u.HomeDir = "relative_path" + _, err := api.AddUser(u, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error adding user with invalid home dir: %v", err) + } +} + +func TestAddUserNoPerms(t *testing.T) { + u := getTestUser() + u.Permissions = []string{} + _, err := api.AddUser(u, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error adding user with no perms: %v", err) + } +} + +func TestAddUserInvalidPerms(t *testing.T) { + u := getTestUser() + u.Permissions = []string{"invalidPerm"} + _, err := api.AddUser(u, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error adding user with no perms: %v", err) + } +} + +func TestUpdateUser(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.HomeDir = filepath.Join(homeBasePath, "testmod") + user.UID = 33 + user.GID = 101 + user.MaxSessions = 10 + user.QuotaSize = 4096 + user.QuotaFiles = 2 + user.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload} + user.UploadBandwidth = 1024 + user.DownloadBandwidth = 512 + user, err = api.UpdateUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to update user: %v", err) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove: %v", err) + } +} + +func TestUpdateUserNoCredentials(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.Password = "" + user.PublicKey = "" + // password and public key will be ommitted from json serialization if empty and so they will remain unchanged + // and no validation error will be raised + _, err = api.UpdateUser(user, http.StatusOK) + if err != nil { + t.Errorf("unexpected error updating user with no credentials: %v", err) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove: %v", err) + } +} + +func TestUpdateUserEmptyHomeDir(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.HomeDir = "" + _, err = api.UpdateUser(user, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error updating user with empty home dir: %v", err) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove: %v", err) + } +} + +func TestUpdateUserInvalidHomeDir(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.HomeDir = "relative_path" + _, err = api.UpdateUser(user, http.StatusBadRequest) + if err != nil { + t.Errorf("unexpected error updating user with empty home dir: %v", err) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove: %v", err) + } +} + +func TestUpdateNonExistentUser(t *testing.T) { + _, err := api.UpdateUser(getTestUser(), http.StatusNotFound) + if err != nil { + t.Errorf("unable to update user: %v", err) + } +} + +func TestGetNonExistentUser(t *testing.T) { + _, err := api.GetUserByID(0, http.StatusNotFound) + if err != nil { + t.Errorf("unable to get user: %v", err) + } +} + +func TestDeleteNonExistentUser(t *testing.T) { + err := api.RemoveUser(getTestUser(), http.StatusNotFound) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestAddDuplicateUser(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + _, err = api.AddUser(getTestUser(), http.StatusInternalServerError) + if err != nil { + t.Errorf("unable to add second user: %v", err) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestGetUsers(t *testing.T) { + user1, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + u := getTestUser() + u.Username = defaultUsername + "1" + user2, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add second user: %v", err) + } + users, err := api.GetUsers(0, 0, "", http.StatusOK) + if err != nil { + t.Errorf("unable to get users: %v", err) + } + if len(users) < 2 { + t.Errorf("at least 2 users are expected") + } + users, err = api.GetUsers(1, 0, "", http.StatusOK) + if len(users) != 1 { + t.Errorf("1 user are expected") + } + users, err = api.GetUsers(1, 1, "", http.StatusOK) + if len(users) != 1 { + t.Errorf("1 user are expected") + } + err = api.RemoveUser(user1, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } + err = api.RemoveUser(user2, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestGetQuotaScans(t *testing.T) { + _, err := api.GetQuotaScans(http.StatusOK) + if err != nil { + t.Errorf("unable to get quota scans: %v", err) + } +} + +func TestStartQuotaScan(t *testing.T) { + user, err := api.AddUser(getTestUser(), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + err = api.StartQuotaScan(user, http.StatusCreated) + if err != nil { + t.Errorf("unable to start quota scan: %v", err) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} diff --git a/api/api_utils.go b/api/api_utils.go new file mode 100644 index 00000000..2d734ea6 --- /dev/null +++ b/api/api_utils.go @@ -0,0 +1,236 @@ +package api + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "strconv" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/sftpd" + "github.com/drakkan/sftpgo/utils" + "github.com/go-chi/render" +) + +const ( + httpBaseURL = "http://127.0.0.1:8080" +) + +var ( + defaultPerms = []string{dataprovider.PermAny} +) + +// AddUser add a new user, useful for tests +func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) { + var newUser dataprovider.User + userAsJSON, err := json.Marshal(user) + if err != nil { + return newUser, err + } + resp, err := http.Post(httpBaseURL+userPath, "application/json", bytes.NewBuffer(userAsJSON)) + if err != nil { + return newUser, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode, resp) + if expectedStatusCode != http.StatusOK { + return newUser, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newUser) + } + if err == nil { + err = checkUser(user, newUser) + } + return newUser, err +} + +// UpdateUser update an user, useful for tests +func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, error) { + var newUser dataprovider.User + userAsJSON, err := json.Marshal(user) + if err != nil { + return user, err + } + req, err := http.NewRequest(http.MethodPut, httpBaseURL+userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON)) + if err != nil { + return user, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return user, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode, resp) + if expectedStatusCode != http.StatusOK { + return newUser, err + } + if err == nil { + newUser, err = GetUserByID(user.ID, expectedStatusCode) + } + if err == nil { + err = checkUser(user, newUser) + } + return newUser, err +} + +// RemoveUser remove user, useful for tests +func RemoveUser(user dataprovider.User, expectedStatusCode int) error { + req, err := http.NewRequest(http.MethodDelete, httpBaseURL+userPath+"/"+strconv.FormatInt(user.ID, 10), nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + return checkResponse(resp.StatusCode, expectedStatusCode, resp) +} + +// GetUserByID get user by id, useful for tests +func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, error) { + var user dataprovider.User + resp, err := http.Get(httpBaseURL + userPath + "/" + strconv.FormatInt(userID, 10)) + if err != nil { + return user, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode, resp) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &user) + } + return user, err +} + +// GetUsers useful for tests +func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, error) { + var users []dataprovider.User + req, err := http.NewRequest(http.MethodGet, httpBaseURL+userPath, nil) + if err != nil { + return users, err + } + q := req.URL.Query() + if limit > 0 { + q.Add("limit", strconv.FormatInt(limit, 10)) + } + if offset > 0 { + q.Add("offset", strconv.FormatInt(offset, 10)) + } + if len(username) > 0 { + q.Add("username", username) + } + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + if err != nil { + return users, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode, resp) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &users) + } + return users, err +} + +// GetQuotaScans get active quota scans, useful for tests +func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, error) { + var quotaScans []sftpd.ActiveQuotaScan + resp, err := http.Get(httpBaseURL + quotaScanPath) + if err != nil { + return quotaScans, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode, resp) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, "aScans) + } + return quotaScans, err +} + +// StartQuotaScan start a new quota scan +func StartQuotaScan(user dataprovider.User, expectedStatusCode int) error { + userAsJSON, err := json.Marshal(user) + if err != nil { + return err + } + resp, err := http.Post(httpBaseURL+quotaScanPath, "application/json", bytes.NewBuffer(userAsJSON)) + if err != nil { + return err + } + defer resp.Body.Close() + return checkResponse(resp.StatusCode, expectedStatusCode, resp) +} + +func checkResponse(actual int, expected int, resp *http.Response) error { + if expected != actual { + return fmt.Errorf("wrong status code: got %v want %v", actual, expected) + } + if expected != http.StatusOK { + b, err := ioutil.ReadAll(resp.Body) + if err == nil { + fmt.Printf("request: %v, response body: %v", resp.Request.URL, string(b)) + } + } + return nil +} + +func checkUser(expected dataprovider.User, actual dataprovider.User) error { + if len(actual.Password) > 0 { + return errors.New("User password must not be visible") + } + if len(actual.PublicKey) > 0 { + return errors.New("User public key must not be visible") + } + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual user ID must be > 0") + } else if actual.ID <= 0 { + return errors.New("user ID must be >=0") + } + } else { + if actual.ID != expected.ID { + return errors.New("user ID mismatch") + } + } + if expected.Username != actual.Username { + return errors.New("Username mismatch") + } + if expected.HomeDir != actual.HomeDir { + return errors.New("HomeDir mismatch") + } + if expected.UID != actual.UID { + return errors.New("UID mismatch") + } + if expected.GID != actual.GID { + return errors.New("GID mismatch") + } + if expected.MaxSessions != actual.MaxSessions { + return errors.New("MaxSessions mismatch") + } + if expected.QuotaSize != actual.QuotaSize { + return errors.New("QuotaSize mismatch") + } + if expected.QuotaFiles != actual.QuotaFiles { + return errors.New("QuotaFiles mismatch") + } + if len(expected.Permissions) != len(actual.Permissions) { + return errors.New("Permissions mismatch") + } + for _, v := range expected.Permissions { + if !utils.IsStringInSlice(v, actual.Permissions) { + return errors.New("Permissions contents mismatch") + } + } + if expected.UploadBandwidth != actual.UploadBandwidth { + return errors.New("UploadBandwidth mismatch") + } + if expected.DownloadBandwidth != actual.DownloadBandwidth { + return errors.New("DownloadBandwidth mismatch") + } + return nil +} diff --git a/api/quota.go b/api/quota.go new file mode 100644 index 00000000..449b0098 --- /dev/null +++ b/api/quota.go @@ -0,0 +1,47 @@ +package api + +import ( + "net/http" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/sftpd" + "github.com/drakkan/sftpgo/utils" + "github.com/go-chi/render" +) + +func getQuotaScans(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, sftpd.GetQuotaScans()) +} + +func startQuotaScan(w http.ResponseWriter, r *http.Request) { + var u dataprovider.User + err := render.DecodeJSON(r.Body, &u) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.UserExists(dataProvider, u.Username) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusNotFound) + return + } + if sftpd.AddQuotaScan(user.Username) { + sendAPIResponse(w, r, err, "Scan started", http.StatusCreated) + go func() { + numFiles, size, err := utils.ScanDirContents(user.HomeDir) + if err != nil { + logger.Warn(logSender, "error scanning user home dir %v: %v", user.HomeDir, err) + } else { + err := dataprovider.UpdateUserQuota(dataProvider, user.Username, numFiles, size, true) + if err != nil { + logger.Debug(logSender, "error updating user quota for %v: %v", user.Username, err) + } + logger.Debug(logSender, "user dir scanned, user: %v, dir: %v", user.Username, user.HomeDir) + } + sftpd.RemoveQuotaScan(user.Username) + }() + } else { + sendAPIResponse(w, r, err, "Another scan is already in progress", http.StatusConflict) + } +} diff --git a/api/router.go b/api/router.go new file mode 100644 index 00000000..22bebde1 --- /dev/null +++ b/api/router.go @@ -0,0 +1,77 @@ +package api + +import ( + "net/http" + + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/sftpd" + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/go-chi/render" +) + +// GetHTTPRouter returns the configured HTTP router +func GetHTTPRouter() http.Handler { + return router +} + +func initializeRouter() { + router = chi.NewRouter() + router.Use(middleware.RequestID) + router.Use(middleware.RealIP) + router.Use(logger.NewStructuredLogger(logger.GetLogger())) + router.Use(middleware.Recoverer) + + router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) + })) + + router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed) + })) + + router.Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, sftpd.GetConnectionsStats()) + }) + + router.Delete(activeConnectionsPath+"/{connectionID}", func(w http.ResponseWriter, r *http.Request) { + connectionID := chi.URLParam(r, "connectionID") + if connectionID == "" { + sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest) + return + } + if sftpd.CloseActiveConnection(connectionID) { + sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK) + } else { + sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) + } + }) + + router.Get(quotaScanPath, func(w http.ResponseWriter, r *http.Request) { + getQuotaScans(w, r) + }) + + router.Post(quotaScanPath, func(w http.ResponseWriter, r *http.Request) { + startQuotaScan(w, r) + }) + + router.Get(userPath, func(w http.ResponseWriter, r *http.Request) { + getUsers(w, r) + }) + + router.Post(userPath, func(w http.ResponseWriter, r *http.Request) { + addUser(w, r) + }) + + router.Get(userPath+"/{userID}", func(w http.ResponseWriter, r *http.Request) { + getUserByID(w, r) + }) + + router.Put(userPath+"/{userID}", func(w http.ResponseWriter, r *http.Request) { + updateUser(w, r) + }) + + router.Delete(userPath+"/{userID}", func(w http.ResponseWriter, r *http.Request) { + deleteUser(w, r) + }) +} diff --git a/api/schema/openapi.yaml b/api/schema/openapi.yaml new file mode 100644 index 00000000..a2c72021 --- /dev/null +++ b/api/schema/openapi.yaml @@ -0,0 +1,654 @@ +openapi: 3.0.1 +info: + title: SFTPGo + description: 'SFTPGo REST API' + version: 1.0.0 + +servers: +- url: /api/v1 +paths: + /sftp_connection: + get: + tags: + - connections + summary: Get the active sftp users and info about their uploads/downloads + operationId: get_connections + responses: + 200: + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref : '#/components/schemas/ConnectionStatus' + /sftp_connection/{connectionID}: + delete: + tags: + - connections + summary: Terminate an active SFTP connection + operationId: close_connection + parameters: + - name: connectionID + in: path + description: ID of the connection to close + required: true + schema: + type: string + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 200 + message: "Connection closed" + error: "" + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 404: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 404 + message: "" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" + /quota_scan: + get: + tags: + - quota + summary: Get the active quota scans + operationId: get_quota_scan + responses: + 200: + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref : '#/components/schemas/QuotaScan' + post: + tags: + - quota + summary: start a new quota scan + description: A quota scan update the number of files and their total size for the given user + operationId: start_quota_scan + requestBody: + required: true + content: + application/json: + schema: + $ref : '#/components/schemas/User' + responses: + 201: + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 201 + message: "Scan started" + error: "" + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 403: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 403 + message: "" + error: "Error description if any" + 404: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 404 + message: "" + error: "Error description if any" + 409: + description: Another scan is already in progress for this user + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 409 + message: "Another scan is already in progress" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" + /user: + get: + tags: + - users + summary: Returns an array with one or more SFTP users + description: For security reasons password and public key are empty in the response + operationId: get_users + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: The maximum number of items to return. Max value is 500, default is 100 + - in: query + name: order + required: false + description: Ordering users by username + schema: + type: string + enum: + - ASC + - DESC + example: ASC + - in: query + name: username + required: false + description: Filter by username, extact match case sensitive + schema: + type: string + responses: + 200: + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref : '#/components/schemas/User' + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 403: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 403 + message: "" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" + post: + tags: + - users + summary: Adds a new SFTP user + operationId: add_user + requestBody: + required: true + content: + application/json: + schema: + $ref : '#/components/schemas/User' + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref : '#/components/schemas/User' + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 403: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 403 + message: "" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" + /user/{userID}: + get: + tags: + - users + summary: Find user by ID + description: For security reasons password and public key are empty in the response + operationId: getUserByID + parameters: + - name: userID + in: path + description: ID of the user to retrieve + required: true + schema: + type: integer + format: int32 + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref : '#/components/schemas/User' + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 403: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 403 + message: "" + error: "Error description if any" + 404: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 404 + message: "" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" + put: + tags: + - users + summary: Update an user + operationId: updateUser + parameters: + - name: userID + in: path + description: ID of the user to update + required: true + schema: + type: integer + format: int32 + requestBody: + required: true + content: + application/json: + schema: + $ref : '#/components/schemas/User' + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref : '#/components/schemas/ApiResponse' + example: + status: 200 + message: "User updated" + error: "" + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 403: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 403 + message: "" + error: "Error description if any" + 404: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 404 + message: "" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" + delete: + tags: + - users + summary: Delete an user + operationId: deleteUser + parameters: + - name: userID + in: path + description: ID of the user to delete + required: true + schema: + type: integer + format: int32 + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref : '#/components/schemas/ApiResponse' + example: + status: 200 + message: "User deleted" + error: "" + 400: + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 400 + message: "" + error: "Error description if any" + 403: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 403 + message: "" + error: "Error description if any" + 404: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 404 + message: "" + error: "Error description if any" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + status: 500 + message: "" + error: "Error description if any" +components: + schemas: + Permission: + type: string + enum: + - '*' + - list + - download + - upload + - delete + - rename + - create_dirs + - create_symlinks + description: > + Permissions: + * `*` - all permission are granted + * `list` - list items is allowed + * `download` - download files is allowed + * `upload` - upload files is allowed + * `delete` - delete files or directories is allowed + * `rename` - rename files or directories is allowed + * `create_dirs` - create directories is allowed + * `create_symlinks` - create links is allowed + User: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + username: + type: string + password: + type: string + nullable: true + description: password or public key are mandatory. For security reasons this field is omitted when you search/get users + public_key: + type: string + nullable: true + description: password or public key are mandatory. For security reasons this field is omitted when you search/get users + home_dir: + type: string + description: path to the user home directory. The user cannot upload or download files outside this directory. SFTPGo tries to automatically create this folder if missing. Must be an absolute path + uid: + type: integer + format: int32 + minimum: 0 + maximum: 65535 + description: if you run sftpgo as root user the created files and directories will be assigned to this uid. 0 means no change, the owner will be the user that runs sftpgo. Ignored on windows + gid: + type: integer + format: int32 + minimum: 0 + maximum: 65535 + description: if you run sftpgo as root user the created files and directories will be assigned to this gid. 0 means no change, the group will be the one of the user that runs sftpgo. Ignored on windows + max_sessions: + type: integer + format: int32 + description: limit the sessions that an sftp user can open. 0 means unlimited + quota_size: + type: integer + format: int64 + description: quota as size. 0 menas unlimited. Please note that quota is updated if files are added/removed via sftp otherwise a quota scan is needed + quota_files: + type: integer + format: int32 + description: quota as number of files. 0 menas unlimited. Please note that quota is updated if files are added/removed via sftp otherwise a quota scan is needed + permissions: + type: array + items: + $ref: '#/components/schemas/Permission' + minItems: 1 + used_quota_size: + type: integer + format: int64 + used_quota_file: + type: integer + format: int32 + last_quota_scan: + type: integer + format: int64 + description: last quota scan as unix timestamp + upload_bandwidth: + type: integer + format: int32 + description: Maximum upload bandwidth as KB/s, 0 means unlimited + download_bandwidth: + type: integer + format: int32 + description: Maximum download bandwidth as KB/s, 0 means unlimited + SFTPTransfer: + type: object + properties: + operation_type: + type: string + enum: + - upload + - download + start_time: + type: integer + format: int64 + description: start time as unix timestamp + size: + type: integer + format: int64 + description: bytes transferred + last_activity: + type: integer + format: int64 + description: last transfer activity as unix timestamp + ConnectionStatus: + type: object + properties: + username: + type: string + description: connected username + connection_id: + type: string + description: unique sftp connection identifier + client_version: + type: string + description: SFTP client version + remote_address: + type: string + description: Remote address for the connected SFTP client + connection_time: + type: integer + format: int64 + description: connection time as unix timestamp + last_activity: + type: integer + format: int64 + description: last client activity as unix timestamp + active_transfers: + type: array + items: + $ref : '#/components/schemas/SFTPTransfer' + QuotaScan: + type: object + properties: + username: + type: string + description: username with an active scan + start_time: + type: integer + format: int64 + description: scan start time as unix timestamp + ApiResponse: + type: object + properties: + status: + type: integer + format: int32 + minimum: 200 + maximum: 500 + example: 200 + description: HTTP Status code, for example 200 OK, 400 Bad request and so on + message: + type: string + nullable: true + description: additional message if any + error: + type: string + nullable: true + description: error description if any diff --git a/api/user.go b/api/user.go new file mode 100644 index 00000000..5129bb4a --- /dev/null +++ b/api/user.go @@ -0,0 +1,152 @@ +package api + +import ( + "database/sql" + "errors" + "net/http" + "strconv" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/go-chi/chi" + "github.com/go-chi/render" +) + +func getUsers(w http.ResponseWriter, r *http.Request) { + limit := 100 + offset := 0 + order := "ASC" + username := "" + var err error + if _, ok := r.URL.Query()["limit"]; ok { + limit, err = strconv.Atoi(r.URL.Query().Get("limit")) + if err != nil { + err = errors.New("Invalid limit") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if limit > 500 { + limit = 500 + } + } + if _, ok := r.URL.Query()["offset"]; ok { + offset, err = strconv.Atoi(r.URL.Query().Get("offset")) + if err != nil { + err = errors.New("Invalid offset") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + } + if _, ok := r.URL.Query()["order"]; ok { + order = r.URL.Query().Get("order") + if order != "ASC" && order != "DESC" { + err = errors.New("Invalid order") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + } + if _, ok := r.URL.Query()["username"]; ok { + username = r.URL.Query().Get("username") + } + users, err := dataprovider.GetUsers(dataProvider, limit, offset, order, username) + if err == nil { + render.JSON(w, r, users) + } else { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + } +} + +func getUserByID(w http.ResponseWriter, r *http.Request) { + userID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64) + if err != nil { + err = errors.New("Invalid userID") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserByID(dataProvider, userID) + if err == nil { + user.Password = "" + user.PublicKey = "" + render.JSON(w, r, user) + } else if err == sql.ErrNoRows { + sendAPIResponse(w, r, err, "", http.StatusNotFound) + } else { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + } +} + +func addUser(w http.ResponseWriter, r *http.Request) { + var user dataprovider.User + err := render.DecodeJSON(r.Body, &user) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = dataprovider.AddUser(dataProvider, user) + if err == nil { + user, err = dataprovider.UserExists(dataProvider, user.Username) + if err == nil { + user.Password = "" + user.PublicKey = "" + render.JSON(w, r, user) + } else { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + } + } else { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + } +} + +func updateUser(w http.ResponseWriter, r *http.Request) { + userID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64) + if err != nil { + err = errors.New("Invalid userID") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserByID(dataProvider, userID) + if err == sql.ErrNoRows { + sendAPIResponse(w, r, err, "", http.StatusNotFound) + return + } else if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + err = render.DecodeJSON(r.Body, &user) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if user.ID != userID { + sendAPIResponse(w, r, err, "user ID in request body does not match user ID in path parameter", http.StatusBadRequest) + return + } + err = dataprovider.UpdateUser(dataProvider, user) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + } else { + sendAPIResponse(w, r, err, "User updated", http.StatusOK) + } +} + +func deleteUser(w http.ResponseWriter, r *http.Request) { + userID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64) + if err != nil { + err = errors.New("Invalid userID") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserByID(dataProvider, userID) + if err == sql.ErrNoRows { + sendAPIResponse(w, r, err, "", http.StatusNotFound) + return + } else if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + err = dataprovider.DeleteUser(dataProvider, user) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + } else { + sendAPIResponse(w, r, err, "User deleted", http.StatusOK) + } +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 00000000..a0e2f1f4 --- /dev/null +++ b/config/config.go @@ -0,0 +1,88 @@ +package config + +import ( + "encoding/json" + "os" + + "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/sftpd" +) + +const ( + logSender = "config" +) + +var ( + globalConf globalConfig +) + +type globalConfig struct { + SFTPD sftpd.Configuration `json:"sftpd"` + ProviderConf dataprovider.Config `json:"data_provider"` + HTTPDConfig api.HTTPDConf `json:"httpd"` +} + +func init() { + // create a default configuration to use if no config file is provided + globalConf = globalConfig{ + SFTPD: sftpd.Configuration{ + BindPort: 2022, + BindAddress: "", + IdleTimeout: 15, + Umask: "0022", + }, + ProviderConf: dataprovider.Config{ + Driver: "sqlite", + Name: "sftpgo.db", + Host: "", + Port: 5432, + Username: "", + Password: "", + ConnectionString: "", + UsersTable: "users", + ManageUsers: 1, + SSLMode: 0, + TrackQuota: 1, + }, + HTTPDConfig: api.HTTPDConf{ + BindPort: 8080, + BindAddress: "127.0.0.1", + }, + } +} + +// GetSFTPDConfig returns sftpd configuration +func GetSFTPDConfig() sftpd.Configuration { + return globalConf.SFTPD +} + +// GetHTTPDConfig returns httpd configuration +func GetHTTPDConfig() api.HTTPDConf { + return globalConf.HTTPDConfig +} + +//GetProviderConf returns data provider configuration +func GetProviderConf() dataprovider.Config { + return globalConf.ProviderConf +} + +// LoadConfig loads configuration from sftpgo.conf +func LoadConfig(configPath string) error { + logger.Debug(logSender, "load config from path: %v", configPath) + //globalConf.basePath = basePath + file, err := os.Open(configPath) + if err != nil { + logger.Warn(logSender, "error loading configuration file: %v. Default configuration will be used", err) + return err + } + defer file.Close() + err = json.NewDecoder(file).Decode(&globalConf) + if err != nil { + logger.Warn(logSender, "error parsing config file: %v", err) + return err + } + logger.Debug(logSender, "config loaded: %+v", globalConf) + return err +} diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go new file mode 100644 index 00000000..6add9807 --- /dev/null +++ b/dataprovider/dataprovider.go @@ -0,0 +1,231 @@ +package dataprovider + +import ( + "database/sql" + "fmt" + "path/filepath" + "strings" + + "github.com/alexedwards/argon2id" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/utils" +) + +const ( + // SQLiteDataProviderName name for sqlite db provider + SQLiteDataProviderName = "sqlite" + // PGSSQLDataProviderName name for postgresql db provider + PGSSQLDataProviderName = "postgresql" + // MySQLDataProviderName name for mysql db provider + MySQLDataProviderName = "mysql" + + logSender = "dataProvider" + argonPwdPrefix = "$argon2id$" + manageUsersDisabledError = "please set manage_users to 1 in sftpgo.conf to enable this method" + trackQuotaDisabledError = "please set track_quota to 1 in sftpgo.conf to enable this method" +) + +var ( + // SupportedProviders data provider in config file must be one of these strings + SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName} + dbHandle *sql.DB + config Config + provider Provider + sqlPlaceholders []string + validPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermDelete, PermRename, + PermCreateDirs, PermCreateSymlinks} +) + +// Config provider configuration +type Config struct { + Driver string `json:"driver"` + Name string `json:"name"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + ConnectionString string `json:"connection_string"` + UsersTable string `json:"users_table"` + ManageUsers int `json:"manage_users"` + SSLMode int `json:"sslmode"` + TrackQuota int `json:"track_quota"` +} + +// ValidationError raised if input data is not valid +type ValidationError struct { + err string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("Validation error: %s", e.err) +} + +// MethodDisabledError raised if a method is disable in config file +type MethodDisabledError struct { + err string +} + +func (e *MethodDisabledError) Error() string { + return fmt.Sprintf("Method disabled error: %s", e.err) +} + +// GetProvider get the configured provider +func GetProvider() Provider { + return provider +} + +// Provider interface for data providers +type Provider interface { + validateUserAndPass(username string, password string) (User, error) + validateUserAndPubKey(username string, pubKey string) (User, error) + updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error + getUsedQuota(username string) (int, int64, error) + userExists(username string) (User, error) + addUser(user User) error + updateUser(user User) error + deleteUser(user User) error + getUsers(limit int, offset int, order string, username string) ([]User, error) + getUserByID(ID int64) (User, error) +} + +// Initialize auth provider +func Initialize(cnf Config, basePath string) error { + config = cnf + sqlPlaceholders = getSQLPlaceholders() + if config.Driver == SQLiteDataProviderName { + provider = SQLiteProvider{} + return initializeSQLiteProvider(basePath) + } else if config.Driver == PGSSQLDataProviderName { + provider = PGSQLProvider{} + return initializePGSQLProvider() + } else if config.Driver == MySQLDataProviderName { + provider = SQLiteProvider{} + return initializeMySQLProvider() + } + return fmt.Errorf("Unsupported data provider: %v", config.Driver) +} + +// CheckUserAndPass returns the user with the given username and password if exists +func CheckUserAndPass(p Provider, username string, password string) (User, error) { + return p.validateUserAndPass(username, password) +} + +// CheckUserAndPubKey returns the user with the given username and public key if exists +func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, error) { + return p.validateUserAndPubKey(username, pubKey) +} + +// UpdateUserQuota update the quota for the given user +func UpdateUserQuota(p Provider, username string, filesAdd int, sizeAdd int64, reset bool) error { + if config.TrackQuota == 0 { + return &MethodDisabledError{err: trackQuotaDisabledError} + } + return p.updateQuota(username, filesAdd, sizeAdd, reset) +} + +// GetUsedQuota returns the used quota for the given user +func GetUsedQuota(p Provider, username string) (int, int64, error) { + if config.TrackQuota == 0 { + return 0, 0, &MethodDisabledError{err: trackQuotaDisabledError} + } + return p.getUsedQuota(username) +} + +// UserExists checks if the given username exists +func UserExists(p Provider, username string) (User, error) { + return p.userExists(username) +} + +// AddUser adds a new user, ManageUsers configuration must be set to 1 to enable this method +func AddUser(p Provider, user User) error { + if config.ManageUsers == 0 { + return &MethodDisabledError{err: manageUsersDisabledError} + } + return p.addUser(user) +} + +// UpdateUser updates an existing user, ManageUsers configuration must be set to 1 to enable this method +func UpdateUser(p Provider, user User) error { + if config.ManageUsers == 0 { + return &MethodDisabledError{err: manageUsersDisabledError} + } + return p.updateUser(user) +} + +// DeleteUser deletes an existing user, ManageUsers configuration must be set to 1 to enable this method +func DeleteUser(p Provider, user User) error { + if config.ManageUsers == 0 { + return &MethodDisabledError{err: manageUsersDisabledError} + } + return p.deleteUser(user) +} + +// GetUsers returns an array of users respecting limit and offset +func GetUsers(p Provider, limit int, offset int, order string, username string) ([]User, error) { + return p.getUsers(limit, offset, order, username) +} + +// GetUserByID returns the user with the given ID +func GetUserByID(p Provider, ID int64) (User, error) { + return p.getUserByID(ID) +} + +func validateUser(user *User) error { + if len(user.Username) == 0 || len(user.HomeDir) == 0 { + return &ValidationError{err: "Mandatory parameters missing"} + } + if len(user.Password) == 0 && len(user.PublicKey) == 0 { + return &ValidationError{err: "Please set password or public_key"} + } + if len(user.Permissions) == 0 { + return &ValidationError{err: "Please grant some permissions to this user"} + } + if !filepath.IsAbs(user.HomeDir) { + return &ValidationError{err: fmt.Sprintf("home_dir must be an absolute path, actual value: %v", user.HomeDir)} + } + for _, p := range user.Permissions { + if !utils.IsStringInSlice(p, validPerms) { + return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)} + } + } + if !strings.HasPrefix(user.Password, argonPwdPrefix) { + pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams) + if err != nil { + return err + } + user.Password = pwd + } + if len(user.PublicKey) > 0 { + _, _, _, _, err := ssh.ParseAuthorizedKey([]byte(user.PublicKey)) + if err != nil { + return err + } + } + return nil +} + +func getSSLMode() string { + if config.Driver == PGSSQLDataProviderName { + if config.SSLMode == 0 { + return "disable" + } else if config.SSLMode == 1 { + return "require" + } else if config.SSLMode == 2 { + return "verify-ca" + } else if config.SSLMode == 3 { + return "verify-full" + } + } else if config.Driver == MySQLDataProviderName { + if config.SSLMode == 0 { + return "false" + } else if config.SSLMode == 1 { + return "true" + } else if config.SSLMode == 2 { + return "skip-verify" + } else if config.SSLMode == 3 { + return "preferred" + } + } + return "" +} diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go new file mode 100644 index 00000000..218cb090 --- /dev/null +++ b/dataprovider/mysql.go @@ -0,0 +1,90 @@ +package dataprovider + +import ( + "database/sql" + "fmt" + "runtime" + "time" + + "github.com/drakkan/sftpgo/logger" +) + +// MySQLProvider auth provider for sqlite database +type MySQLProvider struct { +} + +func initializeMySQLProvider() error { + var err error + var connectionString string + if len(config.ConnectionString) == 0 { + connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8&interpolateParams=true&timeout=10s&tls=%v", + config.Username, config.Password, config.Host, config.Port, config.Name, getSSLMode()) + } else { + connectionString = config.ConnectionString + } + dbHandle, err = sql.Open("mysql", connectionString) + if err == nil { + numCPU := runtime.NumCPU() + logger.Debug(logSender, "mysql database handle created, connection string: %v, connections: %v", connectionString, numCPU) + dbHandle.SetMaxIdleConns(numCPU) + dbHandle.SetMaxOpenConns(numCPU) + dbHandle.SetConnMaxLifetime(1800 * time.Second) + } else { + logger.Warn(logSender, "error creating mysql database handler, connection string: %v, error: %v", connectionString, err) + } + return err +} + +func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) { + return sqlCommonValidateUserAndPass(username, password) +} + +func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { + return sqlCommonValidateUserAndPubKey(username, publicKey) +} + +func (p MySQLProvider) getUserByID(ID int64) (User, error) { + return sqlCommonGetUserByID(ID) +} + +func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + tx, err := dbHandle.Begin() + if err != nil { + logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) + return err + } + err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) + if err == nil { + err = tx.Commit() + } else { + err = tx.Rollback() + } + if err != nil { + logger.Warn(logSender, "error closing transaction to update quota for user %v: %v", username, err) + } + return err +} + +func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) { + return sqlCommonGetUsedQuota(username) +} + +func (p MySQLProvider) userExists(username string) (User, error) { + return sqlCommonCheckUserExists(username) +} + +func (p MySQLProvider) addUser(user User) error { + return sqlCommonAddUser(user) +} + +func (p MySQLProvider) updateUser(user User) error { + return sqlCommonUpdateUser(user) +} + +func (p MySQLProvider) deleteUser(user User) error { + return sqlCommonDeleteUser(user) +} + +func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { + return sqlCommonGetUsers(limit, offset, order, username) +} diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go new file mode 100644 index 00000000..bbc8d7cc --- /dev/null +++ b/dataprovider/pgsql.go @@ -0,0 +1,88 @@ +package dataprovider + +import ( + "database/sql" + "fmt" + "runtime" + + "github.com/drakkan/sftpgo/logger" +) + +// PGSQLProvider auth provider for PostgreSQL database +type PGSQLProvider struct { +} + +func initializePGSQLProvider() error { + var err error + var connectionString string + if len(config.ConnectionString) == 0 { + connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10", + config.Host, config.Port, config.Name, config.Username, config.Password, getSSLMode()) + } else { + connectionString = config.ConnectionString + } + dbHandle, err = sql.Open("postgres", connectionString) + if err == nil { + numCPU := runtime.NumCPU() + logger.Debug(logSender, "postgres database handle created, connection string: %v, connections: %v", connectionString, numCPU) + dbHandle.SetMaxIdleConns(numCPU) + dbHandle.SetMaxOpenConns(numCPU) + } else { + logger.Warn(logSender, "error creating postgres database handler, connection string: %v, error: %v", connectionString, err) + } + return err +} + +func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) { + return sqlCommonValidateUserAndPass(username, password) +} + +func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { + return sqlCommonValidateUserAndPubKey(username, publicKey) +} + +func (p PGSQLProvider) getUserByID(ID int64) (User, error) { + return sqlCommonGetUserByID(ID) +} + +func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + tx, err := dbHandle.Begin() + if err != nil { + logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) + return err + } + err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) + if err == nil { + err = tx.Commit() + } else { + err = tx.Rollback() + } + if err != nil { + logger.Warn(logSender, "error closing transaction to update quota for user %v: %v", username, err) + } + return err +} + +func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) { + return sqlCommonGetUsedQuota(username) +} + +func (p PGSQLProvider) userExists(username string) (User, error) { + return sqlCommonCheckUserExists(username) +} + +func (p PGSQLProvider) addUser(user User) error { + return sqlCommonAddUser(user) +} + +func (p PGSQLProvider) updateUser(user User) error { + return sqlCommonUpdateUser(user) +} + +func (p PGSQLProvider) deleteUser(user User) error { + return sqlCommonDeleteUser(user) +} + +func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { + return sqlCommonGetUsers(limit, offset, order, username) +} diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go new file mode 100644 index 00000000..93f271e2 --- /dev/null +++ b/dataprovider/sqlcommon.go @@ -0,0 +1,289 @@ +package dataprovider + +import ( + "database/sql" + "encoding/json" + "errors" + "strings" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/alexedwards/argon2id" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/utils" +) + +func getUserByUsername(username string) (User, error) { + var user User + q := getUserByUsernameQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Debug(logSender, "error preparing database query %v: %v", q, err) + return user, err + } + defer stmt.Close() + + row := stmt.QueryRow(username) + return getUserFromDbRow(row, nil) +} + +func sqlCommonValidateUserAndPass(username string, password string) (User, error) { + var user User + if len(password) == 0 { + return user, errors.New("Credentials cannot be null or empty") + } + user, err := getUserByUsername(username) + if err != nil { + logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) + } else { + match := false + if strings.HasPrefix(user.Password, argonPwdPrefix) { + match, err = argon2id.ComparePasswordAndHash(password, user.Password) + if err != nil { + logger.Warn(logSender, "error comparing password with argon hash: %v", err) + return user, err + } + } else { + // clear text password match + match = (user.Password == password) + } + if !match { + err = errors.New("Invalid credentials") + } + } + return user, err +} + +func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) { + var user User + if len(pubKey) == 0 { + return user, errors.New("Credentials cannot be null or empty") + } + user, err := getUserByUsername(username) + if err != nil { + logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) + } else { + if len(user.PublicKey) > 0 { + storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(user.PublicKey)) + if err != nil { + logger.Warn(logSender, "error parsing stored public key for user %v: %v", username, err) + return user, err + } + if string(storedPubKey.Marshal()) != pubKey { + err = errors.New("Invalid credentials") + } + } else { + err = errors.New("Invalid credentials") + } + } + return user, err +} + +func sqlCommonGetUserByID(ID int64) (User, error) { + var user User + q := getUserByIDQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Debug(logSender, "error preparing database query %v: %v", q, err) + return user, err + } + defer stmt.Close() + + row := stmt.QueryRow(ID) + return getUserFromDbRow(row, nil) +} + +func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error { + var usedFiles int + var usedSize int64 + var err error + if reset { + usedFiles = 0 + usedSize = 0 + } else { + usedFiles, usedSize, err = p.getUsedQuota(username) + if err != nil { + return err + } + } + usedFiles += filesAdd + usedSize += sizeAdd + if usedFiles < 0 { + logger.Warn(logSender, "used files is negative, probably some files were added and not tracked, please rescan quota!") + usedFiles = 0 + } + if usedSize < 0 { + logger.Warn(logSender, "used files is negative, probably some files were added and not tracked, please rescan quota!") + usedSize = 0 + } + + q := getUpdateQuotaQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Debug(logSender, "error preparing database query %v: %v", q, err) + return err + } + defer stmt.Close() + _, err = stmt.Exec(usedSize, usedFiles, utils.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + logger.Debug(logSender, "quota updated for user %v, new files: %v new size: %v", username, usedFiles, usedSize) + } else { + logger.Warn(logSender, "error updating quota for username %v: %v", username, err) + } + return err +} + +func sqlCommonGetUsedQuota(username string) (int, int64, error) { + q := getQuotaQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Warn(logSender, "error preparing database query %v: %v", q, err) + return 0, 0, err + } + defer stmt.Close() + + var usedFiles int + var usedSize int64 + err = stmt.QueryRow(username).Scan(&usedSize, &usedFiles) + if err != nil { + logger.Warn(logSender, "error getting user quota: %v, error: %v", username, err) + return 0, 0, err + } + return usedFiles, usedSize, err +} + +func sqlCommonCheckUserExists(username string) (User, error) { + var user User + q := getUserByUsernameQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Warn(logSender, "error preparing database query %v: %v", q, err) + return user, err + } + defer stmt.Close() + row := stmt.QueryRow(username) + return getUserFromDbRow(row, nil) +} + +func sqlCommonAddUser(user User) error { + err := validateUser(&user) + if err != nil { + return err + } + q := getAddUserQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Warn(logSender, "error preparing database query %v: %v", q, err) + return err + } + defer stmt.Close() + permissions, err := user.GetPermissionsAsJSON() + if err != nil { + return err + } + _, err = stmt.Exec(user.Username, user.Password, user.PublicKey, user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, + user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth) + return err +} + +func sqlCommonUpdateUser(user User) error { + err := validateUser(&user) + if err != nil { + return err + } + q := getUpdateUserQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Warn(logSender, "error preparing database query %v: %v", q, err) + return err + } + defer stmt.Close() + permissions, err := user.GetPermissionsAsJSON() + if err != nil { + return err + } + _, err = stmt.Exec(user.Password, user.PublicKey, user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, + user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.ID) + return err +} + +func sqlCommonDeleteUser(user User) error { + q := getDeleteUserQuery() + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Warn(logSender, "error preparing database query %v: %v", q, err) + return err + } + defer stmt.Close() + _, err = stmt.Exec(user.ID) + return err +} + +func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) { + users := []User{} + q := getUsersQuery(order, username) + stmt, err := dbHandle.Prepare(q) + if err != nil { + logger.Warn(logSender, "error preparing database query %v: %v", q, err) + return nil, err + } + defer stmt.Close() + var rows *sql.Rows + if len(username) > 0 { + rows, err = stmt.Query(username, limit, offset) + } else { + rows, err = stmt.Query(limit, offset) + } + if err == nil { + defer rows.Close() + for rows.Next() { + u, err := getUserFromDbRow(nil, rows) + // hide password and public key + u.Password = "" + u.PublicKey = "" + if err == nil { + users = append(users, u) + } else { + break + } + } + } + + return users, err +} + +func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) { + var user User + var permissions sql.NullString + var password sql.NullString + var publicKey sql.NullString + var err error + if row != nil { + err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, + &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaScan, + &user.UploadBandwidth, &user.DownloadBandwidth) + + } else { + err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, + &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaScan, + &user.UploadBandwidth, &user.DownloadBandwidth) + } + if err != nil { + return user, err + } + if password.Valid { + user.Password = password.String + } + if publicKey.Valid { + user.PublicKey = publicKey.String + } + if permissions.Valid { + var list []string + err = json.Unmarshal([]byte(permissions.String), &list) + if err == nil { + user.Permissions = list + } + } + return user, err +} diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go new file mode 100644 index 00000000..1179244b --- /dev/null +++ b/dataprovider/sqlite.go @@ -0,0 +1,86 @@ +package dataprovider + +import ( + "database/sql" + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/drakkan/sftpgo/logger" +) + +// SQLiteProvider auth provider for sqlite database +type SQLiteProvider struct { +} + +func initializeSQLiteProvider(basePath string) error { + var err error + var connectionString string + if len(config.ConnectionString) == 0 { + dbPath := filepath.Join(basePath, config.Name) + fi, err := os.Stat(dbPath) + if err != nil { + logger.Warn(logSender, "sqlite database file does not exists, please be sure to create and inizialiaze"+ + " a database before starting sftpgo") + return err + } + if fi.Size() == 0 { + return errors.New("sqlite database file is invalid, please be sure to create and inizialiaze" + + " a database before starting sftpgo") + } + connectionString = fmt.Sprintf("file:%v?cache=shared", dbPath) + } else { + connectionString = config.ConnectionString + } + dbHandle, err = sql.Open("sqlite3", connectionString) + if err == nil { + logger.Debug(logSender, "sqlite database handle created, connection string: %v", connectionString) + dbHandle.SetMaxOpenConns(1) + } else { + logger.Warn(logSender, "error creating sqlite database handler, connection string: %v, error: %v", connectionString, err) + } + return err +} + +func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) { + return sqlCommonValidateUserAndPass(username, password) +} + +func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { + return sqlCommonValidateUserAndPubKey(username, publicKey) +} + +func (p SQLiteProvider) getUserByID(ID int64) (User, error) { + return sqlCommonGetUserByID(ID) +} + +func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + // we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block + // the database access since it will try to open a new connection + return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) +} + +func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) { + return sqlCommonGetUsedQuota(username) +} + +func (p SQLiteProvider) userExists(username string) (User, error) { + return sqlCommonCheckUserExists(username) +} + +func (p SQLiteProvider) addUser(user User) error { + return sqlCommonAddUser(user) +} + +func (p SQLiteProvider) updateUser(user User) error { + return sqlCommonUpdateUser(user) +} + +func (p SQLiteProvider) deleteUser(user User) error { + return sqlCommonDeleteUser(user) +} + +func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { + return sqlCommonGetUsers(limit, offset, order, username) +} diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go new file mode 100644 index 00000000..07d89cdb --- /dev/null +++ b/dataprovider/sqlqueries.go @@ -0,0 +1,66 @@ +package dataprovider + +import "fmt" + +const ( + selectUserFields = "id,username,password,public_key,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions," + + "used_quota_size,used_quota_files,last_quota_scan,upload_bandwidth,download_bandwidth" +) + +func getSQLPlaceholders() []string { + var placeholders []string + for i := 1; i <= 20; i++ { + if config.Driver == PGSSQLDataProviderName { + placeholders = append(placeholders, fmt.Sprintf("$%v", i)) + } else { + placeholders = append(placeholders, "?") + } + } + return placeholders +} + +func getUserByUsernameQuery() string { + return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectUserFields, config.UsersTable, sqlPlaceholders[0]) +} + +func getUserByIDQuery() string { + return fmt.Sprintf(`SELECT %v FROM %v WHERE id = %v`, selectUserFields, config.UsersTable, sqlPlaceholders[0]) +} + +func getUsersQuery(order string, username string) string { + if len(username) > 0 { + return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v ORDER BY username %v LIMIT %v OFFSET %v`, + selectUserFields, config.UsersTable, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2]) + } + return fmt.Sprintf(`SELECT %v FROM %v ORDER BY username %v LIMIT %v OFFSET %v`, selectUserFields, config.UsersTable, + order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateQuotaQuery() string { + return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_scan = %v + WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getQuotaQuery() string { + return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %v WHERE username = %v`, config.UsersTable, + sqlPlaceholders[0]) +} + +func getAddUserQuery() string { + return fmt.Sprintf(`INSERT INTO %v (username,password,public_key,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions, + used_quota_size,used_quota_files,last_quota_scan,upload_bandwidth,download_bandwidth) + VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v)`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], + sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], + sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11]) +} + +func getUpdateUserQuery() string { + return fmt.Sprintf(`UPDATE %v SET password=%v,public_key=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v, + quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v WHERE id = %v`, config.UsersTable, + sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], + sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11]) +} + +func getDeleteUserQuery() string { + return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, config.UsersTable, sqlPlaceholders[0]) +} diff --git a/dataprovider/user.go b/dataprovider/user.go new file mode 100644 index 00000000..25c4c4db --- /dev/null +++ b/dataprovider/user.go @@ -0,0 +1,84 @@ +package dataprovider + +import ( + "encoding/json" + "path/filepath" + + "github.com/drakkan/sftpgo/utils" +) + +// Permissions +const ( + PermAny = "*" + PermListItems = "list" + PermDownload = "download" + PermUpload = "upload" + PermDelete = "delete" + PermRename = "rename" + PermCreateDirs = "create_dirs" + PermCreateSymlinks = "create_symlinks" +) + +// User defines an SFTP user +type User struct { + ID int64 `json:"id"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + PublicKey string `json:"public_key,omitempty"` + HomeDir string `json:"home_dir"` + UID int `json:"uid"` + GID int `json:"gid"` + MaxSessions int `json:"max_sessions"` + QuotaSize int64 `json:"quota_size"` + QuotaFiles int `json:"quota_files"` + Permissions []string `json:"permissions"` + UsedQuotaSize int64 `json:"used_quota_size"` + UsedQuotaFiles int `json:"used_quota_files"` + LastQuotaScan int64 `json:"last_quota_scan"` + UploadBandwidth int64 `json:"upload_bandwidth"` + DownloadBandwidth int64 `json:"download_bandwidth"` +} + +// HasPerm returns true if the user has the given permission or any permission +func (u *User) HasPerm(permission string) bool { + if utils.IsStringInSlice(PermAny, u.Permissions) { + return true + } + return utils.IsStringInSlice(permission, u.Permissions) +} + +// HasOption returns true if the user has the give option +/*func (u *User) HasOption(option string) bool { + return utils.IsStringInSlice(option, u.Options) +}*/ + +// GetPermissionsAsJSON returns the permission as json byte array +func (u *User) GetPermissionsAsJSON() ([]byte, error) { + return json.Marshal(u.Permissions) +} + +// GetOptionsAsJSON returns the permission as json byte array +/*func (u *User) GetOptionsAsJSON() ([]byte, error) { + return json.Marshal(u.Options) +}*/ + +// GetUID returns a validate uid +func (u *User) GetUID() int { + if u.UID <= 0 || u.UID > 65535 { + return -1 + } + return u.UID +} + +// GetGID returns a validate gid +func (u *User) GetGID() int { + if u.GID <= 0 || u.GID > 65535 { + return -1 + } + return u.GID +} + +// GetHomeDir returns user home dir cleaned +func (u *User) GetHomeDir() string { + return filepath.Clean(u.HomeDir) +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..69ec4dab --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/drakkan/sftpgo + +go 1.12 + +require ( + github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802 + github.com/go-chi/chi v4.0.2+incompatible + github.com/go-chi/render v1.0.1 + github.com/go-sql-driver/mysql v1.4.1 + github.com/kr/fs v0.1.0 // indirect + github.com/lib/pq v1.1.1 + github.com/mattn/go-sqlite3 v1.10.0 + github.com/pkg/sftp v1.10.0 + github.com/rs/zerolog v1.14.3 + golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 + gopkg.in/natefinch/lumberjack.v2 v2.0.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..271e7eef --- /dev/null +++ b/go.sum @@ -0,0 +1,37 @@ +github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802 h1:RwMM1q/QSKYIGbHfOkf843hE8sSUJtf1dMwFPtEDmm0= +github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802/go.mod h1:4dsm7ufQm1Gwl8S2ss57u+2J7KlxIL2QUmFGlGtWogY= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/go-chi/chi v4.0.2+incompatible h1:maB6vn6FqCxrpz4FqWdh4+lwpyZIQS7YEAUcHlgXVRs= +github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= +github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= +github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= +github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.10.0 h1:DGA1KlA9esU6WcicH+P8PxFZOl15O6GYtab1cIJdOlE= +github.com/pkg/sftp v1.10.0/go.mod h1:NxmoDg/QLVWluQDUYG7XBZTLUpKeFa8e3aMf1BfjyHk= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0= +github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= +gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= diff --git a/init/sftpgo.service b/init/sftpgo.service new file mode 100644 index 00000000..20388e45 --- /dev/null +++ b/init/sftpgo.service @@ -0,0 +1,16 @@ +[Unit] +Description=SFTPGo sftp server +After=network.target + +[Service] +User=root +Group=root +Type=simple +WorkingDirectory=/etc/sftpgo +ExecStart=/usr/bin/sftpgo -config-dir /etc/sftpgo -log-file-path /var/log/sftpgo.log +KillMode=mixed +Restart=always +RestartSec=10s + +[Install] +WantedBy=multi-user.target diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..62f6cd6b --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,80 @@ +package logger + +import ( + "fmt" + + "github.com/rs/zerolog" + lumberjack "gopkg.in/natefinch/lumberjack.v2" +) + +const ( + dateFormat = "2006-01-02T15:04.05.000" // YYYY-MM-DDTHH:MM.SS.ZZZ +) + +var ( + logger zerolog.Logger +) + +// GetLogger get the logger instance +func GetLogger() *zerolog.Logger { + return &logger +} + +// InitLogger initialize loggers +func InitLogger(logFilePath string, level zerolog.Level) { + logMaxSize := 10 // MB + logMaxBackups := 5 + logMaxAge := 28 // days + + zerolog.TimeFieldFormat = dateFormat + logger = zerolog.New(&lumberjack.Logger{ + Filename: logFilePath, + MaxSize: logMaxSize, + MaxBackups: logMaxBackups, + MaxAge: logMaxAge, + Compress: false, + }).With().Timestamp().Logger().Level(level) +} + +// Debug log at debug level for sender +func Debug(sender string, format string, v ...interface{}) { + logger.Debug().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) +} + +// Info log at info level for sender +func Info(sender string, format string, v ...interface{}) { + logger.Info().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) +} + +// Warn log at warn level for sender +func Warn(sender string, format string, v ...interface{}) { + logger.Warn().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) +} + +// Error log at error level for sender +func Error(sender string, format string, v ...interface{}) { + logger.Error().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) +} + +// TransferLog logs an SFTP upload or download +func TransferLog(operation string, path string, elapsed int64, size int64, user string, connectionID string) { + logger.Info(). + Str("sender", operation). + Int64("elapsed_ms", elapsed). + Int64("size_bytes", size). + Str("username", user). + Str("file_path", path). + Str("connection_id", connectionID). + Msg("") +} + +// CommandLog log an SFTP command +func CommandLog(command string, path string, target string, user string, connectionID string) { + logger.Info(). + Str("sender", command). + Str("username", user). + Str("file_path", path). + Str("target_path", target). + Str("connection_id", connectionID). + Msg("") +} diff --git a/logger/request_logger.go b/logger/request_logger.go new file mode 100644 index 00000000..9fd4575f --- /dev/null +++ b/logger/request_logger.go @@ -0,0 +1,65 @@ +package logger + +import ( + "fmt" + "net/http" + "time" + + "github.com/go-chi/chi/middleware" + "github.com/rs/zerolog" +) + +// StructuredLogger that uses zerolog +type StructuredLogger struct { + Logger *zerolog.Logger +} + +// StructuredLoggerEntry using zerolog logger +type StructuredLoggerEntry struct { + Logger *zerolog.Logger + fields map[string]interface{} +} + +// NewStructuredLogger returns RequestLogger +func NewStructuredLogger(logger *zerolog.Logger) func(next http.Handler) http.Handler { + return middleware.RequestLogger(&StructuredLogger{logger}) +} + +// NewLogEntry creates a new log entry +func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + fields := map[string]interface{}{ + "remote_addr": r.RemoteAddr, + "proto": r.Proto, + "method": r.Method, + "user_agent": r.UserAgent(), + "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)} + + reqID := middleware.GetReqID(r.Context()) + if reqID != "" { + fields["request_id"] = reqID + } + + return &StructuredLoggerEntry{Logger: l.Logger, fields: fields} +} + +// Write a new entry +func (l *StructuredLoggerEntry) Write(status, bytes int, elapsed time.Duration) { + l.Logger.Info().Fields(l.fields).Int( + "resp_status", status).Int( + "resp_size", bytes).Int64( + "elapsed_ms", elapsed.Nanoseconds()/1000000).Str( + "sender", "httpd").Msg( + "") +} + +// Panic logs panics +func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) { + l.Logger.Error().Fields(l.fields).Str( + "stack", string(stack)).Str( + "panic", fmt.Sprintf("%+v", v)).Msg("") +} diff --git a/main.go b/main.go new file mode 100644 index 00000000..8e1a502a --- /dev/null +++ b/main.go @@ -0,0 +1,83 @@ +package main // import "github.com/drakkan/sftpgo" + +import ( + "flag" + "fmt" + "net/http" + "os" + "path/filepath" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + + "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/config" + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/sftpd" + + "github.com/rs/zerolog" +) + +func main() { + confName := "sftpgo.conf" + logSender := "main" + var ( + configDir string + logFilePath string + ) + flag.StringVar(&configDir, "config-dir", ".", "Location for SFTPGo config dir. It must contain sftpgo.conf, "+ + "the private key for the SFTP server (id_rsa file) and the SQLite database if you use SQLite as data provider. "+ + "The server private key will be autogenerated if the user that executes SFTPGo has write access to the config-dir") + flag.StringVar(&logFilePath, "log-file-path", "sftpgo.log", "Location for the log file") + flag.Parse() + + configFilePath := filepath.Join(configDir, confName) + logger.InitLogger(logFilePath, zerolog.DebugLevel) + logger.Info(logSender, "starting SFTPGo, config dir: %v", configDir) + config.LoadConfig(configFilePath) + providerConf := config.GetProviderConf() + + err := dataprovider.Initialize(providerConf, configDir) + if err != nil { + logger.Warn(logSender, "error initializing data provider: %v", err) + os.Exit(1) + } + + dataProvider := dataprovider.GetProvider() + sftpdConf := config.GetSFTPDConfig() + httpdConf := config.GetHTTPDConfig() + router := api.GetHTTPRouter() + + sftpd.SetDataProvider(dataProvider) + api.SetDataProvider(dataProvider) + + shutdown := make(chan bool) + + go func() { + logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf) + if err := sftpdConf.Initalize(configDir); err != nil { + logger.Error(logSender, "could not start SFTP server: %v", err) + } + shutdown <- true + }() + + go func() { + logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf) + s := &http.Server{ + Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort), + Handler: router, + ReadTimeout: 300 * time.Second, + WriteTimeout: 300 * time.Second, + MaxHeaderBytes: 1 << 20, // 1MB + } + if err := s.ListenAndServe(); err != nil { + logger.Error(logSender, "could not start HTTP server: %v", err) + } + shutdown <- true + }() + + <-shutdown +} diff --git a/sftpd/handler.go b/sftpd/handler.go new file mode 100644 index 00000000..5e3736ce --- /dev/null +++ b/sftpd/handler.go @@ -0,0 +1,507 @@ +package sftpd + +import ( + "fmt" + "io" + "io/ioutil" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/drakkan/sftpgo/utils" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "golang.org/x/crypto/ssh" + + "github.com/pkg/sftp" +) + +// Connection details for an authenticated user +type Connection struct { + ID string + User dataprovider.User + ClientVersion string + RemoteAddr net.Addr + StartTime time.Time + lastActivity time.Time + lock *sync.Mutex + sshConn *ssh.ServerConn + dataProvider dataprovider.Provider +} + +// Fileread creates a reader for a file on the system and returns the reader back. +func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { + updateConnectionActivity(c.ID) + + if !c.User.HasPerm(dataprovider.PermDownload) { + return nil, sftp.ErrSshFxPermissionDenied + } + + p, err := c.buildPath(request.Filepath) + if err != nil { + return nil, sftp.ErrSshFxNoSuchFile + } + + c.lock.Lock() + defer c.lock.Unlock() + + if _, err := os.Stat(p); os.IsNotExist(err) { + return nil, sftp.ErrSshFxNoSuchFile + } + + file, err := os.Open(p) + if err != nil { + logger.Error(logSender, "could not open file \"%v\" for reading: %v", p, err) + return nil, sftp.ErrSshFxFailure + } + + logger.Debug(logSender, "fileread requested for path: \"%v\", user: %v", p, c.User.Username) + + transfer := Transfer{ + file: file, + path: p, + start: time.Now(), + bytesSent: 0, + bytesReceived: 0, + user: c.User, + connectionID: c.ID, + transferType: transferDownload, + isNewFile: false, + } + addTransfer(&transfer) + return &transfer, nil +} + +// Filewrite handles the write actions for a file on the system. +func (c Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { + updateConnectionActivity(c.ID) + if !c.User.HasPerm(dataprovider.PermUpload) { + return nil, sftp.ErrSshFxPermissionDenied + } + + p, err := c.buildPath(request.Filepath) + if err != nil { + return nil, sftp.ErrSshFxNoSuchFile + } + + c.lock.Lock() + defer c.lock.Unlock() + + stat, statErr := os.Stat(p) + // If the file doesn't exist we need to create it, as well as the directory pathway + // leading up to where that file will be created. + if os.IsNotExist(statErr) { + if !c.hasSpace(true) { + logger.Info(logSender, "denying file write due to space limit") + return nil, sftp.ErrSshFxFailure + } + + if _, err := os.Stat(filepath.Dir(p)); os.IsNotExist(err) { + if !c.User.HasPerm(dataprovider.PermCreateDirs) { + return nil, sftp.ErrSshFxPermissionDenied + } + } + + dirsToCreate, err := c.findNonexistentDirs(p) + if err != nil { + return nil, sftp.ErrSshFxFailure + } + + last := len(dirsToCreate) - 1 + for i := range dirsToCreate { + d := dirsToCreate[last-i] + if err := os.Mkdir(d, 0777); err != nil { + logger.Error(logSender, "error making path for file, dir: %v, path: %v", d, p) + return nil, sftp.ErrSshFxFailure + } + utils.SetPathPermissions(d, c.User.GetUID(), c.User.GetGID()) + } + + file, err := os.Create(p) + if err != nil { + logger.Error(logSender, "error creating file %v: %v", p, err) + return nil, sftp.ErrSshFxFailure + } + + utils.SetPathPermissions(p, c.User.GetUID(), c.User.GetGID()) + + logger.Debug(logSender, "file upload/replace started for path \"%v\" user: %v", p, c.User.Username) + + transfer := Transfer{ + file: file, + path: p, + start: time.Now(), + bytesSent: 0, + bytesReceived: 0, + user: c.User, + connectionID: c.ID, + transferType: transferUpload, + isNewFile: true, + } + addTransfer(&transfer) + return &transfer, nil + } + + if statErr != nil { + logger.Error("error performing file stat %v: %v", p, statErr) + return nil, sftp.ErrSshFxFailure + } + + if !c.hasSpace(false) { + logger.Info(logSender, "denying file write due to space limit") + return nil, sftp.ErrSshFxFailure + } + + // Not sure this would ever happen, but lets not find out. + if stat.IsDir() { + logger.Warn("attempted to open a directory for writing to: %v", p) + return nil, sftp.ErrSshFxOpUnsupported + } + + var osFlags int + trunc := false + sftpFileOpenFlags := request.Pflags() + if sftpFileOpenFlags.Read && sftpFileOpenFlags.Write { + osFlags |= os.O_RDWR + } else if sftpFileOpenFlags.Write { + osFlags |= os.O_WRONLY + } + if sftpFileOpenFlags.Append { + osFlags |= os.O_APPEND + } + if sftpFileOpenFlags.Creat { + osFlags |= os.O_CREATE + } + if sftpFileOpenFlags.Trunc { + osFlags |= os.O_TRUNC + trunc = true + } + + if !trunc { + // see https://github.com/pkg/sftp/issues/295 + logger.Info(logSender, "upload resume is not supported, returning error") + return nil, sftp.ErrSshFxOpUnsupported + } + + // we use 0666 so the umask is applied + file, err := os.OpenFile(p, osFlags, 0666) + if err != nil { + logger.Error(logSender, "error opening existing file, flags: %v, source: %v, err: %v", request.Flags, p, err) + return nil, sftp.ErrSshFxFailure + } + + if trunc { + logger.Debug(logSender, "file truncation requested update quota for user %v", c.User.Username) + dataprovider.UpdateUserQuota(dataProvider, c.User.Username, -1, -stat.Size(), false) + } + + utils.SetPathPermissions(p, c.User.GetUID(), c.User.GetGID()) + + logger.Debug(logSender, "file upload started for path \"%v\" user: %v", p, c.User.Username) + + transfer := Transfer{ + file: file, + path: p, + start: time.Now(), + bytesSent: 0, + bytesReceived: 0, + user: c.User, + connectionID: c.ID, + transferType: transferUpload, + isNewFile: trunc, + } + addTransfer(&transfer) + return &transfer, nil +} + +// Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading +// or writing to those files. +func (c Connection) Filecmd(request *sftp.Request) error { + updateConnectionActivity(c.ID) + + p, err := c.buildPath(request.Filepath) + if err != nil { + return sftp.ErrSshFxNoSuchFile + } + + var target string + // If a target is provided in this request validate that it is going to the correct + // location for the server. If it is not, return an operation unsupported error. This + // is maybe not the best error response, but its not wrong either. + if request.Target != "" { + target, err = c.buildPath(request.Target) + if err != nil { + return sftp.ErrSshFxOpUnsupported + } + } + + logger.Debug(logSender, "new cmd, method: %v user: %v", request.Method, c.User.Username) + + switch request.Method { + case "Setstat": + return nil + case "Rename": + if !c.User.HasPerm(dataprovider.PermRename) { + return sftp.ErrSshFxPermissionDenied + } + + logger.CommandLog(sftpdRenameLogSender, p, target, c.User.Username, c.ID) + if err := os.Rename(p, target); err != nil { + logger.Error("failed to rename file, source: %v target: %v: %v", p, target, err) + return sftp.ErrSshFxFailure + } + + break + case "Rmdir": + if !c.User.HasPerm(dataprovider.PermDelete) { + return sftp.ErrSshFxPermissionDenied + } + + logger.CommandLog(sftpdRmdirLogSender, p, target, c.User.Username, c.ID) + numFiles, size, err := utils.ScanDirContents(p) + if err != nil { + logger.Error("failed to remove directory %v, scanning error: %v", p, err) + return sftp.ErrSshFxFailure + } + if err := os.RemoveAll(p); err != nil { + logger.Error("failed to remove directory %v: %v", p, err) + return sftp.ErrSshFxFailure + } + + dataprovider.UpdateUserQuota(dataProvider, c.User.Username, -numFiles, -size, false) + + return sftp.ErrSshFxOk + case "Mkdir": + if !c.User.HasPerm(dataprovider.PermCreateDirs) { + return sftp.ErrSshFxPermissionDenied + } + + logger.CommandLog(sftpdMkdirLogSender, p, target, c.User.Username, c.ID) + dirsToCreate, err := c.findNonexistentDirs(filepath.Join(p, "testfile")) + if err != nil { + return sftp.ErrSshFxFailure + } + + last := len(dirsToCreate) - 1 + for i := range dirsToCreate { + d := dirsToCreate[last-i] + if err := os.Mkdir(d, 0777); err != nil { + logger.Error(logSender, "error making path dir: %v, full path: %v", d, p) + return sftp.ErrSshFxFailure + } + utils.SetPathPermissions(d, c.User.GetUID(), c.User.GetGID()) + } + break + case "Symlink": + if !c.User.HasPerm(dataprovider.PermCreateSymlinks) { + return sftp.ErrSshFxPermissionDenied + } + + logger.CommandLog(sftpdSymlinkLogSender, p, target, c.User.Username, c.ID) + if err := os.Symlink(p, target); err != nil { + logger.Warn("failed to create symlink %v->%v: %v", p, target, err) + return sftp.ErrSshFxFailure + } + + break + case "Remove": + if !c.User.HasPerm(dataprovider.PermDelete) { + return sftp.ErrSshFxPermissionDenied + } + + logger.CommandLog(sftpdRemoveLogSender, p, target, c.User.Username, c.ID) + var size int64 + var fi os.FileInfo + if fi, err = os.Lstat(p); err != nil { + logger.Error(logSender, "failed to remove a file %v: stat error: %v", p, err) + return sftp.ErrSshFxFailure + } + size = fi.Size() + if err := os.Remove(p); err != nil { + logger.Error(logSender, "failed to remove a file %v: %v", p, err) + return sftp.ErrSshFxFailure + } + + if fi.Mode()&os.ModeSymlink != os.ModeSymlink { + dataprovider.UpdateUserQuota(dataProvider, c.User.Username, -1, -size, false) + } + + return sftp.ErrSshFxOk + default: + return sftp.ErrSshFxOpUnsupported + } + + var fileLocation = p + if target != "" { + fileLocation = target + } + + utils.SetPathPermissions(fileLocation, c.User.GetUID(), c.User.GetGID()) + + return sftp.ErrSshFxOk +} + +// Filelist is the handler for SFTP filesystem list calls. This will handle calls to list the contents of +// a directory as well as perform file/folder stat calls. +func (c Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { + updateConnectionActivity(c.ID) + p, err := c.buildPath(request.Filepath) + if err != nil { + return nil, sftp.ErrSshFxNoSuchFile + } + + switch request.Method { + case "List": + if !c.User.HasPerm(dataprovider.PermListItems) { + return nil, sftp.ErrSshFxPermissionDenied + } + + logger.Debug(logSender, "requested list file for dir: %v user: %v", p, c.User.Username) + + files, err := ioutil.ReadDir(p) + if err != nil { + logger.Error(logSender, "error listing directory: %v", err) + return nil, sftp.ErrSshFxFailure + } + + return ListerAt(files), nil + case "Stat": + if !c.User.HasPerm(dataprovider.PermListItems) { + return nil, sftp.ErrSshFxPermissionDenied + } + + logger.Debug(logSender, "requested stat for file: %v user: %v", p, c.User.Username) + s, err := os.Stat(p) + if os.IsNotExist(err) { + return nil, sftp.ErrSshFxNoSuchFile + } else if err != nil { + logger.Error(logSender, "error running STAT on file: %v", err) + return nil, sftp.ErrSshFxFailure + } + + return ListerAt([]os.FileInfo{s}), nil + default: + return nil, sftp.ErrSshFxOpUnsupported + } +} + +func (c Connection) hasSpace(checkFiles bool) bool { + if (checkFiles && c.User.QuotaFiles > 0) || c.User.QuotaSize > 0 { + numFile, size, err := dataprovider.GetUsedQuota(c.dataProvider, c.User.Username) + if err != nil { + if _, ok := err.(*dataprovider.MethodDisabledError); ok { + logger.Warn(logSender, "quota enforcement not possibile for user %v: %v", c.User.Username, err) + return true + } + logger.Warn(logSender, "error getting used quota for %v: %v", c.User.Username, err) + return false + } + if (checkFiles && numFile >= c.User.QuotaFiles) || size >= c.User.QuotaSize { + logger.Debug(logSender, "quota exceed for user %v, num files: %v/%v, size: %v/%v check files: %v", + c.User.Username, numFile, c.User.QuotaFiles, size, c.User.QuotaSize, checkFiles) + return false + } + } + return true +} + +// Normalizes a directory we get from the SFTP request to ensure the user is not able to escape +// from their data directory. After normalization if the directory is still within their home +// path it is returned. If they managed to "escape" an error will be returned. +func (c Connection) buildPath(rawPath string) (string, error) { + r := filepath.Clean(filepath.Join(c.User.HomeDir, rawPath)) + p, err := filepath.EvalSymlinks(r) + if err != nil && !os.IsNotExist(err) { + return "", err + } else if os.IsNotExist(err) { + // The requested directory doesn't exist, so at this point we need to iterate up the + // path chain until we hit a directory that _does_ exist and can be validated. + _, err = c.findFirstExistingDir(r) + if err != nil { + logger.Warn(logSender, "error resolving not existent path: %v", err) + } + return r, err + } + + err = c.isSubDir(p) + if err != nil { + logger.Warn(logSender, "Invalid path resolution, dir: %v outside user home: %v err: %v", p, c.User.HomeDir, err) + } + return r, err +} + +// iterate up the path chain until we hit a directory that does exist and can be validated. +// all nonexistent directories will be returned +func (c Connection) findNonexistentDirs(path string) ([]string, error) { + results := []string{} + cleanPath := filepath.Clean(path) + parent := filepath.Dir(cleanPath) + _, err := os.Stat(parent) + + for os.IsNotExist(err) { + results = append(results, parent) + parent = filepath.Dir(parent) + _, err = os.Stat(parent) + } + if err != nil { + return results, err + } + p, err := filepath.EvalSymlinks(parent) + if err != nil { + return results, err + } + err = c.isSubDir(p) + if err != nil { + logger.Warn(logSender, "Error finding non existing dir: %v", err) + } + return results, err +} + +// iterate up the path chain until we hit a directory that does exist and can be validated. +func (c Connection) findFirstExistingDir(path string) (string, error) { + results, err := c.findNonexistentDirs(path) + if err != nil { + logger.Warn(logSender, "unable to find non existent dirs: %v", err) + return "", err + } + var parent string + if len(results) > 0 { + lastMissingDir := results[len(results)-1] + parent = filepath.Dir(lastMissingDir) + } else { + parent = c.User.GetHomeDir() + } + p, err := filepath.EvalSymlinks(parent) + if err != nil { + return "", err + } + fileInfo, err := os.Stat(p) + if err != nil { + return "", err + } + if !fileInfo.IsDir() { + return "", fmt.Errorf("resolved path is not a dir: %v", p) + } + err = c.isSubDir(p) + return p, err +} + +// checks if sub is a subpath of the user home dir. +// EvalSymlink must be used on sub before calling this method +func (c Connection) isSubDir(sub string) error { + // home dir must exist and it is already a validated absolute path + parent, err := filepath.EvalSymlinks(c.User.HomeDir) + if err != nil { + logger.Warn(logSender, "invalid home dir %v: %v", c.User.HomeDir, err) + return err + } + if !strings.HasPrefix(sub, parent) { + logger.Warn(logSender, "dir %v is not inside: %v ", sub, parent) + return fmt.Errorf("dir %v is not inside: %v", sub, parent) + } + return nil +} diff --git a/sftpd/lister.go b/sftpd/lister.go new file mode 100644 index 00000000..8429f94f --- /dev/null +++ b/sftpd/lister.go @@ -0,0 +1,23 @@ +package sftpd + +import ( + "io" + "os" +) + +// ListerAt .. +type ListerAt []os.FileInfo + +// ListAt returns the number of entries copied and an io.EOF error if we made it to the end of the file list. +// Take a look at the pkg/sftp godoc for more information about how this function should work. +func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) { + if offset >= int64(len(l)) { + return 0, io.EOF + } + + n := copy(f, l[offset:]) + if n < len(f) { + return n, io.EOF + } + return n, nil +} diff --git a/sftpd/server.go b/sftpd/server.go new file mode 100644 index 00000000..a361d76f --- /dev/null +++ b/sftpd/server.go @@ -0,0 +1,283 @@ +package sftpd + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "path/filepath" + "strconv" + "sync" + "time" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/utils" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// Configuration server configuration +type Configuration struct { + BindPort int `json:"bind_port"` + BindAddress string `json:"bind_address"` + IdleTimeout int `json:"idle_timeout"` + Umask string `json:"umask"` +} + +// Initalize the SFTP server and add a persistent listener to handle inbound SFTP connections. +func (c Configuration) Initalize(configDir string) error { + umask, err := strconv.ParseUint(c.Umask, 8, 8) + if err == nil { + utils.SetUmask(int(umask), c.Umask) + } else { + logger.Warn(logSender, "error reading umask, please fix your config file: %v", err) + } + serverConfig := &ssh.ServerConfig{ + NoClientAuth: false, + MaxAuthTries: 10, + PasswordCallback: func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + sp, err := c.validatePasswordCredentials(conn, pass) + if err != nil { + return nil, errors.New("could not validate credentials") + } + + return sp, nil + }, + PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + sp, err := c.validatePublicKeyCredentials(conn, string(pubKey.Marshal())) + if err != nil { + return nil, errors.New("could not validate credentials") + } + + return sp, nil + }, + ServerVersion: "SSH-2.0-SFTPServer", + } + + if _, err := os.Stat(filepath.Join(configDir, "id_rsa")); os.IsNotExist(err) { + logger.Info(logSender, "creating new private key for server") + if err := c.generatePrivateKey(configDir); err != nil { + return err + } + } else if err != nil { + return err + } + + privateBytes, err := ioutil.ReadFile(filepath.Join(configDir, "id_rsa")) + if err != nil { + return err + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return err + } + + // Add our private key to the server configuration. + serverConfig.AddHostKey(private) + + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort)) + if err != nil { + logger.Warn(logSender, "error starting listener on address %v: %v", listener.Addr().String(), err) + return err + } + + logger.Info(logSender, "server listener registered address: %v", listener.Addr().String()) + if c.IdleTimeout > 0 { + startIdleTimer(time.Duration(c.IdleTimeout) * time.Minute) + } + + for { + conn, _ := listener.Accept() + if conn != nil { + go c.AcceptInboundConnection(conn, serverConfig) + } + } +} + +// AcceptInboundConnection handles an inbound connection to the instance and determines if we should serve the request or not. +func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { + defer conn.Close() + + // Before beginning a handshake must be performed on the incoming net.Conn + sconn, chans, reqs, err := ssh.NewServerConn(conn, config) + if err != nil { + logger.Warn(logSender, "failed to accept an incoming connection: %v", err) + return + } + defer sconn.Close() + + logger.Debug(logSender, "accepted inbound connection, ip: %v", conn.RemoteAddr().String()) + + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + // If its not a session channel we just move on because its not something we + // know how to handle at this point. + if newChannel.ChannelType() != "session" { + logger.Debug(logSender, "received an unknown channel type: %v", newChannel.ChannelType()) + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Warn(logSender, "could not accept a channel: %v", err) + continue + } + + // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" + // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) + go func(in <-chan *ssh.Request) { + for req := range in { + ok := false + + switch req.Type { + case "subsystem": + if string(req.Payload[4:]) == "sftp" { + ok = true + } + } + + req.Reply(ok, nil) + } + }(requests) + + var user dataprovider.User + + err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user) + + if err != nil { + logger.Warn(logSender, "Unable to deserialize user info, cannot serve connection: %v", err) + return + } + + connectionID := hex.EncodeToString(sconn.SessionID()) + + // Create a new handler for the currently logged in user's server. + handler := c.createHandler(sconn, user, connectionID) + + // Create the server instance for the channel using the handler we created above. + server := sftp.NewRequestServer(channel, handler) + + if err := server.Serve(); err == io.EOF { + logger.Debug(logSender, "connection closed, id: %v", connectionID) + server.Close() + } else if err != nil { + logger.Error(logSender, "sftp connection closed with error id %v: %v", connectionID, err) + } + + removeConnection(connectionID) + } +} + +func (c Configuration) createHandler(conn *ssh.ServerConn, user dataprovider.User, connectionID string) sftp.Handlers { + + connection := Connection{ + ID: connectionID, + User: user, + ClientVersion: string(conn.ClientVersion()), + RemoteAddr: conn.RemoteAddr(), + StartTime: time.Now(), + lastActivity: time.Now(), + lock: new(sync.Mutex), + sshConn: conn, + } + + addConnection(connectionID, connection) + + return sftp.Handlers{ + FileGet: connection, + FilePut: connection, + FileCmd: connection, + FileList: connection, + } +} + +func loginUser(user dataprovider.User) (*ssh.Permissions, error) { + if !filepath.IsAbs(user.HomeDir) { + logger.Warn(logSender, "user %v has invalid home dir: %v. Home dir must be an absolute path, login not allowed", + user.Username, user.HomeDir) + return nil, fmt.Errorf("Cannot login user with invalid home dir: %v", user.HomeDir) + } + if _, err := os.Stat(user.HomeDir); os.IsNotExist(err) { + logger.Debug(logSender, "home directory \"%v\" for user %v does not exist, try to create", user.HomeDir, user.Username) + err := os.MkdirAll(user.HomeDir, 0777) + if err == nil { + utils.SetPathPermissions(user.HomeDir, user.GetUID(), user.GetGID()) + } + } + + if user.MaxSessions > 0 { + activeSessions := getActiveSessions(user.Username) + if activeSessions >= user.MaxSessions { + logger.Debug(logSender, "authentication refused for user: %v, too many open sessions: %v/%v", user.Username, + activeSessions, user.MaxSessions) + return nil, fmt.Errorf("Too many open sessions: %v", activeSessions) + } + } + + json, err := json.Marshal(user) + if err != nil { + logger.Warn(logSender, "error serializing user info: %v, authentication rejected", err) + return nil, err + } + p := &ssh.Permissions{} + p.Extensions = make(map[string]string) + p.Extensions["user"] = string(json) + return p, nil +} + +func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey string) (*ssh.Permissions, error) { + var err error + var user dataprovider.User + + if user, err = dataprovider.CheckUserAndPubKey(dataProvider, conn.User(), pubKey); err == nil { + return loginUser(user) + } + return nil, err +} + +func (c Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + var err error + var user dataprovider.User + + if user, err = dataprovider.CheckUserAndPass(dataProvider, conn.User(), string(pass)); err == nil { + return loginUser(user) + } + return nil, err +} + +// Generates a private key that will be used by the SFTP server. +func (c Configuration) generatePrivateKey(configDir string) error { + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return err + } + + o, err := os.OpenFile(filepath.Join(configDir, "id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer o.Close() + + pkey := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + if err := pem.Encode(o, pkey); err != nil { + return err + } + + return nil +} diff --git a/sftpd/sftpd.go b/sftpd/sftpd.go new file mode 100644 index 00000000..f670e6ca --- /dev/null +++ b/sftpd/sftpd.go @@ -0,0 +1,287 @@ +package sftpd + +import ( + "sync" + "time" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/utils" +) + +const ( + logSender = "sftpd" + sftpUploadLogSender = "SFTPUpload" + sftpdDownloadLogSender = "SFTPDownload" + sftpdRenameLogSender = "SFTPRename" + sftpdRmdirLogSender = "SFTPRmdir" + sftpdMkdirLogSender = "SFTPMkdir" + sftpdSymlinkLogSender = "SFTPSymlink" + sftpdRemoveLogSender = "SFTPRemove" + operationDownload = "download" + operationUpload = "upload" +) + +var ( + mutex sync.RWMutex + openConnections map[string]Connection + activeTransfers []*Transfer + idleConnectionTicker *time.Ticker + idleTimeout time.Duration + activeQuotaScans []ActiveQuotaScan + dataProvider dataprovider.Provider +) + +type connectionTransfer struct { + OperationType string `json:"operation_type"` + StartTime int64 `json:"start_time"` + Size int64 `json:"size"` + LastActivity int64 `json:"last_activity"` +} + +// ActiveQuotaScan username and start data for a quota scan +type ActiveQuotaScan struct { + Username string `json:"usernane"` + StartTime int64 `json:"start_time"` +} + +// ConnectionStatus status for an active connection +type ConnectionStatus struct { + Username string `json:"username"` + ConnectionID string `json:"connection_id"` + ClientVersion string `json:"client_version"` + RemoteAddress string `json:"remote_address"` + ConnectionTime int64 `json:"connection_time"` + LastActivity int64 `json:"last_activity"` + Transfers []connectionTransfer `json:"active_transfers"` +} + +func init() { + openConnections = make(map[string]Connection) +} + +// SetDataProvider sets the data provider +func SetDataProvider(provider dataprovider.Provider) { + dataProvider = provider +} + +func getActiveSessions(username string) int { + mutex.RLock() + defer mutex.RUnlock() + numSessions := 0 + for _, c := range openConnections { + if c.User.Username == username { + numSessions++ + } + } + return numSessions +} + +// GetQuotaScans returns the active quota scans +func GetQuotaScans() []ActiveQuotaScan { + mutex.RLock() + defer mutex.RUnlock() + scans := make([]ActiveQuotaScan, len(activeQuotaScans)) + copy(scans, activeQuotaScans) + return scans +} + +// AddQuotaScan add an user to the ones with active quota scans. +// Returns false if the user has a quota scan already running +func AddQuotaScan(username string) bool { + mutex.Lock() + defer mutex.Unlock() + for _, s := range activeQuotaScans { + if s.Username == username { + return false + } + } + activeQuotaScans = append(activeQuotaScans, ActiveQuotaScan{ + Username: username, + StartTime: utils.GetTimeAsMsSinceEpoch(time.Now()), + }) + return true +} + +// RemoveQuotaScan remove and user from the ones with active quota scans +func RemoveQuotaScan(username string) { + mutex.Lock() + defer mutex.Unlock() + indexToRemove := -1 + for i, s := range activeQuotaScans { + if s.Username == username { + indexToRemove = i + break + } + } + if indexToRemove >= 0 { + activeQuotaScans[indexToRemove] = activeQuotaScans[len(activeQuotaScans)-1] + activeQuotaScans = activeQuotaScans[:len(activeQuotaScans)-1] + } +} + +// CloseActiveConnection close an active SFTP connection, returns true on success +func CloseActiveConnection(connectionID string) bool { + result := false + mutex.RLock() + defer mutex.RUnlock() + for _, c := range openConnections { + if c.ID == connectionID { + logger.Debug(logSender, "closing connection with id: %v", connectionID) + c.sshConn.Close() + result = true + break + } + } + return result +} + +// GetConnectionsStats returns stats for active connections +func GetConnectionsStats() []ConnectionStatus { + mutex.RLock() + defer mutex.RUnlock() + stats := []ConnectionStatus{} + for _, c := range openConnections { + conn := ConnectionStatus{ + Username: c.User.Username, + ConnectionID: c.ID, + ClientVersion: c.ClientVersion, + RemoteAddress: c.RemoteAddr.String(), + ConnectionTime: utils.GetTimeAsMsSinceEpoch(c.StartTime), + LastActivity: utils.GetTimeAsMsSinceEpoch(c.lastActivity), + Transfers: []connectionTransfer{}, + } + for _, t := range activeTransfers { + if t.connectionID == c.ID { + if utils.GetTimeAsMsSinceEpoch(t.lastActivity) > conn.LastActivity { + conn.LastActivity = utils.GetTimeAsMsSinceEpoch(t.lastActivity) + } + var operationType string + var size int64 + if t.transferType == transferUpload { + operationType = operationUpload + size = t.bytesReceived + } else { + operationType = operationDownload + size = t.bytesSent + } + connTransfer := connectionTransfer{ + OperationType: operationType, + StartTime: utils.GetTimeAsMsSinceEpoch(t.start), + Size: size, + LastActivity: utils.GetTimeAsMsSinceEpoch(t.lastActivity), + } + conn.Transfers = append(conn.Transfers, connTransfer) + } + } + stats = append(stats, conn) + } + return stats +} + +func startIdleTimer(maxIdleTime time.Duration) { + idleConnectionTicker = time.NewTicker(5 * time.Minute) + idleTimeout = maxIdleTime + go func() { + for t := range idleConnectionTicker.C { + logger.Debug(logSender, "idle connections check ticker %v", t) + checkIdleConnections() + } + }() +} + +func checkIdleConnections() { + mutex.RLock() + defer mutex.RUnlock() + for _, c := range openConnections { + idleTime := time.Since(c.lastActivity) + for _, t := range activeTransfers { + if t.connectionID == c.ID { + transferIdleTime := time.Since(t.lastActivity) + if transferIdleTime < idleTime { + logger.Debug(logSender, "idle time: %v setted to transfer idle time: %v connection id: %v", + idleTime, transferIdleTime, c.ID) + idleTime = transferIdleTime + } + } + } + if idleTime > idleTimeout { + logger.Debug(logSender, "close idle connection id: %v idle time: %v", c.ID, idleTime) + err := c.sshConn.Close() + if err != nil { + logger.Warn(logSender, "error closing idle connection: %v", err) + } + } + } + logger.Debug(logSender, "check idle connections ended") +} + +func addConnection(id string, conn Connection) { + mutex.Lock() + defer mutex.Unlock() + openConnections[id] = conn + logger.Debug(logSender, "connection added, num open connections: %v", len(openConnections)) +} + +func removeConnection(id string) { + mutex.Lock() + defer mutex.Unlock() + delete(openConnections, id) + logger.Debug(logSender, "connection removed, num open connections: %v", len(openConnections)) +} + +func addTransfer(transfer *Transfer) { + mutex.Lock() + defer mutex.Unlock() + activeTransfers = append(activeTransfers, transfer) +} + +func removeTransfer(transfer *Transfer) { + mutex.Lock() + defer mutex.Unlock() + indexToRemove := -1 + for i, v := range activeTransfers { + if v == transfer { + indexToRemove = i + break + } + } + if indexToRemove >= 0 { + //logger.Debug(logSender, "remove index %v from active transfer, size: %v", indexToRemove, len(activeTransfers)) + activeTransfers[indexToRemove] = activeTransfers[len(activeTransfers)-1] + activeTransfers = activeTransfers[:len(activeTransfers)-1] + } else { + logger.Warn(logSender, "transfer to remove not found!") + } +} + +func updateConnectionActivity(id string) { + mutex.Lock() + defer mutex.Unlock() + if c, ok := openConnections[id]; ok { + //logger.Debug(logSender, "update connection activity, id: %v", id) + c.lastActivity = time.Now() + openConnections[id] = c + } + //logger.Debug(logSender, "connection activity updated: %+v", openConnections) +} + +func logConnections() { + mutex.RLock() + defer mutex.RUnlock() + for _, c := range openConnections { + logger.Debug(logSender, "active connection %+v", c) + } +} + +func logTransfers() { + mutex.RLock() + defer mutex.RUnlock() + if len(activeTransfers) > 0 { + for _, v := range activeTransfers { + logger.Debug(logSender, "active transfer: %+v", v) + } + } else { + logger.Debug(logSender, "no active transfer") + } +} diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go new file mode 100644 index 00000000..222eac62 --- /dev/null +++ b/sftpd/sftpd_test.go @@ -0,0 +1,816 @@ +package sftpd_test + +import ( + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "os" + "path/filepath" + "runtime" + "testing" + + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/api" + "github.com/drakkan/sftpgo/dataprovider" + "github.com/pkg/sftp" +) + +// To run test cases you need to manually start sftpgo using port 2022 for sftp and 8080 for http API + +const ( + sftpServerAddr = "127.0.0.1:2022" + defaultUsername = "test_user" + defaultPassword = "test_password" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn +NhAAAAAwEAAQAAAYEAtN449A/nY5O6cSH/9Doa8a3ISU0WZJaHydTaCLuO+dkqtNpnV5mq +zFbKidXAI1eSwVctw9ReVOl1uK6aZF3lbXdOD8W9PXobR9KUUT2qBx5QC4ibfAqDKWymDA +PG9ylzz64hsYBqJr7VNk9kTFEUsDmWzLabLoH42Elnp8mF/lTkWIcpVp0ly/etS08gttXo +XenekJ1vRuxOYWDCEzGPU7kGc920TmM14k7IDdPoOh5+3sRUKedKeOUrVDH1f0n7QjHQsZ +cbshp8tgqzf734zu8cTqNrr+6taptdEOOij1iUL/qYGfzny/hA48tO5+UFUih5W8ftp0+E +NBIDkkGgk2MJ92I7QAXyMVsIABXco+mJT7pQi9tqlODGIQ3AOj0gcA3X/Ib8QX77Ih3TPi +XEh77/P1XiYZOgpp2cRmNH8QbqaL9u898hDvJwIPJPuj2lIltTElH7hjBf5LQfCzrLV7BD +10rM7sl4jr+A2q8jl1Ikp+25kainBBZSbrDummT9AAAFgDU/VLk1P1S5AAAAB3NzaC1yc2 +EAAAGBALTeOPQP52OTunEh//Q6GvGtyElNFmSWh8nU2gi7jvnZKrTaZ1eZqsxWyonVwCNX +ksFXLcPUXlTpdbiummRd5W13Tg/FvT16G0fSlFE9qgceUAuIm3wKgylspgwDxvcpc8+uIb +GAaia+1TZPZExRFLA5lsy2my6B+NhJZ6fJhf5U5FiHKVadJcv3rUtPILbV6F3p3pCdb0bs +TmFgwhMxj1O5BnPdtE5jNeJOyA3T6Doeft7EVCnnSnjlK1Qx9X9J+0Ix0LGXG7IafLYKs3 ++9+M7vHE6ja6/urWqbXRDjoo9YlC/6mBn858v4QOPLTuflBVIoeVvH7adPhDQSA5JBoJNj +CfdiO0AF8jFbCAAV3KPpiU+6UIvbapTgxiENwDo9IHAN1/yG/EF++yId0z4lxIe+/z9V4m +GToKadnEZjR/EG6mi/bvPfIQ7ycCDyT7o9pSJbUxJR+4YwX+S0Hws6y1ewQ9dKzO7JeI6/ +gNqvI5dSJKftuZGopwQWUm6w7ppk/QAAAAMBAAEAAAGAHKnC+Nq0XtGAkIFE4N18e6SAwy +0WSWaZqmCzFQM0S2AhJnweOIG/0ZZHjsRzKKauOTmppQk40dgVsejpytIek9R+aH172gxJ +2n4Cx0UwduRU5x8FFQlNc/kl722B0JWfJuB/snOZXv6LJ4o5aObIkozt2w9tVFeAqjYn2S +1UsNOfRHBXGsTYwpRDwFWP56nKo2d2wBBTHDhCy6fb2dLW1fvSi/YspueOGIlHpvlYKi2/ +CWqvs9xVrwcScMtiDoQYq0khhO0efLCxvg/o+W9CLMVM2ms4G1zoSUQKN0oYWWQJyW4+VI +YneWO8UpN0J3ElXKi7bhgAat7dBaM1g9IrAzk153DiEFZNsPxGOgL/+YdQN7zUBx/z7EkI +jyv80RV7fpUXvcq2p+qNl6UVig3VSzRrnsaJkUWu/A0u59ha7ocv6NxDIXjxpIDJme16GF +quiGVBQNnYJymS/vFEbGf6bgf7iRmMCRUMG4nqLA6fPYP9uAtch+CmDfVLZC/fIdC5AAAA +wQCDissV4zH6bfqgxJSuYNk8Vbb+19cF3b7gH1rVlB3zxpCAgcRgMHC+dP1z2NRx7UW9MR +nye6kjpkzZZ0OigLqo7TtEq8uTglD9o6W7mRXqhy5A/ySOmqPL3ernHHQhGuoNODYAHkOU +u2Rh8HXi+VLwKZcLInPOYJvcuLG4DxN8WfeVvlMHwhAOaTNNOtL4XZDHQeIPc4qHmJymmv +sV7GuyQ6yW5C10uoGdxRPd90Bh4z4h2bKfZFjvEBbSBVkqrlAAAADBAN/zNtNayd/dX7Cr +Nb4sZuzCh+CW4BH8GOePZWNCATwBbNXBVb5cR+dmuTqYm+Ekz0VxVQRA1TvKncluJOQpoa +Xj8r0xdIgqkehnfDPMKtYVor06B9Fl1jrXtXU0Vrr6QcBWruSVyK1ZxqcmcNK/+KolVepe +A6vcl/iKaG4U7su166nxLST06M2EgcSVsFJHpKn5+WAXC+X0Gx8kNjWIIb3GpiChdc0xZD +mq02xZthVJrTCVw/e7gfDoB2QRsNV8HwAAAMEAzsCghZVp+0YsYg9oOrw4tEqcbEXEMhwY +0jW8JNL8Spr1Ibp5Dw6bRSk5azARjmJtnMJhJ3oeHfF0eoISqcNuQXGndGQbVM9YzzAzc1 +NbbCNsVroqKlChT5wyPNGS+phi2bPARBno7WSDvshTZ7dAVEP2c9MJW0XwoSevwKlhgSdt +RLFFQ/5nclJSdzPBOmQouC0OBcMFSrYtMeknJ4VvueVvve5HcHFaEsaMc7ABAGaLYaBQOm +iixITGvaNZh/tjAAAACW5pY29sYUBwMQE= +-----END OPENSSH PRIVATE KEY-----` +) + +var ( + allPerms = []string{dataprovider.PermAny} + homeBasePath string +) + +func init() { + if runtime.GOOS == "windows" { + homeBasePath = "C:\\" + } else { + homeBasePath = "/tmp" + } +} + +func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: defaultUsername, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + if usePubKey { + key, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + return nil, err + } + config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + return sftpClient, err +} + +func createTestFile(path string, size int64) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + return ioutil.WriteFile(path, content, 0666) +} + +func getTestUser(usePubKey bool) dataprovider.User { + user := dataprovider.User{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Permissions: allPerms, + } + if usePubKey { + user.PublicKey = testPubKey + user.Password = "" + } + return user +} + +func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + destFile, err := client.Create(remoteDestPath) + if err != nil { + return err + } + defer destFile.Close() + _, err = io.Copy(destFile, srcFile) + if expectedSize > 0 { + fi, err := client.Lstat(remoteDestPath) + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error { + downloadDest, err := os.Create(localDestPath) + if err != nil { + return err + } + defer downloadDest.Close() + sftpSrcFile, err := client.Open(remoteSourcePath) + if err != nil { + return err + } + defer sftpSrcFile.Close() + _, err = io.Copy(downloadDest, sftpSrcFile) + if err != nil { + return err + } + err = downloadDest.Sync() + if err != nil { + return err + } + if expectedSize > 0 { + fi, err := downloadDest.Stat() + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func TestBasicSFTPHandling(t *testing.T) { + usePubKey := false + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + _, err = client.ReadDir(".") + if err != nil { + t.Errorf("unable to read remote dir: %v", err) + } + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := user.UsedQuotaSize + testFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + err = client.Symlink(testFileName, testFileName+".link") + if err != nil { + t.Errorf("error creating symlink: %v", err) + } + err = client.Remove(testFileName + ".link") + if err != nil { + t.Errorf("error removing symlink: %v", err) + } + localDownloadPath := filepath.Join(homeBasePath, "test_download.dat") + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + if err != nil { + t.Errorf("file download error: %v", err) + } + user, err = api.GetUserByID(user.ID, http.StatusOK) + if err != nil { + t.Errorf("error getting user: %v", err) + } + if expectedQuotaFiles != user.UsedQuotaFiles { + t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles) + } + if expectedQuotaSize != user.UsedQuotaSize { + t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize) + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + _, err = client.Lstat(testFileName) + if err == nil { + t.Errorf("stat for deleted file must not succeed") + } + user, err = api.GetUserByID(user.ID, http.StatusOK) + if err != nil { + t.Errorf("error getting user: %v", err) + } + if (expectedQuotaFiles - 1) != user.UsedQuotaFiles { + t.Errorf("quota files does not match after delete, expected: %v, actual: %v", expectedQuotaFiles-1, user.UsedQuotaFiles) + } + if (expectedQuotaSize - testFileSize) != user.UsedQuotaSize { + t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize-testFileSize, user.UsedQuotaSize) + } + err = client.Mkdir("test") + if err != nil { + t.Errorf("error mkdir: %v", err) + } + err = client.Rename("test", "test1") + if err != nil { + t.Errorf("error rename: %v", err) + } + err = client.Remove("test1") + if err != nil { + t.Errorf("error rmdir: %v", err) + } + err = client.MkdirAll("/test/test") + if err != nil { + t.Errorf("error mkdir all: %v", err) + } + err = client.Remove("/test") + if err != nil { + t.Errorf("error rmdir all: %v", err) + } + _, err = client.Lstat("/test") + if err == nil { + t.Errorf("stat for deleted dir must not succeed") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +// basic tests to verify virtual chroot, should be improved to cover more cases ... +func TestEscapeHomeDir(t *testing.T) { + usePubKey := true + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + testDir := "testDir" + linkPath := filepath.Join(homeBasePath, defaultUsername, testDir) + err = os.Symlink(homeBasePath, linkPath) + if err != nil { + t.Errorf("error making local symlink: %v", err) + } + _, err = client.ReadDir(testDir) + if err == nil { + t.Errorf("reading a symbolic link outside home dir should not suceeded") + } + os.Remove(linkPath) + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + remoteDestPath := filepath.Join("..", "..", testFileName) + err = sftpUploadFile(testFilePath, remoteDestPath, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + _, err = client.Lstat(testFileName) + if err != nil { + t.Errorf("file stat error: %v the file was created outside the user dir!", err) + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestHomeSpecialChars(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.HomeDir = filepath.Join(homeBasePath, "abc açà#&%lk") + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + files, err := client.ReadDir(".") + if err != nil { + t.Errorf("unable to read remote dir: %v", err) + } + if len(files) < 1 { + t.Errorf("expected at least 1 file in this dir") + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestLoginPubKey(t *testing.T) { + usePubKey := true + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + _, err = client.ReadDir(".") + if err != nil { + t.Errorf("unable to read remote dir: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestLoginAfterUserUpdateEmptyPwd(t *testing.T) { + usePubKey := false + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.Password = "" + user.PublicKey = "" + // password and public key should remain unchanged + _, err = api.UpdateUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to update user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + _, err = client.ReadDir(".") + if err != nil { + t.Errorf("unable to read remote dir: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestLoginAfterUserUpdateEmptyPubKey(t *testing.T) { + usePubKey := true + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + user.Password = "" + user.PublicKey = "" + // password and public key should remain unchanged + _, err = api.UpdateUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to update user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + _, err = client.ReadDir(".") + if err != nil { + t.Errorf("unable to read remote dir: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestMaxSessions(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.MaxSessions = 1 + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err := client.Getwd() + if err != nil { + t.Errorf("unable to get working dir: %v", err) + } + _, err = client.ReadDir(".") + if err != nil { + t.Errorf("unable to read remote dir: %v", err) + } + _, err = getSftpClient(user, usePubKey) + if err == nil { + t.Errorf("max sessions exceeded, new login should not succeed") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestQuotaScan(t *testing.T) { + usePubKey := false + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + testFileSize := int64(65535) + expectedQuotaSize := user.UsedQuotaSize + testFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } + // create user with the same home dir, so there is at least an untracked file + user, err = api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + err = api.StartQuotaScan(user, http.StatusCreated) + if err != nil { + t.Errorf("error starting quota scan: %v", err) + } + scans, err := api.GetQuotaScans(http.StatusOK) + if err != nil { + t.Errorf("error getting active quota scans: %v", err) + } + for len(scans) > 0 { + scans, err = api.GetQuotaScans(http.StatusOK) + if err != nil { + t.Errorf("error getting active quota scans: %v", err) + break + } + } + user, err = api.GetUserByID(user.ID, http.StatusOK) + if err != nil { + t.Errorf("error getting user: %v", err) + } + if expectedQuotaFiles != user.UsedQuotaFiles { + t.Errorf("quota files does not match after scan, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles) + } + if expectedQuotaSize != user.UsedQuotaSize { + t.Errorf("quota size does not match after scan, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize) + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermList(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + _, err = client.ReadDir(".") + if err == nil { + t.Errorf("read remote dir without permission should not succeed") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermDownload(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + localDownloadPath := filepath.Join(homeBasePath, "test_download.dat") + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + if err == nil { + t.Errorf("file download without permission should not succeed") + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermUpload(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermDelete, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err == nil { + t.Errorf("file upload without permission should not succeed") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermDelete(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + err = client.Remove(testFileName) + if err == nil { + t.Errorf("delete without permission should not succeed") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermRename(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + err = client.Rename(testFileName, testFileName+".rename") + if err == nil { + t.Errorf("rename without permission should not succeed") + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermCreateDirs(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateSymlinks} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + err = client.Mkdir("testdir") + if err == nil { + t.Errorf("mkdir without permission should not succeed") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestPermSymlink(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateDirs} + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + err = client.Symlink(testFilePath, testFilePath+".symlink") + if err == nil { + t.Errorf("symlink without permission should not succeed") + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} diff --git a/sftpd/transfer.go b/sftpd/transfer.go new file mode 100644 index 00000000..22e9bf4e --- /dev/null +++ b/sftpd/transfer.go @@ -0,0 +1,86 @@ +package sftpd + +import ( + "os" + "time" + + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" +) + +const ( + transferUpload = iota + transferDownload +) + +// Transfer struct, it contains the transfer details for an upload or a download +type Transfer struct { + file *os.File + path string + start time.Time + bytesSent int64 + bytesReceived int64 + user dataprovider.User + connectionID string + transferType int + lastActivity time.Time + isNewFile bool +} + +// ReadAt update sent bytes +func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) { + readed, e := t.file.ReadAt(p, off) + t.bytesSent += int64(readed) + t.lastActivity = time.Now() + t.handleThrottle() + return readed, e +} + +// WriteAt update received bytes +func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) { + written, e := t.file.WriteAt(p, off) + t.bytesReceived += int64(written) + t.lastActivity = time.Now() + t.handleThrottle() + return written, e +} + +// Close method called when the transfer is completed, we log the transfer info +func (t *Transfer) Close() error { + elapsed := time.Since(t.start).Nanoseconds() / 1000000 + if t.transferType == transferDownload { + logger.TransferLog(sftpdDownloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID) + } else { + logger.TransferLog(sftpUploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID) + } + removeTransfer(t) + if t.transferType == transferUpload && t.bytesReceived > 0 { + numFiles := 0 + if t.isNewFile { + numFiles++ + } + dataprovider.UpdateUserQuota(dataProvider, t.user.Username, numFiles, t.bytesReceived, false) + } + return t.file.Close() +} + +func (t *Transfer) handleThrottle() { + var wantedBandwidth int64 + var trasferredBytes int64 + if t.transferType == transferDownload { + wantedBandwidth = t.user.DownloadBandwidth + trasferredBytes = t.bytesSent + } else { + wantedBandwidth = t.user.UploadBandwidth + trasferredBytes = t.bytesReceived + } + if wantedBandwidth > 0 { + // real and wanted elapsed as milliseconds, bytes as kilobytes + realElapsed := time.Since(t.start).Nanoseconds() / 1000000 + wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth + if wantedElapsed > realElapsed { + toSleep := time.Duration(wantedElapsed - realElapsed) + time.Sleep(toSleep * time.Millisecond) + } + } +} diff --git a/sftpgo.conf b/sftpgo.conf new file mode 100644 index 00000000..972d5229 --- /dev/null +++ b/sftpgo.conf @@ -0,0 +1,25 @@ +{ + "sftpd":{ + "bind_port":2022, + "bind_address": "", + "idle_timeout": 15, + "umask": "0022" + }, + "data_provider": { + "driver": "sqlite", + "name": "sftpgo.db", + "host": "", + "port": 5432, + "username": "", + "password": "", + "sslmode": 0, + "connection_string": "", + "users_table": "users", + "manage_users": 1, + "track_quota": 1 + }, + "httpd":{ + "bind_port":8080, + "bind_address": "127.0.0.1" + } +} diff --git a/sql/mysql/20190706.sql b/sql/mysql/20190706.sql new file mode 100644 index 00000000..2e040cc5 --- /dev/null +++ b/sql/mysql/20190706.sql @@ -0,0 +1,6 @@ +BEGIN; +-- +-- Create model User +-- +CREATE TABLE `users` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, `password` varchar(255) NULL, `public_key` longtext NULL, `home_dir` varchar(255) NOT NULL, `uid` integer NOT NULL, `gid` integer NOT NULL, `max_sessions` integer NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `permissions` longtext NOT NULL, `used_quota_size` bigint NOT NULL, `used_quota_files` integer NOT NULL, `last_quota_scan` bigint NOT NULL, `upload_bandwidth` integer NOT NULL, `download_bandwidth` integer NOT NULL); +COMMIT; diff --git a/sql/pgsql/20190706.sql b/sql/pgsql/20190706.sql new file mode 100644 index 00000000..9fbf1847 --- /dev/null +++ b/sql/pgsql/20190706.sql @@ -0,0 +1,7 @@ +BEGIN; +-- +-- Create model User +-- +CREATE TABLE "users" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "password" varchar(255) NULL, "public_key" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL, "gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_scan" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL); +COMMIT; + diff --git a/sql/sqlite/20190706.sql b/sql/sqlite/20190706.sql new file mode 100644 index 00000000..4395f0cb --- /dev/null +++ b/sql/sqlite/20190706.sql @@ -0,0 +1,6 @@ +BEGIN; +-- +-- Create model User +-- +CREATE TABLE "users" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE, "password" varchar(255) NULL, "public_key" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL, "gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_scan" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL); +COMMIT; diff --git a/utils/umask_unix.go b/utils/umask_unix.go new file mode 100644 index 00000000..a08f85ef --- /dev/null +++ b/utils/umask_unix.go @@ -0,0 +1,15 @@ +// +build !windows + +package utils + +import ( + "syscall" + + "github.com/drakkan/sftpgo/logger" +) + +// SetUmask set umask on unix systems +func SetUmask(umask int, configValue string) { + logger.Debug(logSender, "set umask to %v (%v)", configValue, umask) + syscall.Umask(umask) +} diff --git a/utils/umask_windows.go b/utils/umask_windows.go new file mode 100644 index 00000000..5ac774fc --- /dev/null +++ b/utils/umask_windows.go @@ -0,0 +1,8 @@ +package utils + +import "github.com/drakkan/sftpgo/logger" + +// SetUmask does nothing on windows +func SetUmask(umask int, configValue string) { + logger.Debug(logSender, "umask not available on windows, configured value %v (%v)", configValue, umask) +} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..56b79fa1 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,69 @@ +package utils + +import ( + "os" + "path/filepath" + "runtime" + "time" + + "github.com/drakkan/sftpgo/logger" +) + +const logSender = "utils" + +// IsStringInSlice search a string in a slice +func IsStringInSlice(obj string, list []string) bool { + for _, v := range list { + if v == obj { + return true + } + } + return false +} + +// GetTimeAsMsSinceEpoch returns unix timestamp as milliseconds from a time struct +func GetTimeAsMsSinceEpoch(t time.Time) int64 { + return t.UnixNano() / 1000000 +} + +// ScanDirContents returns the number of files contained in a directory and their size +func ScanDirContents(path string) (int, int64, error) { + var numFiles int + var size int64 + var err error + numFiles = 0 + size = 0 + isDir, err := isDirectory(path) + if err == nil && isDir { + err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info != nil && info.Mode().IsRegular() { + size += info.Size() + numFiles++ + } + + return err + }) + } + + return numFiles, size, err +} + +func isDirectory(path string) (bool, error) { + fileInfo, err := os.Stat(path) + if err != nil { + return false, err + } + return fileInfo.IsDir(), err +} + +// SetPathPermissions call os.Chown on unix does nothing on windows +func SetPathPermissions(path string, uid int, gid int) { + if runtime.GOOS != "windows" { + if err := os.Chown(path, uid, gid); err != nil { + logger.Warn(logSender, "error chowning path %v: %v", path, err) + } + } +}