httpd/webdav: add a list of hosts allowed to send proxy headers

X-Forwarded-For, X-Real-IP and X-Forwarded-Proto headers will be ignored
for hosts not included in this list.

This is a backward incompatible change, before the proxy headers were
always used
This commit is contained in:
Nicola Murino 2021-05-11 06:54:06 +02:00
parent f1b998ce16
commit c8f7fc9bc9
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
25 changed files with 669 additions and 383 deletions

View file

@ -52,8 +52,7 @@ ENV SFTPGO_HTTPD__STATIC_FILES_PATH=/usr/share/sftpgo/static
# Modify the default configuration file
RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \
sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json && \
sed -i "s|\"address\": \"127.0.0.1\",|\"address\": \"\",|" /etc/sftpgo/sftpgo.json
sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json
RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups

View file

@ -57,8 +57,7 @@ ENV SFTPGO_HTTPD__STATIC_FILES_PATH=/usr/share/sftpgo/static
# Modify the default configuration file
RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \
sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json && \
sed -i "s|\"address\": \"127.0.0.1\",|\"address\": \"\",|" /etc/sftpgo/sftpgo.json
sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json
RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups

View file

@ -697,6 +697,19 @@ func TestCachedFs(t *testing.T) {
assert.NoError(t, err)
}
func TestParseAllowedIPAndRanges(t *testing.T) {
_, err := utils.ParseAllowedIPAndRanges([]string{"1.1.1.1", "not an ip"})
assert.Error(t, err)
_, err = utils.ParseAllowedIPAndRanges([]string{"1.1.1.5", "192.168.1.0/240"})
assert.Error(t, err)
allow, err := utils.ParseAllowedIPAndRanges([]string{"192.168.1.2", "172.16.0.0/24"})
assert.NoError(t, err)
assert.True(t, allow[0](net.ParseIP("192.168.1.2")))
assert.False(t, allow[0](net.ParseIP("192.168.2.2")))
assert.True(t, allow[1](net.ParseIP("172.16.0.1")))
assert.False(t, allow[1](net.ParseIP("172.16.1.1")))
}
func BenchmarkBcryptHashing(b *testing.B) {
bcryptPassword := "bcryptpassword"
for i := 0; i < b.N; i++ {

View file

@ -60,6 +60,7 @@ var (
ClientAuthType: 0,
TLSCipherSuites: nil,
Prefix: "",
ProxyAllowed: nil,
}
defaultHTTPDBinding = httpd.Binding{
Address: "127.0.0.1",
@ -69,6 +70,7 @@ var (
EnableHTTPS: false,
ClientAuthType: 0,
TLSCipherSuites: nil,
ProxyAllowed: nil,
}
defaultRateLimiter = common.RateLimiterConfig{
Average: 0,
@ -768,6 +770,12 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true
}
proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_ALLOWED", idx))
if ok {
binding.ProxyAllowed = proxyAllowed
isSet = true
}
prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
if ok {
binding.Prefix = prefix
@ -833,6 +841,12 @@ func getHTTPDBindingFromEnv(idx int) {
isSet = true
}
proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_ALLOWED", idx))
if ok {
binding.ProxyAllowed = proxyAllowed
isSet = true
}
if isSet {
if len(globalConf.HTTPDConfig.Bindings) > idx {
globalConf.HTTPDConfig.Bindings[idx] = binding

View file

@ -572,6 +572,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT", "8000")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED", "192.168.10.1")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS", "127.0.1.1")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT", "9000")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS", "1")
@ -582,6 +583,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS")
@ -605,6 +607,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Equal(t, 0, bindings[1].ClientAuthType)
require.Len(t, bindings[1].TLSCipherSuites, 1)
require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0])
require.Equal(t, "192.168.10.1", bindings[1].ProxyAllowed[0])
require.Empty(t, bindings[1].Prefix)
require.Equal(t, 9000, bindings[2].Port)
require.Equal(t, "127.0.1.1", bindings[2].Address)
@ -634,6 +637,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED", " 192.168.9.1 , 172.16.25.0/24")
t.Cleanup(func() {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__PORT")
@ -650,6 +654,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED")
})
configDir := ".."
@ -680,6 +685,9 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Len(t, bindings[2].TLSCipherSuites, 2)
require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0])
require.Equal(t, "TLS_CHACHA20_POLY1305_SHA256", bindings[2].TLSCipherSuites[1])
require.Len(t, bindings[2].ProxyAllowed, 2)
require.Equal(t, "192.168.9.1", bindings[2].ProxyAllowed[0])
require.Equal(t, "172.16.25.0/24", bindings[2].ProxyAllowed[1])
}
func TestHTTPClientCertificatesFromEnv(t *testing.T) {

View file

@ -147,6 +147,7 @@ The configuration file contains the following sections:
- `client_auth_type`, integer. Set to `1` to require a client certificate and verify it. Set to `2` to request a client certificate during the TLS handshake and verify it if given, in this mode the client is allowed not to send a certificate. At least one certification authority must be defined in order to verify client certificates. If no certification authority is defined, this setting is ignored. Default: 0.
- `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty.
- `prefix`, string. Prefix for WebDAV resources, if empty WebDAV resources will be available at the `/` URI. If defined it must be an absolute URI, for example `/dav`. Default: "".
- `proxy_allowed`, list of IP addresses and IP ranges allowed to set `X-Forwarded-For`, `X-Real-IP` headers. Any of the indicated headers, if set on requests from a connection address not in this list, will be silently ignored. Default: empty.
- `bind_port`, integer. Deprecated, please use `bindings`.
- `bind_address`, string. Deprecated, please use `bindings`.
- `certificate_file`, string. Certificate for WebDAV over HTTPS. This can be an absolute path or a path relative to the config dir.
@ -216,8 +217,9 @@ The configuration file contains the following sections:
- `enable_https`, boolean. Set to `true` and provide both a certificate and a key file to enable HTTPS connection for this binding. Default `false`.
- `client_auth_type`, integer. Set to `1` to require client certificate authentication in addition to JWT/Web authentication. You need to define at least a certificate authority for this to work. Default: 0.
- `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty.
- `proxy_allowed`, list of IP addresses and IP ranges allowed to set `X-Forwarded-For`, `X-Real-IP`, `X-Forwarded-Proto` headers. Any of the indicated headers, if set on requests from a connection address not in this list, will be silently ignored. Default: empty.
- `bind_port`, integer. Deprecated, please use `bindings`.
- `bind_address`, string. Deprecated, please use `bindings`. Leave blank to listen on all available network interfaces. On \*NIX you can specify an absolute path to listen on a Unix-domain socket. Default: "127.0.0.1"
- `bind_address`, string. Deprecated, please use `bindings`. Leave blank to listen on all available network interfaces. On \*NIX you can specify an absolute path to listen on a Unix-domain socket. Default: ""
- `templates_path`, string. Path to the HTML web templates. This can be an absolute path or a path relative to the config dir
- `static_files_path`, string. Path to the static files for the web interface. This can be an absolute path or a path relative to the config dir. If both `templates_path` and `static_files_path` are empty the built-in web interface will be disabled
- `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons

View file

@ -192,13 +192,7 @@ systemctl status sftpgo
The easiest way to add virtual users is to use the built-in Web interface.
You can expose the Web Admin interface over the network replacing `"bind_address": "127.0.0.1"` in the `httpd` configuration section with `"bind_address": ""` and apply the change restarting the SFTPGo service with the following command.
```shell
sudo systemctl restart sftpgo
```
So now open the Web Admin URL.
So navigate to the Web Admin URL.
[http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin)

View file

@ -760,8 +760,10 @@ func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
user := getTestUser()
err := dataprovider.AddUser(&user)
assert.NoError(t, err)
user.Password = ""
client, err := getFTPClient(user, true, nil)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
@ -771,7 +773,7 @@ func TestMaxConnections(t *testing.T) {
err = client.Quit()
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
@ -783,8 +785,10 @@ func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
user := getTestUser()
err := dataprovider.AddUser(&user)
assert.NoError(t, err)
user.Password = ""
client, err := getFTPClient(user, true, nil)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
@ -794,7 +798,7 @@ func TestMaxPerHostConnections(t *testing.T) {
err = client.Quit()
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
@ -851,7 +855,7 @@ func TestRateLimiter(t *testing.T) {
assert.Contains(t, err.Error(), "banned client IP")
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
@ -893,7 +897,7 @@ func TestDefender(t *testing.T) {
assert.Contains(t, err.Error(), "banned client IP")
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)

View file

@ -137,7 +137,7 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque
Path: basePath,
Expires: time.Now().Add(tokenDuration),
HttpOnly: true,
Secure: r.TLS != nil,
Secure: isTLS(r),
})
return nil
@ -150,11 +150,21 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request) {
Path: webBasePath,
MaxAge: -1,
HttpOnly: true,
Secure: r.TLS != nil,
Secure: isTLS(r),
})
invalidateToken(r)
}
func isTLS(r *http.Request) bool {
if r.TLS != nil {
return true
}
if proto, ok := r.Context().Value(forwardedProtoKey).(string); ok {
return proto == "https"
}
return false
}
func isTokenInvalidated(r *http.Request) bool {
isTokenFound := false
token := jwtauth.TokenFromHeader(r)

View file

@ -6,6 +6,7 @@ package httpd
import (
"fmt"
"net"
"net/http"
"net/url"
"path"
@ -156,6 +157,19 @@ type Binding struct {
// any invalid name will be silently ignored.
// The order matters, the ciphers listed first will be the preferred ones.
TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"`
// List of IP addresses and IP ranges allowed to set X-Forwarded-For, X-Real-IP,
// X-Forwarded-Proto headers.
ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"`
allowHeadersFrom []func(net.IP) bool
}
func (b *Binding) parseAllowedProxy() error {
allowedFuncs, err := utils.ParseAllowedIPAndRanges(b.ProxyAllowed)
if err != nil {
return err
}
b.allowHeadersFrom = allowedFuncs
return nil
}
// GetAddress returns the binding address
@ -252,6 +266,14 @@ func (c *Conf) isWebClientEnabled() bool {
return false
}
func (c *Conf) checkRequiredDirs(staticFilesPath, templatesPath string) error {
if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") {
return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v",
staticFilesPath, templatesPath)
}
return nil
}
// Initialize configures and starts the HTTP server
func (c *Conf) Initialize(configDir string) error {
logger.Debug(logSender, "", "initializing HTTP server with config %+v", c)
@ -261,9 +283,8 @@ func (c *Conf) Initialize(configDir string) error {
if backupsPath == "" {
return fmt.Errorf("required directory is invalid, backup path %#v", backupsPath)
}
if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") {
return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v",
staticFilesPath, templatesPath)
if err := c.checkRequiredDirs(staticFilesPath, templatesPath); err != nil {
return err
}
certificateFile := getConfigPath(c.CertificateFile, configDir)
certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir)
@ -303,6 +324,9 @@ func (c *Conf) Initialize(configDir string) error {
if !binding.IsValid() {
continue
}
if err := binding.parseAllowedProxy(); err != nil {
return err
}
go func(b Binding) {
server := newHttpdServer(b, staticFilesPath)

View file

@ -326,6 +326,12 @@ func TestInitialization(t *testing.T) {
err = httpdConf.Initialize(configDir)
assert.Error(t, err)
httpdConf.CARevocationLists = nil
httpdConf.Bindings[0].ProxyAllowed = []string{"invalid ip/network"}
err = httpdConf.Initialize(configDir)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "is not a valid IP range")
}
httpdConf.Bindings[0].ProxyAllowed = nil
httpdConf.Bindings[0].EnableWebAdmin = false
httpdConf.Bindings[0].EnableWebClient = false
httpdConf.Bindings[0].Port = 8081
@ -3288,6 +3294,22 @@ func TestRateLimiter(t *testing.T) {
err = resp.Body.Close()
assert.NoError(t, err)
resp, err = client.Get(httpBaseURL + webLoginPath)
assert.NoError(t, err)
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
assert.Equal(t, "1", resp.Header.Get("Retry-After"))
assert.NotEmpty(t, resp.Header.Get("X-Retry-In"))
err = resp.Body.Close()
assert.NoError(t, err)
resp, err = client.Get(httpBaseURL + webClientLoginPath)
assert.NoError(t, err)
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
assert.Equal(t, "1", resp.Header.Get("Retry-After"))
assert.NotEmpty(t, resp.Header.Get("X-Retry-In"))
err = resp.Body.Close()
assert.NoError(t, err)
err = common.Initialize(oldConfig)
assert.NoError(t, err)
}
@ -4625,10 +4647,13 @@ func TestDefender(t *testing.T) {
remoteAddr := "172.16.5.6:9876"
webAdminToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
assert.NoError(t, err)
webToken, err := getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr)
assert.NoError(t, err)
req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil)
req.RemoteAddr = remoteAddr
req.RequestURI = webClientFilesPath
setJWTCookieForReq(req, webToken)
rr := executeRequest(req)
checkResponseCode(t, http.StatusOK, rr)
@ -4642,11 +4667,20 @@ func TestDefender(t *testing.T) {
assert.Error(t, err)
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
req.RemoteAddr = remoteAddr
req.RequestURI = webClientFilesPath
setJWTCookieForReq(req, webToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusForbidden, rr)
assert.Contains(t, rr.Body.String(), "your IP address is banned")
req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil)
req.RemoteAddr = remoteAddr
req.RequestURI = webUsersPath
setJWTCookieForReq(req, webAdminToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusForbidden, rr)
assert.Contains(t, rr.Body.String(), "your IP address is banned")
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
req.RemoteAddr = remoteAddr
req.Header.Set("X-Real-IP", "127.0.0.1:2345")
@ -5192,7 +5226,7 @@ func TestWebAdminLoginMock(t *testing.T) {
req.Header.Set("X-Forwarded-For", "10.9.9.9")
rr = executeRequest(req)
checkResponseCode(t, http.StatusOK, rr)
assert.Contains(t, rr.Body.String(), "Login from IP 127.0.1.1:4567 is not allowed")
assert.Contains(t, rr.Body.String(), "login from IP 127.0.1.1 not allowed")
// invalid csrf token
form = getLoginForm(altAdminUsername, altAdminPassword, "invalid csrf")

View file

@ -424,7 +424,7 @@ func TestCreateTokenError(t *testing.T) {
}
req, _ := http.NewRequest(http.MethodGet, tokenPath, nil)
server.checkAddrAndSendToken(rr, req, admin)
server.generateAndSendToken(rr, req, admin)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
rr = httptest.NewRecorder()
@ -565,22 +565,6 @@ func TestJWTTokenValidation(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestAdminAllowListConnAddr(t *testing.T) {
server := httpdServer{}
admin := dataprovider.Admin{
Filters: dataprovider.AdminFilters{
AllowList: []string{"192.168.1.0/24"},
},
}
rr := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, tokenPath, nil)
ctx := context.WithValue(req.Context(), connAddrKey, "127.0.0.1:4567")
req.RemoteAddr = "192.168.1.16:1234"
server.checkAddrAndSendToken(rr, req.WithContext(ctx), admin)
assert.Equal(t, http.StatusForbidden, rr.Code, rr.Body.String())
assert.Equal(t, "context value connection address", connAddrKey.String())
}
func TestUpdateContextFromCookie(t *testing.T) {
server := httpdServer{
tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil),
@ -672,14 +656,6 @@ func TestCookieExpiration(t *testing.T) {
cookie = rr.Header().Get("Set-Cookie")
assert.Empty(t, cookie)
req, _ = http.NewRequest(http.MethodGet, tokenPath, nil)
req.RemoteAddr = "172.16.1.2:1234"
ctx = jwtauth.NewContext(req.Context(), token, nil)
ctx = context.WithValue(ctx, connAddrKey, "10.9.9.9")
server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie")
assert.Empty(t, cookie)
req, _ = http.NewRequest(http.MethodGet, tokenPath, nil)
req.RemoteAddr = "172.16.1.12:4567"
ctx = jwtauth.NewContext(req.Context(), token, nil)
@ -749,17 +725,10 @@ func TestCookieExpiration(t *testing.T) {
cookie = rr.Header().Get("Set-Cookie")
assert.Empty(t, cookie)
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
req.RemoteAddr = "172.16.4.12:4567"
ctx = jwtauth.NewContext(req.Context(), token, nil)
server.checkCookieExpiration(rr, req.WithContext(context.WithValue(ctx, connAddrKey, "172.16.0.1:4567")))
cookie = rr.Header().Get("Set-Cookie")
assert.Empty(t, cookie)
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
req.RemoteAddr = "172.16.4.16:4567"
ctx = jwtauth.NewContext(req.Context(), token, nil)
server.checkCookieExpiration(rr, req.WithContext(context.WithValue(ctx, connAddrKey, "172.16.4.18:4567")))
server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie)
@ -1014,6 +983,111 @@ func TestJWTTokenCleanup(t *testing.T) {
stopJWTTokensCleanupTicker()
}
func TestProxyHeaders(t *testing.T) {
username := "adminTest"
password := "testPwd"
admin := dataprovider.Admin{
Username: username,
Password: password,
Permissions: []string{dataprovider.PermAdminAny},
Status: 1,
Filters: dataprovider.AdminFilters{
AllowList: []string{"172.19.2.0/24"},
},
}
err := dataprovider.AddAdmin(&admin)
assert.NoError(t, err)
testIP := "10.29.1.9"
validForwardedFor := "172.19.2.6"
b := Binding{
Address: "",
Port: 8080,
EnableWebAdmin: true,
EnableWebClient: false,
ProxyAllowed: []string{testIP, "10.8.0.0/30"},
}
err = b.parseAllowedProxy()
assert.NoError(t, err)
server := newHttpdServer(b, "")
server.initializeRouter()
testServer := httptest.NewServer(server.router)
defer testServer.Close()
req, err := http.NewRequest(http.MethodGet, tokenPath, nil)
assert.NoError(t, err)
req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "https")
req.RemoteAddr = "127.0.0.1:123"
req.SetBasicAuth(username, password)
rr := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.Contains(t, rr.Body.String(), "login from IP 127.0.0.1 not allowed")
req.RemoteAddr = testIP
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
req.RemoteAddr = "10.8.0.2"
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
form := make(url.Values)
form.Set("username", username)
form.Set("password", password)
form.Set(csrfFormToken, createCSRFToken())
req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
assert.Contains(t, rr.Body.String(), "login from IP 10.29.1.9 not allowed")
req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
cookie := rr.Header().Get("Set-Cookie")
assert.NotContains(t, cookie, "Secure")
req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "https")
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
cookie = rr.Header().Get("Set-Cookie")
assert.Contains(t, cookie, "Secure")
req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "http")
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
cookie = rr.Header().Get("Set-Cookie")
assert.NotContains(t, cookie, "Secure")
err = dataprovider.DeleteAdmin(username)
assert.NoError(t, err)
}
func TestWebAdminRedirect(t *testing.T) {
b := Binding{
Address: "",
@ -1306,3 +1380,16 @@ func TestManageKeysInvalidClaims(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rr.Code)
assert.Contains(t, rr.Body.String(), "Invalid token claims")
}
func TestTLSReq(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil)
assert.NoError(t, err)
req.TLS = &tls.ConnectionState{}
assert.True(t, isTLS(req))
req.TLS = nil
ctx := context.WithValue(req.Context(), forwardedProtoKey, "https")
assert.True(t, isTLS(req.WithContext(ctx)))
ctx = context.WithValue(req.Context(), forwardedProtoKey, "http")
assert.False(t, isTLS(req.WithContext(ctx)))
assert.Equal(t, "context value forwarded proto", forwardedProtoKey.String())
}

View file

@ -1,23 +1,19 @@
package httpd
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwt"
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
)
var (
connAddrKey = &contextKey{"connection address"}
errInvalidToken = errors.New("invalid JWT token")
forwardedProtoKey = &contextKey{"forwarded proto"}
errInvalidToken = errors.New("invalid JWT token")
)
type contextKey struct {
@ -28,13 +24,6 @@ func (k *contextKey) String() string {
return "context value " + k.name
}
func saveConnectionAddress(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), connAddrKey, r.RemoteAddr)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
token, _, err := jwtauth.FromContext(r.Context())
@ -188,16 +177,3 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
func rateLimiter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if delay, err := common.LimitRate(common.ProtocolHTTP, utils.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
delay += 499999999 * time.Nanosecond
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
w.Header().Set("X-Retry-In", delay.String())
sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -1,11 +1,13 @@
package httpd
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
"net"
"net/http"
"time"
@ -23,7 +25,10 @@ import (
"github.com/drakkan/sftpgo/version"
)
var compressor = middleware.NewCompressor(5)
var (
compressor = middleware.NewCompressor(5)
xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
)
type httpdServer struct {
binding Binding
@ -111,10 +116,6 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
if err := r.ParseForm(); err != nil {
renderClientLoginPage(w, err.Error())
return
@ -130,16 +131,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
return
}
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
renderClientLoginPage(w, "configured connections limit reached")
return
}
if common.IsBanned(ipAddr) {
renderClientLoginPage(w, "your IP address is banned")
return
}
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
renderClientLoginPage(w, fmt.Sprintf("access denied by post connect hook: %v", err))
return
@ -204,14 +196,6 @@ func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Req
renderLoginPage(w, err.Error())
return
}
if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
if connAddr != r.RemoteAddr {
if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) {
renderLoginPage(w, fmt.Sprintf("Login from IP %v is not allowed", connAddr))
return
}
}
}
c := jwtTokenClaims{
Username: admin.Username,
Permissions: admin.Permissions,
@ -246,19 +230,10 @@ func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) {
return
}
s.checkAddrAndSendToken(w, r, admin)
s.generateAndSendToken(w, r, admin)
}
func (s *httpdServer) checkAddrAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin) {
if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
if connAddr != r.RemoteAddr {
if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) {
sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
}
}
func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin) {
c := jwtTokenClaims{
Username: admin.Username,
Permissions: admin.Permissions,
@ -330,15 +305,6 @@ func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request,
logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie", admin.Username, r.RemoteAddr)
return
}
if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
if connAddr != r.RemoteAddr {
if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) {
logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie",
admin.Username, connAddr)
return
}
}
}
logger.Debug(logSender, "", "cookie refreshed for admin %#v", admin.Username)
tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebAdmin) //nolint:errcheck
}
@ -357,200 +323,266 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
return r
}
func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
ip := net.ParseIP(ipAddr)
if ip != nil {
for _, allow := range s.binding.allowHeadersFrom {
if allow(ip) {
parsedIP := utils.GetRealIP(r)
if parsedIP != "" {
ipAddr = parsedIP
r.RemoteAddr = ipAddr
}
if forwardedProto := r.Header.Get(xForwardedProto); forwardedProto != "" {
ctx := context.WithValue(r.Context(), forwardedProtoKey, forwardedProto)
r = r.WithContext(ctx)
}
break
}
}
}
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
s.sendForbiddenResponse(w, r, "configured connections limit reached")
return
}
if common.IsBanned(ipAddr) {
s.sendForbiddenResponse(w, r, "your IP address is banned")
return
}
if delay, err := common.LimitRate(common.ProtocolHTTP, ipAddr); err != nil {
delay += 499999999 * time.Nanosecond
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
w.Header().Set("X-Retry-In", delay.String())
s.sendTooManyRequestResponse(w, r, err)
return
}
next.ServeHTTP(w, r)
})
}
func (s *httpdServer) sendTooManyRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
r = s.updateContextFromCookie(r)
if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
renderClientMessagePage(w, r, http.StatusText(http.StatusTooManyRequests), "Rate limit exceeded",
http.StatusTooManyRequests, err, "")
return
}
renderMessagePage(w, r, http.StatusText(http.StatusTooManyRequests), "Rate limit exceeded", http.StatusTooManyRequests,
err, "")
return
}
sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
func (s *httpdServer) sendForbiddenResponse(w http.ResponseWriter, r *http.Request, message string) {
if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
r = s.updateContextFromCookie(r)
if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
renderClientForbiddenPage(w, r, message)
return
}
renderForbiddenPage(w, r, message)
return
}
sendAPIResponse(w, r, errors.New(message), message, http.StatusForbidden)
}
func (s *httpdServer) initializeRouter() {
s.tokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil)
s.router = chi.NewRouter()
s.router.Use(saveConnectionAddress)
s.router.Use(middleware.RequestID)
s.router.Use(logger.NewStructuredLogger(logger.GetLogger()))
s.router.Use(middleware.Recoverer)
s.router.Use(s.checkConnection)
s.router.Use(middleware.GetHead)
s.router.Use(middleware.StripSlashes)
s.router.Use(middleware.RealIP)
s.router.Use(rateLimiter)
s.router.Group(func(r chi.Router) {
r.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) {
render.PlainText(w, r, "ok")
})
})
s.router.Group(func(router chi.Router) {
router.Use(middleware.RequestID)
router.Use(logger.NewStructuredLogger(logger.GetLogger()))
router.Use(middleware.Recoverer)
router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
r = s.updateContextFromCookie(r)
if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
renderClientNotFoundPage(w, r, nil)
return
}
renderNotFoundPage(w, r, nil)
s.router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
r = s.updateContextFromCookie(r)
if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
renderClientNotFoundPage(w, r, nil)
return
}
sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
}))
renderNotFoundPage(w, r, nil)
return
}
sendAPIResponse(w, r, nil, http.StatusText(http.StatusNotFound), http.StatusNotFound)
}))
router.Get(tokenPath, s.getToken)
s.router.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) {
render.PlainText(w, r, "ok")
})
router.Group(func(router chi.Router) {
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
router.Use(jwtAuthenticatorAPI)
s.router.Get(tokenPath, s.getToken)
router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, version.Get())
})
s.router.Group(func(router chi.Router) {
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
router.Use(jwtAuthenticatorAPI)
router.Get(logoutPath, s.logout)
router.Put(adminPwdPath, changeAdminPassword)
router.With(checkPerm(dataprovider.PermAdminViewServerStatus)).
Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, getServicesStatus())
})
router.With(checkPerm(dataprovider.PermAdminViewConnections)).
Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, common.Connections.GetStats())
})
router.With(checkPerm(dataprovider.PermAdminCloseConnections)).
Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanPath, getQuotaScans)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanPath, startQuotaScan)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanVFolderPath, getVFolderQuotaScans)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanVFolderPath, startVFolderQuotaScan)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(userPath, addUser)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath, getFolders)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath+"/{name}", getFolderByName)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(folderPath, addFolder)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(folderPath+"/{name}", updateFolder)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(folderPath+"/{name}", deleteFolder)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(dumpDataPath, dumpData)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsage)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateVFolderQuotaUsage)
router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderBanTime, getBanTime)
router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderScore, getScore)
router.With(checkPerm(dataprovider.PermAdminManageDefender)).Post(defenderUnban, unban)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath, getAdmins)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(adminPath, addAdmin)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath+"/{username}", getAdminByUsername)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Put(adminPath+"/{username}", updateAdmin)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Delete(adminPath+"/{username}", deleteAdmin)
router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, version.Get())
})
if s.enableWebAdmin || s.enableWebClient {
router.Group(func(router chi.Router) {
router.Use(compressor.Handler)
fileServer(router, webStaticFilesPath, http.Dir(s.staticFilesPath))
})
if s.enableWebClient {
router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
})
router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
})
} else {
router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
})
router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
})
}
}
router.Get(logoutPath, s.logout)
router.Put(adminPwdPath, changeAdminPassword)
router.With(checkPerm(dataprovider.PermAdminViewServerStatus)).
Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, getServicesStatus())
})
router.With(checkPerm(dataprovider.PermAdminViewConnections)).
Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, common.Connections.GetStats())
})
router.With(checkPerm(dataprovider.PermAdminCloseConnections)).
Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanPath, getQuotaScans)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanPath, startQuotaScan)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanVFolderPath, getVFolderQuotaScans)
router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanVFolderPath, startVFolderQuotaScan)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(userPath, addUser)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath, getFolders)
router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath+"/{name}", getFolderByName)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(folderPath, addFolder)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(folderPath+"/{name}", updateFolder)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(folderPath+"/{name}", deleteFolder)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(dumpDataPath, dumpData)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsage)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateVFolderQuotaUsage)
router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderBanTime, getBanTime)
router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderScore, getScore)
router.With(checkPerm(dataprovider.PermAdminManageDefender)).Post(defenderUnban, unban)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath, getAdmins)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(adminPath, addAdmin)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath+"/{username}", getAdminByUsername)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Put(adminPath+"/{username}", updateAdmin)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Delete(adminPath+"/{username}", deleteAdmin)
})
if s.enableWebAdmin || s.enableWebClient {
s.router.Group(func(router chi.Router) {
router.Use(compressor.Handler)
fileServer(router, webStaticFilesPath, http.Dir(s.staticFilesPath))
})
if s.enableWebClient {
router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) {
s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
})
router.Get(webClientLoginPath, handleClientWebLogin)
router.Post(webClientLoginPath, s.handleWebClientLoginPost)
router.Group(func(router chi.Router) {
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebClient)
router.Get(webClientLogoutPath, handleWebClientLogout)
router.With(s.refreshCookie).Get(webClientFilesPath, handleClientGetFiles)
router.With(s.refreshCookie).Get(webClientCredentialsPath, handleClientGetCredentials)
router.Post(webChangeClientPwdPath, handleWebClientChangePwdPost)
router.With(checkClientPerm(dataprovider.WebClientPubKeyChangeDisabled)).
Post(webChangeClientKeysPath, handleWebClientManageKeysPost)
s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
})
}
if s.enableWebAdmin {
router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) {
} else {
s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
})
router.Get(webLoginPath, handleWebLogin)
router.Post(webLoginPath, s.handleWebAdminLoginPost)
router.Group(func(router chi.Router) {
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebAdmin)
router.Get(webLogoutPath, handleWebLogout)
router.With(s.refreshCookie).Get(webChangeAdminPwdPath, handleWebAdminChangePwd)
router.Post(webChangeAdminPwdPath, handleWebAdminChangePwdPost)
router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
Get(webUsersPath, handleGetWebUsers)
router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
Get(webUserPath, handleWebAddUserGet)
router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
Get(webUserPath+"/{username}", handleWebUpdateUserGet)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webUserPath, handleWebAddUserPost)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", handleWebUpdateUserPost)
router.With(checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie).
Get(webConnectionsPath, handleWebGetConnections)
router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
Get(webFoldersPath, handleWebGetFolders)
router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
Get(webFolderPath, handleWebAddFolderGet)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webFolderPath, handleWebAddFolderPost)
router.With(checkPerm(dataprovider.PermAdminViewServerStatus), s.refreshCookie).
Get(webStatusPath, handleWebGetStatus)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
Get(webAdminsPath, handleGetWebAdmins)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
Get(webAdminPath, handleWebAddAdminGet)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
Get(webAdminPath+"/{username}", handleWebUpdateAdminGet)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, handleWebAddAdminPost)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", handleWebUpdateAdminPost)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader).
Delete(webAdminPath+"/{username}", deleteAdmin)
router.With(checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader).
Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
Get(webFolderPath+"/{name}", handleWebUpdateFolderGet)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webFolderPath+"/{name}", handleWebUpdateFolderPost)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
Delete(webFolderPath+"/{name}", deleteFolder)
router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
Post(webScanVFolderPath, startVFolderQuotaScan)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
Delete(webUserPath+"/{username}", deleteUser)
router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
Post(webQuotaScanPath, startQuotaScan)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, handleWebMaintenance)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webRestorePath, handleWebRestore)
router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
Get(webTemplateUser, handleWebTemplateUserGet)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateUser, handleWebTemplateUserPost)
router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
Get(webTemplateFolder, handleWebTemplateFolderGet)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateFolder, handleWebTemplateFolderPost)
s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
})
}
})
}
if s.enableWebClient {
s.router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
})
s.router.Get(webClientLoginPath, handleClientWebLogin)
s.router.Post(webClientLoginPath, s.handleWebClientLoginPost)
s.router.Group(func(router chi.Router) {
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebClient)
router.Get(webClientLogoutPath, handleWebClientLogout)
router.With(s.refreshCookie).Get(webClientFilesPath, handleClientGetFiles)
router.With(s.refreshCookie).Get(webClientCredentialsPath, handleClientGetCredentials)
router.Post(webChangeClientPwdPath, handleWebClientChangePwdPost)
router.With(checkClientPerm(dataprovider.WebClientPubKeyChangeDisabled)).
Post(webChangeClientKeysPath, handleWebClientManageKeysPost)
})
}
if s.enableWebAdmin {
s.router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
})
s.router.Get(webLoginPath, handleWebLogin)
s.router.Post(webLoginPath, s.handleWebAdminLoginPost)
s.router.Group(func(router chi.Router) {
router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebAdmin)
router.Get(webLogoutPath, handleWebLogout)
router.With(s.refreshCookie).Get(webChangeAdminPwdPath, handleWebAdminChangePwd)
router.Post(webChangeAdminPwdPath, handleWebAdminChangePwdPost)
router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
Get(webUsersPath, handleGetWebUsers)
router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
Get(webUserPath, handleWebAddUserGet)
router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
Get(webUserPath+"/{username}", handleWebUpdateUserGet)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webUserPath, handleWebAddUserPost)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", handleWebUpdateUserPost)
router.With(checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie).
Get(webConnectionsPath, handleWebGetConnections)
router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
Get(webFoldersPath, handleWebGetFolders)
router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
Get(webFolderPath, handleWebAddFolderGet)
router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webFolderPath, handleWebAddFolderPost)
router.With(checkPerm(dataprovider.PermAdminViewServerStatus), s.refreshCookie).
Get(webStatusPath, handleWebGetStatus)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
Get(webAdminsPath, handleGetWebAdmins)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
Get(webAdminPath, handleWebAddAdminGet)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
Get(webAdminPath+"/{username}", handleWebUpdateAdminGet)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, handleWebAddAdminPost)
router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", handleWebUpdateAdminPost)
router.With(checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader).
Delete(webAdminPath+"/{username}", deleteAdmin)
router.With(checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader).
Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
Get(webFolderPath+"/{name}", handleWebUpdateFolderGet)
router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webFolderPath+"/{name}", handleWebUpdateFolderPost)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
Delete(webFolderPath+"/{name}", deleteFolder)
router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
Post(webScanVFolderPath, startVFolderQuotaScan)
router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
Delete(webUserPath+"/{username}", deleteUser)
router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
Post(webQuotaScanPath, startQuotaScan)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, handleWebMaintenance)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webRestorePath, handleWebRestore)
router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
Get(webTemplateUser, handleWebTemplateUserGet)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateUser, handleWebTemplateUserPost)
router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
Get(webTemplateFolder, handleWebTemplateFolderGet)
router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateFolder, handleWebTemplateFolderPost)
})
}
}

View file

@ -279,24 +279,11 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
}
func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
claims, err := getTokenClaims(r)
if err != nil || claims.Username == "" {
renderClientForbiddenPage(w, r, "Invalid token claims")
return
}
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
renderClientForbiddenPage(w, r, "configured connections limit reached")
return
}
if common.IsBanned(ipAddr) {
renderClientForbiddenPage(w, r, "your IP address is banned")
return
}
user, err := dataprovider.UserExists(claims.Username)
if err != nil {
@ -635,16 +622,5 @@ func checkWebClientUser(user *dataprovider.User, r *http.Request, connectionID s
logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
return fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
}
if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
if connAddr != r.RemoteAddr {
connIPAddr := utils.GetIPFromRemoteAddress(connAddr)
if common.IsBanned(connIPAddr) {
return errors.New("your IP address is banned")
}
if !user.IsLoginFromAddrAllowed(connIPAddr) {
return fmt.Errorf("login from IP %v is not allowed", connIPAddr)
}
}
}
return nil
}

View file

@ -553,7 +553,7 @@ func TestDefender(t *testing.T) {
_, _, err = getSftpClient(user, usePubKey)
assert.Error(t, err)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
@ -2907,21 +2907,24 @@ func TestMaxConnections(t *testing.T) {
common.Config.MaxTotalConnections = 1
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
user := getTestUser(usePubKey)
err := dataprovider.AddUser(&user)
assert.NoError(t, err)
user.Password = ""
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
s, c, err := getSftpClient(user, usePubKey)
if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
c.Close()
s.Close()
}
err = client.Close()
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
@ -2935,21 +2938,24 @@ func TestMaxPerHostConnections(t *testing.T) {
common.Config.MaxPerHostConnections = 1
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
user := getTestUser(usePubKey)
err := dataprovider.AddUser(&user)
assert.NoError(t, err)
user.Password = ""
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
s, c, err := getSftpClient(user, usePubKey)
if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
c.Close()
s.Close()
}
err = client.Close()
assert.NoError(t, err)
err = conn.Close()
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)

View file

@ -107,7 +107,8 @@
"enable_https": false,
"client_auth_type": 0,
"tls_cipher_suites": [],
"prefix": ""
"prefix": "",
"proxy_allowed": []
}
],
"certificate_file": "",
@ -180,12 +181,13 @@
"bindings": [
{
"port": 8080,
"address": "127.0.0.1",
"address": "",
"enable_web_admin": true,
"enable_web_client": true,
"enable_https": false,
"client_auth_type": 0,
"tls_cipher_suites": []
"tls_cipher_suites": [],
"proxy_allowed": []
}
],
"templates_path": "templates",

View file

@ -209,6 +209,7 @@
<i class="fas fa-angle-up"></i>
</a>
{{if .LoggedAdmin.Username}}
<!-- Logout Modal-->
<div class="modal fade" id="logoutModal" tabindex="-1" role="dialog" aria-labelledby="modalLabel"
aria-hidden="true">
@ -228,6 +229,7 @@
</div>
</div>
</div>
{{end}}
{{block "dialog" .}}{{end}}

View file

@ -168,6 +168,7 @@
<i class="fas fa-angle-up"></i>
</a>
{{if .LoggedUser.Username}}
<!-- Logout Modal-->
<div class="modal fade" id="logoutModal" tabindex="-1" role="dialog" aria-labelledby="modalLabel"
aria-hidden="true">
@ -187,6 +188,7 @@
</div>
</div>
</div>
{{end}}
{{block "dialog" .}}{{end}}

View file

@ -57,7 +57,7 @@
<div class="form-group row">
<label for="idPublicKeys" class="col-sm-2 col-form-label">Keys</label>
<div class="col-sm-10">
<textarea class="form-control" id="idPublicKeys" name="public_keys" rows="3"
<textarea class="form-control" id="idPublicKeys" name="public_keys" rows="5"
aria-describedby="pkHelpBlock">{{range .PublicKeys}}{{.}}&#10;{{end}}</textarea>
<small id="pkHelpBlock" class="form-text text-muted">
One public key per line

View file

@ -38,6 +38,11 @@ const (
osWindows = "windows"
)
var (
xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
xRealIP = http.CanonicalHeaderKey("X-Real-IP")
)
// IsStringInSlice searches a string in a slice and returns true if the string is found
func IsStringInSlice(obj string, list []string) bool {
for i := 0; i < len(list); i++ {
@ -516,3 +521,48 @@ func GetSSHPublicKeyAsString(pubKey []byte) (string, error) {
}
return string(ssh.MarshalAuthorizedKey(k)), nil
}
// GetRealIP returns the ip address as result of parsing either the
// X-Real-IP header or the X-Forwarded-For header
func GetRealIP(r *http.Request) string {
var ip string
if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
i := strings.Index(xff, ", ")
if i == -1 {
i = len(xff)
}
ip = strings.TrimSpace(xff[:i])
}
if net.ParseIP(ip) == nil {
return ""
}
return ip
}
// ParseAllowedIPAndRanges returns a list of functions that allow to find if a
func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) {
res := make([]func(net.IP) bool, len(allowed))
for i, allowFrom := range allowed {
if strings.LastIndex(allowFrom, "/") > 0 {
_, ipRange, err := net.ParseCIDR(allowFrom)
if err != nil {
return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err)
}
res[i] = ipRange.Contains
} else {
allowed := net.ParseIP(allowFrom)
if allowed == nil {
return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom)
}
res[i] = allowed.Equal
}
}
return res, nil
}

View file

@ -23,6 +23,7 @@ import (
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/kms"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/vfs"
)
@ -409,36 +410,68 @@ func TestUserInvalidParams(t *testing.T) {
}
func TestRemoteAddress(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/username", nil)
assert.NoError(t, err)
assert.Empty(t, req.RemoteAddr)
remoteAddr1 := "100.100.100.100"
remoteAddr2 := "172.172.172.172"
c := &Configuration{
Bindings: []Binding{
{
Port: 9000,
ProxyAllowed: []string{remoteAddr2, "10.8.0.0/30"},
},
},
}
server := webDavServer{
config: c,
binding: c.Bindings[0],
}
err := server.binding.parseAllowedProxy()
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "/", nil)
assert.NoError(t, err)
assert.Empty(t, req.RemoteAddr)
req.Header.Set("X-Forwarded-For", remoteAddr1)
checkRemoteAddress(req)
assert.Equal(t, remoteAddr1, req.RemoteAddr)
ip := utils.GetRealIP(req)
assert.Equal(t, remoteAddr1, ip)
// this will be ignore, remoteAddr1 is not allowed to se this header
req.Header.Set("X-Forwarded-For", remoteAddr2)
req.RemoteAddr = remoteAddr1
ip = server.checkRemoteAddress(req)
assert.Equal(t, remoteAddr1, ip)
req.RemoteAddr = ""
ip = server.checkRemoteAddress(req)
assert.Empty(t, ip)
req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", remoteAddr2, remoteAddr1))
checkRemoteAddress(req)
assert.Equal(t, remoteAddr2, req.RemoteAddr)
ip = utils.GetRealIP(req)
assert.Equal(t, remoteAddr2, ip)
req.RemoteAddr = remoteAddr2
req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", "12.34.56.78", "172.16.2.4"))
ip = server.checkRemoteAddress(req)
assert.Equal(t, "12.34.56.78", ip)
assert.Equal(t, ip, req.RemoteAddr)
req.RemoteAddr = "10.8.0.2"
req.Header.Set("X-Forwarded-For", remoteAddr1)
ip = server.checkRemoteAddress(req)
assert.Equal(t, remoteAddr1, ip)
assert.Equal(t, ip, req.RemoteAddr)
req.RemoteAddr = "10.8.0.3"
req.Header.Set("X-Forwarded-For", "not an ip")
ip = server.checkRemoteAddress(req)
assert.Equal(t, "10.8.0.3", ip)
assert.Equal(t, ip, req.RemoteAddr)
req.Header.Del("X-Forwarded-For")
req.RemoteAddr = ""
req.Header.Set("X-Real-IP", remoteAddr1)
checkRemoteAddress(req)
assert.Equal(t, remoteAddr1, req.RemoteAddr)
ip = utils.GetRealIP(req)
assert.Equal(t, remoteAddr1, ip)
req.RemoteAddr = ""
oldValue := common.Config.ProxyProtocol
common.Config.ProxyProtocol = 1
checkRemoteAddress(req)
assert.Empty(t, req.RemoteAddr)
common.Config.ProxyProtocol = oldValue
}
func TestConnWithNilRequest(t *testing.T) {

View file

@ -7,11 +7,11 @@ import (
"errors"
"fmt"
"log"
"net"
"net/http"
"path"
"path/filepath"
"runtime/debug"
"strings"
"time"
"github.com/go-chi/chi/v5/middleware"
@ -27,9 +27,7 @@ import (
)
var (
err401 = errors.New("Unauthorized")
xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
xRealIP = http.CanonicalHeaderKey("X-Real-IP")
err401 = errors.New("Unauthorized")
)
type webDavServer struct {
@ -145,8 +143,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()
checkRemoteAddress(r)
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
ipAddr := s.checkRemoteAddress(r)
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
@ -327,6 +324,24 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
return connID, nil
}
func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
ip := net.ParseIP(ipAddr)
if ip != nil {
for _, allow := range s.binding.allowHeadersFrom {
if allow(ip) {
parsedIP := utils.GetRealIP(r)
if parsedIP != "" {
ipAddr = parsedIP
r.RemoteAddr = ipAddr
}
break
}
}
}
return ipAddr
}
func writeLog(r *http.Request, err error) {
scheme := "http"
if r.TLS != nil {
@ -352,28 +367,6 @@ func writeLog(r *http.Request, err error) {
Send()
}
func checkRemoteAddress(r *http.Request) {
if common.Config.ProxyProtocol != 0 {
return
}
var ip string
if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
i := strings.Index(xff, ", ")
if i == -1 {
i = len(xff)
}
ip = strings.TrimSpace(xff[:i])
}
if len(ip) > 0 {
r.RemoteAddr = ip
}
}
func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error) {
metrics.AddLoginAttempt(loginMethod)
if err != nil {

View file

@ -3,6 +3,7 @@ package webdavd
import (
"fmt"
"net"
"path/filepath"
"github.com/go-chi/chi/v5/middleware"
@ -90,6 +91,18 @@ type Binding struct {
// Prefix for WebDAV resources, if empty WebDAV resources will be available at the
// root ("/") URI. If defined it must be an absolute URI.
Prefix string `json:"prefix" mapstructure:"prefix"`
// List of IP addresses and IP ranges allowed to set X-Forwarded-For/X-Real-IP headers.
ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"`
allowHeadersFrom []func(net.IP) bool
}
func (b *Binding) parseAllowedProxy() error {
allowedFuncs, err := utils.ParseAllowedIPAndRanges(b.ProxyAllowed)
if err != nil {
return err
}
b.allowHeadersFrom = allowedFuncs
return nil
}
func (b *Binding) isMutualTLSEnabled() bool {
@ -191,6 +204,9 @@ func (c *Configuration) Initialize(configDir string) error {
if !binding.IsValid() {
continue
}
if err := binding.parseAllowedProxy(); err != nil {
return err
}
go func(binding Binding) {
server := webDavServer{

View file

@ -469,6 +469,12 @@ func TestInitialization(t *testing.T) {
cfg.CertificateKeyFile = keyPath
cfg.CACertificates = []string{caCrtPath}
cfg.CARevocationLists = []string{caCRLPath}
cfg.Bindings[0].ProxyAllowed = []string{"not valid"}
err = cfg.Initialize(configDir)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "is not a valid IP address")
}
cfg.Bindings[0].ProxyAllowed = nil
err = cfg.Initialize(configDir)
assert.Error(t, err)
}
@ -1180,6 +1186,10 @@ func TestQuotaLimits(t *testing.T) {
// test quota files
err = uploadFile(testFilePath, testFileName+".quota", testFileSize, client)
if !assert.NoError(t, err, "username: %v", user.Username) {
info, err := os.Stat(testFilePath)
if assert.NoError(t, err) {
fmt.Printf("local file size %v", info.Size())
}
printLatestLogs(20)
}
err = uploadFile(testFilePath, testFileName+".quota1", testFileSize, client)