From c8f7fc9bc9977b2bd79690a348c4f06f3a40554b Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Tue, 11 May 2021 06:54:06 +0200 Subject: [PATCH] httpd/webdav: add a list of hosts allowed to send proxy headers X-Forwarded-For, X-Real-IP and X-Forwarded-Proto headers will be ignored for hosts not included in this list. This is a backward incompatible change, before the proxy headers were always used --- Dockerfile | 3 +- Dockerfile.alpine | 3 +- common/common_test.go | 13 + config/config.go | 14 + config/config_test.go | 8 + docs/full-configuration.md | 4 +- docs/howto/postgresql-s3.md | 8 +- ftpd/ftpd_test.go | 16 +- httpd/auth_utils.go | 14 +- httpd/httpd.go | 30 +- httpd/httpd_test.go | 36 ++- httpd/internal_test.go | 153 +++++++-- httpd/middleware.go | 28 +- httpd/server.go | 462 ++++++++++++++------------- httpd/webclient.go | 24 -- sftpd/sftpd_test.go | 28 +- sftpgo.json | 8 +- templates/webadmin/base.html | 2 + templates/webclient/base.html | 2 + templates/webclient/credentials.html | 2 +- utils/utils.go | 50 +++ webdavd/internal_test.go | 69 ++-- webdavd/server.go | 49 ++- webdavd/webdavd.go | 16 + webdavd/webdavd_test.go | 10 + 25 files changed, 669 insertions(+), 383 deletions(-) diff --git a/Dockerfile b/Dockerfile index 794bdc25..a680433f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -52,8 +52,7 @@ ENV SFTPGO_HTTPD__STATIC_FILES_PATH=/usr/share/sftpgo/static # Modify the default configuration file RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \ - sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json && \ - sed -i "s|\"address\": \"127.0.0.1\",|\"address\": \"\",|" /etc/sftpgo/sftpgo.json + sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups diff --git a/Dockerfile.alpine b/Dockerfile.alpine index 05dba742..a5d83b44 100644 --- a/Dockerfile.alpine +++ b/Dockerfile.alpine @@ -57,8 +57,7 @@ ENV SFTPGO_HTTPD__STATIC_FILES_PATH=/usr/share/sftpgo/static # Modify the default configuration file RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \ - sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json && \ - sed -i "s|\"address\": \"127.0.0.1\",|\"address\": \"\",|" /etc/sftpgo/sftpgo.json + sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups diff --git a/common/common_test.go b/common/common_test.go index 875f87ba..2cd8259d 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -697,6 +697,19 @@ func TestCachedFs(t *testing.T) { assert.NoError(t, err) } +func TestParseAllowedIPAndRanges(t *testing.T) { + _, err := utils.ParseAllowedIPAndRanges([]string{"1.1.1.1", "not an ip"}) + assert.Error(t, err) + _, err = utils.ParseAllowedIPAndRanges([]string{"1.1.1.5", "192.168.1.0/240"}) + assert.Error(t, err) + allow, err := utils.ParseAllowedIPAndRanges([]string{"192.168.1.2", "172.16.0.0/24"}) + assert.NoError(t, err) + assert.True(t, allow[0](net.ParseIP("192.168.1.2"))) + assert.False(t, allow[0](net.ParseIP("192.168.2.2"))) + assert.True(t, allow[1](net.ParseIP("172.16.0.1"))) + assert.False(t, allow[1](net.ParseIP("172.16.1.1"))) +} + func BenchmarkBcryptHashing(b *testing.B) { bcryptPassword := "bcryptpassword" for i := 0; i < b.N; i++ { diff --git a/config/config.go b/config/config.go index b3aa3a4d..5c4c9c9e 100644 --- a/config/config.go +++ b/config/config.go @@ -60,6 +60,7 @@ var ( ClientAuthType: 0, TLSCipherSuites: nil, Prefix: "", + ProxyAllowed: nil, } defaultHTTPDBinding = httpd.Binding{ Address: "127.0.0.1", @@ -69,6 +70,7 @@ var ( EnableHTTPS: false, ClientAuthType: 0, TLSCipherSuites: nil, + ProxyAllowed: nil, } defaultRateLimiter = common.RateLimiterConfig{ Average: 0, @@ -768,6 +770,12 @@ func getWebDAVDBindingFromEnv(idx int) { isSet = true } + proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_ALLOWED", idx)) + if ok { + binding.ProxyAllowed = proxyAllowed + isSet = true + } + prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx)) if ok { binding.Prefix = prefix @@ -833,6 +841,12 @@ func getHTTPDBindingFromEnv(idx int) { isSet = true } + proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_ALLOWED", idx)) + if ok { + binding.ProxyAllowed = proxyAllowed + isSet = true + } + if isSet { if len(globalConf.HTTPDConfig.Bindings) > idx { globalConf.HTTPDConfig.Bindings[idx] = binding diff --git a/config/config_test.go b/config/config_test.go index 27776cc3..421ca35f 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -572,6 +572,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) { os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT", "8000") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED", "192.168.10.1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS", "127.0.1.1") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT", "9000") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS", "1") @@ -582,6 +583,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) { os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS") @@ -605,6 +607,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) { require.Equal(t, 0, bindings[1].ClientAuthType) require.Len(t, bindings[1].TLSCipherSuites, 1) require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0]) + require.Equal(t, "192.168.10.1", bindings[1].ProxyAllowed[0]) require.Empty(t, bindings[1].Prefix) require.Equal(t, 9000, bindings[2].Port) require.Equal(t, "127.0.1.1", bindings[2].Address) @@ -634,6 +637,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) { os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED", " 192.168.9.1 , 172.16.25.0/24") t.Cleanup(func() { os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__PORT") @@ -650,6 +654,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) { os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED") }) configDir := ".." @@ -680,6 +685,9 @@ func TestHTTPDBindingsFromEnv(t *testing.T) { require.Len(t, bindings[2].TLSCipherSuites, 2) require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0]) require.Equal(t, "TLS_CHACHA20_POLY1305_SHA256", bindings[2].TLSCipherSuites[1]) + require.Len(t, bindings[2].ProxyAllowed, 2) + require.Equal(t, "192.168.9.1", bindings[2].ProxyAllowed[0]) + require.Equal(t, "172.16.25.0/24", bindings[2].ProxyAllowed[1]) } func TestHTTPClientCertificatesFromEnv(t *testing.T) { diff --git a/docs/full-configuration.md b/docs/full-configuration.md index d49dd544..d6c87d6e 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -147,6 +147,7 @@ The configuration file contains the following sections: - `client_auth_type`, integer. Set to `1` to require a client certificate and verify it. Set to `2` to request a client certificate during the TLS handshake and verify it if given, in this mode the client is allowed not to send a certificate. At least one certification authority must be defined in order to verify client certificates. If no certification authority is defined, this setting is ignored. Default: 0. - `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty. - `prefix`, string. Prefix for WebDAV resources, if empty WebDAV resources will be available at the `/` URI. If defined it must be an absolute URI, for example `/dav`. Default: "". + - `proxy_allowed`, list of IP addresses and IP ranges allowed to set `X-Forwarded-For`, `X-Real-IP` headers. Any of the indicated headers, if set on requests from a connection address not in this list, will be silently ignored. Default: empty. - `bind_port`, integer. Deprecated, please use `bindings`. - `bind_address`, string. Deprecated, please use `bindings`. - `certificate_file`, string. Certificate for WebDAV over HTTPS. This can be an absolute path or a path relative to the config dir. @@ -216,8 +217,9 @@ The configuration file contains the following sections: - `enable_https`, boolean. Set to `true` and provide both a certificate and a key file to enable HTTPS connection for this binding. Default `false`. - `client_auth_type`, integer. Set to `1` to require client certificate authentication in addition to JWT/Web authentication. You need to define at least a certificate authority for this to work. Default: 0. - `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty. + - `proxy_allowed`, list of IP addresses and IP ranges allowed to set `X-Forwarded-For`, `X-Real-IP`, `X-Forwarded-Proto` headers. Any of the indicated headers, if set on requests from a connection address not in this list, will be silently ignored. Default: empty. - `bind_port`, integer. Deprecated, please use `bindings`. - - `bind_address`, string. Deprecated, please use `bindings`. Leave blank to listen on all available network interfaces. On \*NIX you can specify an absolute path to listen on a Unix-domain socket. Default: "127.0.0.1" + - `bind_address`, string. Deprecated, please use `bindings`. Leave blank to listen on all available network interfaces. On \*NIX you can specify an absolute path to listen on a Unix-domain socket. Default: "" - `templates_path`, string. Path to the HTML web templates. This can be an absolute path or a path relative to the config dir - `static_files_path`, string. Path to the static files for the web interface. This can be an absolute path or a path relative to the config dir. If both `templates_path` and `static_files_path` are empty the built-in web interface will be disabled - `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons diff --git a/docs/howto/postgresql-s3.md b/docs/howto/postgresql-s3.md index 4e932f53..468ba19d 100644 --- a/docs/howto/postgresql-s3.md +++ b/docs/howto/postgresql-s3.md @@ -192,13 +192,7 @@ systemctl status sftpgo The easiest way to add virtual users is to use the built-in Web interface. -You can expose the Web Admin interface over the network replacing `"bind_address": "127.0.0.1"` in the `httpd` configuration section with `"bind_address": ""` and apply the change restarting the SFTPGo service with the following command. - -```shell -sudo systemctl restart sftpgo -``` - -So now open the Web Admin URL. +So navigate to the Web Admin URL. [http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin) diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 3b28d450..5aabdc1c 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -760,8 +760,10 @@ func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 - user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + user := getTestUser() + err := dataprovider.AddUser(&user) assert.NoError(t, err) + user.Password = "" client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) @@ -771,7 +773,7 @@ func TestMaxConnections(t *testing.T) { err = client.Quit() assert.NoError(t, err) } - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) @@ -783,8 +785,10 @@ func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 - user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + user := getTestUser() + err := dataprovider.AddUser(&user) assert.NoError(t, err) + user.Password = "" client, err := getFTPClient(user, true, nil) if assert.NoError(t, err) { err = checkBasicFTP(client) @@ -794,7 +798,7 @@ func TestMaxPerHostConnections(t *testing.T) { err = client.Quit() assert.NoError(t, err) } - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) @@ -851,7 +855,7 @@ func TestRateLimiter(t *testing.T) { assert.Contains(t, err.Error(), "banned client IP") } - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) @@ -893,7 +897,7 @@ func TestDefender(t *testing.T) { assert.Contains(t, err.Error(), "banned client IP") } - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) diff --git a/httpd/auth_utils.go b/httpd/auth_utils.go index 29b4a447..987329a0 100644 --- a/httpd/auth_utils.go +++ b/httpd/auth_utils.go @@ -137,7 +137,7 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque Path: basePath, Expires: time.Now().Add(tokenDuration), HttpOnly: true, - Secure: r.TLS != nil, + Secure: isTLS(r), }) return nil @@ -150,11 +150,21 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request) { Path: webBasePath, MaxAge: -1, HttpOnly: true, - Secure: r.TLS != nil, + Secure: isTLS(r), }) invalidateToken(r) } +func isTLS(r *http.Request) bool { + if r.TLS != nil { + return true + } + if proto, ok := r.Context().Value(forwardedProtoKey).(string); ok { + return proto == "https" + } + return false +} + func isTokenInvalidated(r *http.Request) bool { isTokenFound := false token := jwtauth.TokenFromHeader(r) diff --git a/httpd/httpd.go b/httpd/httpd.go index c9e1edbc..71f72086 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -6,6 +6,7 @@ package httpd import ( "fmt" + "net" "net/http" "net/url" "path" @@ -156,6 +157,19 @@ type Binding struct { // any invalid name will be silently ignored. // The order matters, the ciphers listed first will be the preferred ones. TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` + // List of IP addresses and IP ranges allowed to set X-Forwarded-For, X-Real-IP, + // X-Forwarded-Proto headers. + ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` + allowHeadersFrom []func(net.IP) bool +} + +func (b *Binding) parseAllowedProxy() error { + allowedFuncs, err := utils.ParseAllowedIPAndRanges(b.ProxyAllowed) + if err != nil { + return err + } + b.allowHeadersFrom = allowedFuncs + return nil } // GetAddress returns the binding address @@ -252,6 +266,14 @@ func (c *Conf) isWebClientEnabled() bool { return false } +func (c *Conf) checkRequiredDirs(staticFilesPath, templatesPath string) error { + if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") { + return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v", + staticFilesPath, templatesPath) + } + return nil +} + // Initialize configures and starts the HTTP server func (c *Conf) Initialize(configDir string) error { logger.Debug(logSender, "", "initializing HTTP server with config %+v", c) @@ -261,9 +283,8 @@ func (c *Conf) Initialize(configDir string) error { if backupsPath == "" { return fmt.Errorf("required directory is invalid, backup path %#v", backupsPath) } - if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") { - return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v", - staticFilesPath, templatesPath) + if err := c.checkRequiredDirs(staticFilesPath, templatesPath); err != nil { + return err } certificateFile := getConfigPath(c.CertificateFile, configDir) certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir) @@ -303,6 +324,9 @@ func (c *Conf) Initialize(configDir string) error { if !binding.IsValid() { continue } + if err := binding.parseAllowedProxy(); err != nil { + return err + } go func(b Binding) { server := newHttpdServer(b, staticFilesPath) diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index b56d196b..84ba0309 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -326,6 +326,12 @@ func TestInitialization(t *testing.T) { err = httpdConf.Initialize(configDir) assert.Error(t, err) httpdConf.CARevocationLists = nil + httpdConf.Bindings[0].ProxyAllowed = []string{"invalid ip/network"} + err = httpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not a valid IP range") + } + httpdConf.Bindings[0].ProxyAllowed = nil httpdConf.Bindings[0].EnableWebAdmin = false httpdConf.Bindings[0].EnableWebClient = false httpdConf.Bindings[0].Port = 8081 @@ -3288,6 +3294,22 @@ func TestRateLimiter(t *testing.T) { err = resp.Body.Close() assert.NoError(t, err) + resp, err = client.Get(httpBaseURL + webLoginPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, "1", resp.Header.Get("Retry-After")) + assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) + err = resp.Body.Close() + assert.NoError(t, err) + + resp, err = client.Get(httpBaseURL + webClientLoginPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, "1", resp.Header.Get("Retry-After")) + assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) + err = resp.Body.Close() + assert.NoError(t, err) + err = common.Initialize(oldConfig) assert.NoError(t, err) } @@ -4625,10 +4647,13 @@ func TestDefender(t *testing.T) { remoteAddr := "172.16.5.6:9876" + webAdminToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = remoteAddr + req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -4642,11 +4667,20 @@ func TestDefender(t *testing.T) { assert.Error(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = remoteAddr + req.RequestURI = webClientFilesPath setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "your IP address is banned") + req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil) + req.RemoteAddr = remoteAddr + req.RequestURI = webUsersPath + setJWTCookieForReq(req, webAdminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "your IP address is banned") + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = remoteAddr req.Header.Set("X-Real-IP", "127.0.0.1:2345") @@ -5192,7 +5226,7 @@ func TestWebAdminLoginMock(t *testing.T) { req.Header.Set("X-Forwarded-For", "10.9.9.9") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - assert.Contains(t, rr.Body.String(), "Login from IP 127.0.1.1:4567 is not allowed") + assert.Contains(t, rr.Body.String(), "login from IP 127.0.1.1 not allowed") // invalid csrf token form = getLoginForm(altAdminUsername, altAdminPassword, "invalid csrf") diff --git a/httpd/internal_test.go b/httpd/internal_test.go index 2f91bf94..e056be03 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -424,7 +424,7 @@ func TestCreateTokenError(t *testing.T) { } req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) - server.checkAddrAndSendToken(rr, req, admin) + server.generateAndSendToken(rr, req, admin) assert.Equal(t, http.StatusInternalServerError, rr.Code) rr = httptest.NewRecorder() @@ -565,22 +565,6 @@ func TestJWTTokenValidation(t *testing.T) { assert.Equal(t, http.StatusBadRequest, rr.Code) } -func TestAdminAllowListConnAddr(t *testing.T) { - server := httpdServer{} - admin := dataprovider.Admin{ - Filters: dataprovider.AdminFilters{ - AllowList: []string{"192.168.1.0/24"}, - }, - } - rr := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) - ctx := context.WithValue(req.Context(), connAddrKey, "127.0.0.1:4567") - req.RemoteAddr = "192.168.1.16:1234" - server.checkAddrAndSendToken(rr, req.WithContext(ctx), admin) - assert.Equal(t, http.StatusForbidden, rr.Code, rr.Body.String()) - assert.Equal(t, "context value connection address", connAddrKey.String()) -} - func TestUpdateContextFromCookie(t *testing.T) { server := httpdServer{ tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil), @@ -672,14 +656,6 @@ func TestCookieExpiration(t *testing.T) { cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) - req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) - req.RemoteAddr = "172.16.1.2:1234" - ctx = jwtauth.NewContext(req.Context(), token, nil) - ctx = context.WithValue(ctx, connAddrKey, "10.9.9.9") - server.checkCookieExpiration(rr, req.WithContext(ctx)) - cookie = rr.Header().Get("Set-Cookie") - assert.Empty(t, cookie) - req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) req.RemoteAddr = "172.16.1.12:4567" ctx = jwtauth.NewContext(req.Context(), token, nil) @@ -749,17 +725,10 @@ func TestCookieExpiration(t *testing.T) { cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) - req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) - req.RemoteAddr = "172.16.4.12:4567" - ctx = jwtauth.NewContext(req.Context(), token, nil) - server.checkCookieExpiration(rr, req.WithContext(context.WithValue(ctx, connAddrKey, "172.16.0.1:4567"))) - cookie = rr.Header().Get("Set-Cookie") - assert.Empty(t, cookie) - req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.4.16:4567" ctx = jwtauth.NewContext(req.Context(), token, nil) - server.checkCookieExpiration(rr, req.WithContext(context.WithValue(ctx, connAddrKey, "172.16.4.18:4567"))) + server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) @@ -1014,6 +983,111 @@ func TestJWTTokenCleanup(t *testing.T) { stopJWTTokensCleanupTicker() } +func TestProxyHeaders(t *testing.T) { + username := "adminTest" + password := "testPwd" + admin := dataprovider.Admin{ + Username: username, + Password: password, + Permissions: []string{dataprovider.PermAdminAny}, + Status: 1, + Filters: dataprovider.AdminFilters{ + AllowList: []string{"172.19.2.0/24"}, + }, + } + + err := dataprovider.AddAdmin(&admin) + assert.NoError(t, err) + + testIP := "10.29.1.9" + validForwardedFor := "172.19.2.6" + b := Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: false, + ProxyAllowed: []string{testIP, "10.8.0.0/30"}, + } + err = b.parseAllowedProxy() + assert.NoError(t, err) + server := newHttpdServer(b, "") + server.initializeRouter() + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + req, err := http.NewRequest(http.MethodGet, tokenPath, nil) + assert.NoError(t, err) + req.Header.Set("X-Forwarded-For", validForwardedFor) + req.Header.Set(xForwardedProto, "https") + req.RemoteAddr = "127.0.0.1:123" + req.SetBasicAuth(username, password) + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.Contains(t, rr.Body.String(), "login from IP 127.0.0.1 not allowed") + + req.RemoteAddr = testIP + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + req.RemoteAddr = "10.8.0.2" + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + form := make(url.Values) + form.Set("username", username) + form.Set("password", password) + form.Set(csrfFormToken, createCSRFToken()) + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), "login from IP 10.29.1.9 not allowed") + + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + cookie := rr.Header().Get("Set-Cookie") + assert.NotContains(t, cookie, "Secure") + + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + req.Header.Set(xForwardedProto, "https") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + cookie = rr.Header().Get("Set-Cookie") + assert.Contains(t, cookie, "Secure") + + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + req.Header.Set(xForwardedProto, "http") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + cookie = rr.Header().Get("Set-Cookie") + assert.NotContains(t, cookie, "Secure") + + err = dataprovider.DeleteAdmin(username) + assert.NoError(t, err) +} + func TestWebAdminRedirect(t *testing.T) { b := Binding{ Address: "", @@ -1306,3 +1380,16 @@ func TestManageKeysInvalidClaims(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token claims") } + +func TestTLSReq(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + req.TLS = &tls.ConnectionState{} + assert.True(t, isTLS(req)) + req.TLS = nil + ctx := context.WithValue(req.Context(), forwardedProtoKey, "https") + assert.True(t, isTLS(req.WithContext(ctx))) + ctx = context.WithValue(req.Context(), forwardedProtoKey, "http") + assert.False(t, isTLS(req.WithContext(ctx))) + assert.Equal(t, "context value forwarded proto", forwardedProtoKey.String()) +} diff --git a/httpd/middleware.go b/httpd/middleware.go index 77caee9f..3c1ebbe6 100644 --- a/httpd/middleware.go +++ b/httpd/middleware.go @@ -1,23 +1,19 @@ package httpd import ( - "context" "errors" - "fmt" "net/http" - "time" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/jwt" - "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/utils" ) var ( - connAddrKey = &contextKey{"connection address"} - errInvalidToken = errors.New("invalid JWT token") + forwardedProtoKey = &contextKey{"forwarded proto"} + errInvalidToken = errors.New("invalid JWT token") ) type contextKey struct { @@ -28,13 +24,6 @@ func (k *contextKey) String() string { return "context value " + k.name } -func saveConnectionAddress(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), connAddrKey, r.RemoteAddr) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { token, _, err := jwtauth.FromContext(r.Context()) @@ -188,16 +177,3 @@ func verifyCSRFHeader(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } - -func rateLimiter(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if delay, err := common.LimitRate(common.ProtocolHTTP, utils.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { - delay += 499999999 * time.Nanosecond - w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) - w.Header().Set("X-Retry-In", delay.String()) - sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/httpd/server.go b/httpd/server.go index 5895f335..7579c64a 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -1,11 +1,13 @@ package httpd import ( + "context" "crypto/tls" "crypto/x509" "errors" "fmt" "log" + "net" "net/http" "time" @@ -23,7 +25,10 @@ import ( "github.com/drakkan/sftpgo/version" ) -var compressor = middleware.NewCompressor(5) +var ( + compressor = middleware.NewCompressor(5) + xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") +) type httpdServer struct { binding Binding @@ -111,10 +116,6 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler { func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize) - ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) - common.Connections.AddClientConnection(ipAddr) - defer common.Connections.RemoveClientConnection(ipAddr) - if err := r.ParseForm(); err != nil { renderClientLoginPage(w, err.Error()) return @@ -130,16 +131,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re return } - if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached") - renderClientLoginPage(w, "configured connections limit reached") - return - } - if common.IsBanned(ipAddr) { - renderClientLoginPage(w, "your IP address is banned") - return - } - + ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil { renderClientLoginPage(w, fmt.Sprintf("access denied by post connect hook: %v", err)) return @@ -204,14 +196,6 @@ func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Req renderLoginPage(w, err.Error()) return } - if connAddr, ok := r.Context().Value(connAddrKey).(string); ok { - if connAddr != r.RemoteAddr { - if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) { - renderLoginPage(w, fmt.Sprintf("Login from IP %v is not allowed", connAddr)) - return - } - } - } c := jwtTokenClaims{ Username: admin.Username, Permissions: admin.Permissions, @@ -246,19 +230,10 @@ func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) { return } - s.checkAddrAndSendToken(w, r, admin) + s.generateAndSendToken(w, r, admin) } -func (s *httpdServer) checkAddrAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin) { - if connAddr, ok := r.Context().Value(connAddrKey).(string); ok { - if connAddr != r.RemoteAddr { - if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) { - sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) - return - } - } - } - +func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin) { c := jwtTokenClaims{ Username: admin.Username, Permissions: admin.Permissions, @@ -330,15 +305,6 @@ func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie", admin.Username, r.RemoteAddr) return } - if connAddr, ok := r.Context().Value(connAddrKey).(string); ok { - if connAddr != r.RemoteAddr { - if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) { - logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie", - admin.Username, connAddr) - return - } - } - } logger.Debug(logSender, "", "cookie refreshed for admin %#v", admin.Username) tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebAdmin) //nolint:errcheck } @@ -357,200 +323,266 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request { return r } +func (s *httpdServer) checkConnection(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) + ip := net.ParseIP(ipAddr) + if ip != nil { + for _, allow := range s.binding.allowHeadersFrom { + if allow(ip) { + parsedIP := utils.GetRealIP(r) + if parsedIP != "" { + ipAddr = parsedIP + r.RemoteAddr = ipAddr + } + if forwardedProto := r.Header.Get(xForwardedProto); forwardedProto != "" { + ctx := context.WithValue(r.Context(), forwardedProtoKey, forwardedProto) + r = r.WithContext(ctx) + } + break + } + } + } + + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) + + if !common.Connections.IsNewConnectionAllowed(ipAddr) { + logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached") + s.sendForbiddenResponse(w, r, "configured connections limit reached") + return + } + if common.IsBanned(ipAddr) { + s.sendForbiddenResponse(w, r, "your IP address is banned") + return + } + if delay, err := common.LimitRate(common.ProtocolHTTP, ipAddr); err != nil { + delay += 499999999 * time.Nanosecond + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) + w.Header().Set("X-Retry-In", delay.String()) + s.sendTooManyRequestResponse(w, r, err) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) sendTooManyRequestResponse(w http.ResponseWriter, r *http.Request, err error) { + if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { + r = s.updateContextFromCookie(r) + if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { + renderClientMessagePage(w, r, http.StatusText(http.StatusTooManyRequests), "Rate limit exceeded", + http.StatusTooManyRequests, err, "") + return + } + renderMessagePage(w, r, http.StatusText(http.StatusTooManyRequests), "Rate limit exceeded", http.StatusTooManyRequests, + err, "") + return + } + sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) +} + +func (s *httpdServer) sendForbiddenResponse(w http.ResponseWriter, r *http.Request, message string) { + if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { + r = s.updateContextFromCookie(r) + if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { + renderClientForbiddenPage(w, r, message) + return + } + renderForbiddenPage(w, r, message) + return + } + sendAPIResponse(w, r, errors.New(message), message, http.StatusForbidden) +} + func (s *httpdServer) initializeRouter() { s.tokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil) s.router = chi.NewRouter() - s.router.Use(saveConnectionAddress) + s.router.Use(middleware.RequestID) + s.router.Use(logger.NewStructuredLogger(logger.GetLogger())) + s.router.Use(middleware.Recoverer) + s.router.Use(s.checkConnection) s.router.Use(middleware.GetHead) s.router.Use(middleware.StripSlashes) - s.router.Use(middleware.RealIP) - s.router.Use(rateLimiter) - s.router.Group(func(r chi.Router) { - r.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) { - render.PlainText(w, r, "ok") - }) - }) - - s.router.Group(func(router chi.Router) { - router.Use(middleware.RequestID) - router.Use(logger.NewStructuredLogger(logger.GetLogger())) - router.Use(middleware.Recoverer) - - router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { - r = s.updateContextFromCookie(r) - if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { - renderClientNotFoundPage(w, r, nil) - return - } - renderNotFoundPage(w, r, nil) + s.router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { + r = s.updateContextFromCookie(r) + if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { + renderClientNotFoundPage(w, r, nil) return } - sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) - })) + renderNotFoundPage(w, r, nil) + return + } + sendAPIResponse(w, r, nil, http.StatusText(http.StatusNotFound), http.StatusNotFound) + })) - router.Get(tokenPath, s.getToken) + s.router.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) { + render.PlainText(w, r, "ok") + }) - router.Group(func(router chi.Router) { - router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader)) - router.Use(jwtAuthenticatorAPI) + s.router.Get(tokenPath, s.getToken) - router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, version.Get()) - }) + s.router.Group(func(router chi.Router) { + router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader)) + router.Use(jwtAuthenticatorAPI) - router.Get(logoutPath, s.logout) - router.Put(adminPwdPath, changeAdminPassword) - - router.With(checkPerm(dataprovider.PermAdminViewServerStatus)). - Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, getServicesStatus()) - }) - - router.With(checkPerm(dataprovider.PermAdminViewConnections)). - Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, common.Connections.GetStats()) - }) - - router.With(checkPerm(dataprovider.PermAdminCloseConnections)). - Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection) - router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanPath, getQuotaScans) - router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanPath, startQuotaScan) - router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanVFolderPath, getVFolderQuotaScans) - router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanVFolderPath, startVFolderQuotaScan) - router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers) - router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(userPath, addUser) - router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser) - router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser) - router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath, getFolders) - router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath+"/{name}", getFolderByName) - router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(folderPath, addFolder) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(folderPath+"/{name}", updateFolder) - router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(folderPath+"/{name}", deleteFolder) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(dumpDataPath, dumpData) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsage) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateVFolderQuotaUsage) - router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderBanTime, getBanTime) - router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderScore, getScore) - router.With(checkPerm(dataprovider.PermAdminManageDefender)).Post(defenderUnban, unban) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath, getAdmins) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(adminPath, addAdmin) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath+"/{username}", getAdminByUsername) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Put(adminPath+"/{username}", updateAdmin) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Delete(adminPath+"/{username}", deleteAdmin) + router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, version.Get()) }) - if s.enableWebAdmin || s.enableWebClient { - router.Group(func(router chi.Router) { - router.Use(compressor.Handler) - fileServer(router, webStaticFilesPath, http.Dir(s.staticFilesPath)) - }) - if s.enableWebClient { - router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently) - }) - router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently) - }) - } else { - router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently) - }) - router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently) - }) - } - } + router.Get(logoutPath, s.logout) + router.Put(adminPwdPath, changeAdminPassword) + router.With(checkPerm(dataprovider.PermAdminViewServerStatus)). + Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, getServicesStatus()) + }) + + router.With(checkPerm(dataprovider.PermAdminViewConnections)). + Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, common.Connections.GetStats()) + }) + + router.With(checkPerm(dataprovider.PermAdminCloseConnections)). + Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection) + router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanPath, getQuotaScans) + router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanPath, startQuotaScan) + router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanVFolderPath, getVFolderQuotaScans) + router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanVFolderPath, startVFolderQuotaScan) + router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers) + router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(userPath, addUser) + router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser) + router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser) + router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath, getFolders) + router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath+"/{name}", getFolderByName) + router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(folderPath, addFolder) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(folderPath+"/{name}", updateFolder) + router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(folderPath+"/{name}", deleteFolder) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(dumpDataPath, dumpData) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsage) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateVFolderQuotaUsage) + router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderBanTime, getBanTime) + router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderScore, getScore) + router.With(checkPerm(dataprovider.PermAdminManageDefender)).Post(defenderUnban, unban) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath, getAdmins) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(adminPath, addAdmin) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath+"/{username}", getAdminByUsername) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Put(adminPath+"/{username}", updateAdmin) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Delete(adminPath+"/{username}", deleteAdmin) + }) + + if s.enableWebAdmin || s.enableWebClient { + s.router.Group(func(router chi.Router) { + router.Use(compressor.Handler) + fileServer(router, webStaticFilesPath, http.Dir(s.staticFilesPath)) + }) if s.enableWebClient { - router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) { + s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently) }) - router.Get(webClientLoginPath, handleClientWebLogin) - router.Post(webClientLoginPath, s.handleWebClientLoginPost) - - router.Group(func(router chi.Router) { - router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie)) - router.Use(jwtAuthenticatorWebClient) - - router.Get(webClientLogoutPath, handleWebClientLogout) - router.With(s.refreshCookie).Get(webClientFilesPath, handleClientGetFiles) - router.With(s.refreshCookie).Get(webClientCredentialsPath, handleClientGetCredentials) - router.Post(webChangeClientPwdPath, handleWebClientChangePwdPost) - router.With(checkClientPerm(dataprovider.WebClientPubKeyChangeDisabled)). - Post(webChangeClientKeysPath, handleWebClientManageKeysPost) + s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently) }) - } - - if s.enableWebAdmin { - router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) { + } else { + s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently) }) - router.Get(webLoginPath, handleWebLogin) - router.Post(webLoginPath, s.handleWebAdminLoginPost) - - router.Group(func(router chi.Router) { - router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie)) - router.Use(jwtAuthenticatorWebAdmin) - - router.Get(webLogoutPath, handleWebLogout) - router.With(s.refreshCookie).Get(webChangeAdminPwdPath, handleWebAdminChangePwd) - router.Post(webChangeAdminPwdPath, handleWebAdminChangePwdPost) - router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie). - Get(webUsersPath, handleGetWebUsers) - router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie). - Get(webUserPath, handleWebAddUserGet) - router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie). - Get(webUserPath+"/{username}", handleWebUpdateUserGet) - router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webUserPath, handleWebAddUserPost) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", handleWebUpdateUserPost) - router.With(checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie). - Get(webConnectionsPath, handleWebGetConnections) - router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie). - Get(webFoldersPath, handleWebGetFolders) - router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie). - Get(webFolderPath, handleWebAddFolderGet) - router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webFolderPath, handleWebAddFolderPost) - router.With(checkPerm(dataprovider.PermAdminViewServerStatus), s.refreshCookie). - Get(webStatusPath, handleWebGetStatus) - router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie). - Get(webAdminsPath, handleGetWebAdmins) - router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie). - Get(webAdminPath, handleWebAddAdminGet) - router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie). - Get(webAdminPath+"/{username}", handleWebUpdateAdminGet) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, handleWebAddAdminPost) - router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", handleWebUpdateAdminPost) - router.With(checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader). - Delete(webAdminPath+"/{username}", deleteAdmin) - router.With(checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader). - Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection) - router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie). - Get(webFolderPath+"/{name}", handleWebUpdateFolderGet) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webFolderPath+"/{name}", handleWebUpdateFolderPost) - router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader). - Delete(webFolderPath+"/{name}", deleteFolder) - router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). - Post(webScanVFolderPath, startVFolderQuotaScan) - router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader). - Delete(webUserPath+"/{username}", deleteUser) - router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). - Post(webQuotaScanPath, startQuotaScan) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, handleWebMaintenance) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webRestorePath, handleWebRestore) - router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie). - Get(webTemplateUser, handleWebTemplateUserGet) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateUser, handleWebTemplateUserPost) - router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie). - Get(webTemplateFolder, handleWebTemplateFolderGet) - router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateFolder, handleWebTemplateFolderPost) + s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently) }) } - }) + } + + if s.enableWebClient { + s.router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently) + }) + s.router.Get(webClientLoginPath, handleClientWebLogin) + s.router.Post(webClientLoginPath, s.handleWebClientLoginPost) + + s.router.Group(func(router chi.Router) { + router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie)) + router.Use(jwtAuthenticatorWebClient) + + router.Get(webClientLogoutPath, handleWebClientLogout) + router.With(s.refreshCookie).Get(webClientFilesPath, handleClientGetFiles) + router.With(s.refreshCookie).Get(webClientCredentialsPath, handleClientGetCredentials) + router.Post(webChangeClientPwdPath, handleWebClientChangePwdPost) + router.With(checkClientPerm(dataprovider.WebClientPubKeyChangeDisabled)). + Post(webChangeClientKeysPath, handleWebClientManageKeysPost) + }) + } + + if s.enableWebAdmin { + s.router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently) + }) + s.router.Get(webLoginPath, handleWebLogin) + s.router.Post(webLoginPath, s.handleWebAdminLoginPost) + + s.router.Group(func(router chi.Router) { + router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie)) + router.Use(jwtAuthenticatorWebAdmin) + + router.Get(webLogoutPath, handleWebLogout) + router.With(s.refreshCookie).Get(webChangeAdminPwdPath, handleWebAdminChangePwd) + router.Post(webChangeAdminPwdPath, handleWebAdminChangePwdPost) + router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie). + Get(webUsersPath, handleGetWebUsers) + router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie). + Get(webUserPath, handleWebAddUserGet) + router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie). + Get(webUserPath+"/{username}", handleWebUpdateUserGet) + router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webUserPath, handleWebAddUserPost) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", handleWebUpdateUserPost) + router.With(checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie). + Get(webConnectionsPath, handleWebGetConnections) + router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie). + Get(webFoldersPath, handleWebGetFolders) + router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie). + Get(webFolderPath, handleWebAddFolderGet) + router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webFolderPath, handleWebAddFolderPost) + router.With(checkPerm(dataprovider.PermAdminViewServerStatus), s.refreshCookie). + Get(webStatusPath, handleWebGetStatus) + router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie). + Get(webAdminsPath, handleGetWebAdmins) + router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie). + Get(webAdminPath, handleWebAddAdminGet) + router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie). + Get(webAdminPath+"/{username}", handleWebUpdateAdminGet) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, handleWebAddAdminPost) + router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", handleWebUpdateAdminPost) + router.With(checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader). + Delete(webAdminPath+"/{username}", deleteAdmin) + router.With(checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader). + Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection) + router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie). + Get(webFolderPath+"/{name}", handleWebUpdateFolderGet) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webFolderPath+"/{name}", handleWebUpdateFolderPost) + router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader). + Delete(webFolderPath+"/{name}", deleteFolder) + router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). + Post(webScanVFolderPath, startVFolderQuotaScan) + router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader). + Delete(webUserPath+"/{username}", deleteUser) + router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). + Post(webQuotaScanPath, startQuotaScan) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, handleWebMaintenance) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webRestorePath, handleWebRestore) + router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie). + Get(webTemplateUser, handleWebTemplateUserGet) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateUser, handleWebTemplateUserPost) + router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie). + Get(webTemplateFolder, handleWebTemplateFolderGet) + router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateFolder, handleWebTemplateFolderPost) + }) + } } diff --git a/httpd/webclient.go b/httpd/webclient.go index 280fcee7..6f82a754 100644 --- a/httpd/webclient.go +++ b/httpd/webclient.go @@ -279,24 +279,11 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) { } func handleClientGetFiles(w http.ResponseWriter, r *http.Request) { - ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) - common.Connections.AddClientConnection(ipAddr) - defer common.Connections.RemoveClientConnection(ipAddr) - claims, err := getTokenClaims(r) if err != nil || claims.Username == "" { renderClientForbiddenPage(w, r, "Invalid token claims") return } - if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached") - renderClientForbiddenPage(w, r, "configured connections limit reached") - return - } - if common.IsBanned(ipAddr) { - renderClientForbiddenPage(w, r, "your IP address is banned") - return - } user, err := dataprovider.UserExists(claims.Username) if err != nil { @@ -635,16 +622,5 @@ func checkWebClientUser(user *dataprovider.User, r *http.Request, connectionID s logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr) return fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr) } - if connAddr, ok := r.Context().Value(connAddrKey).(string); ok { - if connAddr != r.RemoteAddr { - connIPAddr := utils.GetIPFromRemoteAddress(connAddr) - if common.IsBanned(connIPAddr) { - return errors.New("your IP address is banned") - } - if !user.IsLoginFromAddrAllowed(connIPAddr) { - return fmt.Errorf("login from IP %v is not allowed", connIPAddr) - } - } - } return nil } diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index df1b0246..f6252138 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -553,7 +553,7 @@ func TestDefender(t *testing.T) { _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) @@ -2907,21 +2907,24 @@ func TestMaxConnections(t *testing.T) { common.Config.MaxTotalConnections = 1 usePubKey := true - u := getTestUser(usePubKey) - user, _, err := httpdtest.AddUser(u, http.StatusCreated) + user := getTestUser(usePubKey) + err := dataprovider.AddUser(&user) assert.NoError(t, err) + user.Password = "" conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { - defer conn.Close() - defer client.Close() assert.NoError(t, checkBasicSFTP(client)) s, c, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") { c.Close() s.Close() } + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) } - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) @@ -2935,21 +2938,24 @@ func TestMaxPerHostConnections(t *testing.T) { common.Config.MaxPerHostConnections = 1 usePubKey := true - u := getTestUser(usePubKey) - user, _, err := httpdtest.AddUser(u, http.StatusCreated) + user := getTestUser(usePubKey) + err := dataprovider.AddUser(&user) assert.NoError(t, err) + user.Password = "" conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { - defer conn.Close() - defer client.Close() assert.NoError(t, checkBasicSFTP(client)) s, c, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") { c.Close() s.Close() } + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) } - _, err = httpdtest.RemoveUser(user, http.StatusOK) + err = dataprovider.DeleteUser(user.Username) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) diff --git a/sftpgo.json b/sftpgo.json index ddaafec3..ba3e501f 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -107,7 +107,8 @@ "enable_https": false, "client_auth_type": 0, "tls_cipher_suites": [], - "prefix": "" + "prefix": "", + "proxy_allowed": [] } ], "certificate_file": "", @@ -180,12 +181,13 @@ "bindings": [ { "port": 8080, - "address": "127.0.0.1", + "address": "", "enable_web_admin": true, "enable_web_client": true, "enable_https": false, "client_auth_type": 0, - "tls_cipher_suites": [] + "tls_cipher_suites": [], + "proxy_allowed": [] } ], "templates_path": "templates", diff --git a/templates/webadmin/base.html b/templates/webadmin/base.html index b8e1820a..730857d7 100644 --- a/templates/webadmin/base.html +++ b/templates/webadmin/base.html @@ -209,6 +209,7 @@ + {{if .LoggedAdmin.Username}} + {{end}} {{block "dialog" .}}{{end}} diff --git a/templates/webclient/base.html b/templates/webclient/base.html index adbdd4a9..9f7e509e 100644 --- a/templates/webclient/base.html +++ b/templates/webclient/base.html @@ -168,6 +168,7 @@ + {{if .LoggedUser.Username}} + {{end}} {{block "dialog" .}}{{end}} diff --git a/templates/webclient/credentials.html b/templates/webclient/credentials.html index db039211..29a5c725 100644 --- a/templates/webclient/credentials.html +++ b/templates/webclient/credentials.html @@ -57,7 +57,7 @@
- One public key per line diff --git a/utils/utils.go b/utils/utils.go index 06bd3e1c..13c39d7c 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -38,6 +38,11 @@ const ( osWindows = "windows" ) +var ( + xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") + xRealIP = http.CanonicalHeaderKey("X-Real-IP") +) + // IsStringInSlice searches a string in a slice and returns true if the string is found func IsStringInSlice(obj string, list []string) bool { for i := 0; i < len(list); i++ { @@ -516,3 +521,48 @@ func GetSSHPublicKeyAsString(pubKey []byte) (string, error) { } return string(ssh.MarshalAuthorizedKey(k)), nil } + +// GetRealIP returns the ip address as result of parsing either the +// X-Real-IP header or the X-Forwarded-For header +func GetRealIP(r *http.Request) string { + var ip string + + if xrip := r.Header.Get(xRealIP); xrip != "" { + ip = xrip + } else if xff := r.Header.Get(xForwardedFor); xff != "" { + i := strings.Index(xff, ", ") + if i == -1 { + i = len(xff) + } + ip = strings.TrimSpace(xff[:i]) + } + if net.ParseIP(ip) == nil { + return "" + } + + return ip +} + +// ParseAllowedIPAndRanges returns a list of functions that allow to find if a +func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) { + res := make([]func(net.IP) bool, len(allowed)) + for i, allowFrom := range allowed { + if strings.LastIndex(allowFrom, "/") > 0 { + _, ipRange, err := net.ParseCIDR(allowFrom) + if err != nil { + return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err) + } + + res[i] = ipRange.Contains + } else { + allowed := net.ParseIP(allowFrom) + if allowed == nil { + return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom) + } + + res[i] = allowed.Equal + } + } + + return res, nil +} diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index eedd2a0d..1b1cedad 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -23,6 +23,7 @@ import ( "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/kms" + "github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/vfs" ) @@ -409,36 +410,68 @@ func TestUserInvalidParams(t *testing.T) { } func TestRemoteAddress(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, "/username", nil) - assert.NoError(t, err) - assert.Empty(t, req.RemoteAddr) - remoteAddr1 := "100.100.100.100" remoteAddr2 := "172.172.172.172" + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + ProxyAllowed: []string{remoteAddr2, "10.8.0.0/30"}, + }, + }, + } + + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + err := server.binding.parseAllowedProxy() + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, "/", nil) + assert.NoError(t, err) + assert.Empty(t, req.RemoteAddr) + req.Header.Set("X-Forwarded-For", remoteAddr1) - checkRemoteAddress(req) - assert.Equal(t, remoteAddr1, req.RemoteAddr) + ip := utils.GetRealIP(req) + assert.Equal(t, remoteAddr1, ip) + // this will be ignore, remoteAddr1 is not allowed to se this header + req.Header.Set("X-Forwarded-For", remoteAddr2) + req.RemoteAddr = remoteAddr1 + ip = server.checkRemoteAddress(req) + assert.Equal(t, remoteAddr1, ip) req.RemoteAddr = "" + ip = server.checkRemoteAddress(req) + assert.Empty(t, ip) req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", remoteAddr2, remoteAddr1)) - checkRemoteAddress(req) - assert.Equal(t, remoteAddr2, req.RemoteAddr) + ip = utils.GetRealIP(req) + assert.Equal(t, remoteAddr2, ip) + + req.RemoteAddr = remoteAddr2 + req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", "12.34.56.78", "172.16.2.4")) + ip = server.checkRemoteAddress(req) + assert.Equal(t, "12.34.56.78", ip) + assert.Equal(t, ip, req.RemoteAddr) + + req.RemoteAddr = "10.8.0.2" + req.Header.Set("X-Forwarded-For", remoteAddr1) + ip = server.checkRemoteAddress(req) + assert.Equal(t, remoteAddr1, ip) + assert.Equal(t, ip, req.RemoteAddr) + + req.RemoteAddr = "10.8.0.3" + req.Header.Set("X-Forwarded-For", "not an ip") + ip = server.checkRemoteAddress(req) + assert.Equal(t, "10.8.0.3", ip) + assert.Equal(t, ip, req.RemoteAddr) req.Header.Del("X-Forwarded-For") req.RemoteAddr = "" req.Header.Set("X-Real-IP", remoteAddr1) - checkRemoteAddress(req) - assert.Equal(t, remoteAddr1, req.RemoteAddr) + ip = utils.GetRealIP(req) + assert.Equal(t, remoteAddr1, ip) req.RemoteAddr = "" - - oldValue := common.Config.ProxyProtocol - common.Config.ProxyProtocol = 1 - - checkRemoteAddress(req) - assert.Empty(t, req.RemoteAddr) - - common.Config.ProxyProtocol = oldValue } func TestConnWithNilRequest(t *testing.T) { diff --git a/webdavd/server.go b/webdavd/server.go index f0aa3d30..e3f9716c 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -7,11 +7,11 @@ import ( "errors" "fmt" "log" + "net" "net/http" "path" "path/filepath" "runtime/debug" - "strings" "time" "github.com/go-chi/chi/v5/middleware" @@ -27,9 +27,7 @@ import ( ) var ( - err401 = errors.New("Unauthorized") - xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") - xRealIP = http.CanonicalHeaderKey("X-Real-IP") + err401 = errors.New("Unauthorized") ) type webDavServer struct { @@ -145,8 +143,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() - checkRemoteAddress(r) - ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) + ipAddr := s.checkRemoteAddress(r) common.Connections.AddClientConnection(ipAddr) defer common.Connections.RemoveClientConnection(ipAddr) @@ -327,6 +324,24 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo return connID, nil } +func (s *webDavServer) checkRemoteAddress(r *http.Request) string { + ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) + ip := net.ParseIP(ipAddr) + if ip != nil { + for _, allow := range s.binding.allowHeadersFrom { + if allow(ip) { + parsedIP := utils.GetRealIP(r) + if parsedIP != "" { + ipAddr = parsedIP + r.RemoteAddr = ipAddr + } + break + } + } + } + return ipAddr +} + func writeLog(r *http.Request, err error) { scheme := "http" if r.TLS != nil { @@ -352,28 +367,6 @@ func writeLog(r *http.Request, err error) { Send() } -func checkRemoteAddress(r *http.Request) { - if common.Config.ProxyProtocol != 0 { - return - } - - var ip string - - if xrip := r.Header.Get(xRealIP); xrip != "" { - ip = xrip - } else if xff := r.Header.Get(xForwardedFor); xff != "" { - i := strings.Index(xff, ", ") - if i == -1 { - i = len(xff) - } - ip = strings.TrimSpace(xff[:i]) - } - - if len(ip) > 0 { - r.RemoteAddr = ip - } -} - func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error) { metrics.AddLoginAttempt(loginMethod) if err != nil { diff --git a/webdavd/webdavd.go b/webdavd/webdavd.go index 53858429..5398977d 100644 --- a/webdavd/webdavd.go +++ b/webdavd/webdavd.go @@ -3,6 +3,7 @@ package webdavd import ( "fmt" + "net" "path/filepath" "github.com/go-chi/chi/v5/middleware" @@ -90,6 +91,18 @@ type Binding struct { // Prefix for WebDAV resources, if empty WebDAV resources will be available at the // root ("/") URI. If defined it must be an absolute URI. Prefix string `json:"prefix" mapstructure:"prefix"` + // List of IP addresses and IP ranges allowed to set X-Forwarded-For/X-Real-IP headers. + ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` + allowHeadersFrom []func(net.IP) bool +} + +func (b *Binding) parseAllowedProxy() error { + allowedFuncs, err := utils.ParseAllowedIPAndRanges(b.ProxyAllowed) + if err != nil { + return err + } + b.allowHeadersFrom = allowedFuncs + return nil } func (b *Binding) isMutualTLSEnabled() bool { @@ -191,6 +204,9 @@ func (c *Configuration) Initialize(configDir string) error { if !binding.IsValid() { continue } + if err := binding.parseAllowedProxy(); err != nil { + return err + } go func(binding Binding) { server := webDavServer{ diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index c89aae81..41ff56f6 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -469,6 +469,12 @@ func TestInitialization(t *testing.T) { cfg.CertificateKeyFile = keyPath cfg.CACertificates = []string{caCrtPath} cfg.CARevocationLists = []string{caCRLPath} + cfg.Bindings[0].ProxyAllowed = []string{"not valid"} + err = cfg.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not a valid IP address") + } + cfg.Bindings[0].ProxyAllowed = nil err = cfg.Initialize(configDir) assert.Error(t, err) } @@ -1180,6 +1186,10 @@ func TestQuotaLimits(t *testing.T) { // test quota files err = uploadFile(testFilePath, testFileName+".quota", testFileSize, client) if !assert.NoError(t, err, "username: %v", user.Username) { + info, err := os.Stat(testFilePath) + if assert.NoError(t, err) { + fmt.Printf("local file size %v", info.Size()) + } printLatestLogs(20) } err = uploadFile(testFilePath, testFileName+".quota1", testFileSize, client)