Forráskód Böngészése

httpd/webdav: add a list of hosts allowed to send proxy headers

X-Forwarded-For, X-Real-IP and X-Forwarded-Proto headers will be ignored
for hosts not included in this list.

This is a backward incompatible change, before the proxy headers were
always used
Nicola Murino 4 éve
szülő
commit
c8f7fc9bc9

+ 1 - 2
Dockerfile

@@ -52,8 +52,7 @@ ENV SFTPGO_HTTPD__STATIC_FILES_PATH=/usr/share/sftpgo/static
 
 
 # Modify the default configuration file
 # Modify the default configuration file
 RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \
 RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \
-    sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json && \
-    sed -i "s|\"address\": \"127.0.0.1\",|\"address\": \"\",|" /etc/sftpgo/sftpgo.json
+    sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json
 
 
 RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups
 RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups
 
 

+ 1 - 2
Dockerfile.alpine

@@ -57,8 +57,7 @@ ENV SFTPGO_HTTPD__STATIC_FILES_PATH=/usr/share/sftpgo/static
 
 
 # Modify the default configuration file
 # Modify the default configuration file
 RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \
 RUN sed -i "s|\"users_base_dir\": \"\",|\"users_base_dir\": \"/srv/sftpgo/data\",|" /etc/sftpgo/sftpgo.json && \
-    sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json && \
-    sed -i "s|\"address\": \"127.0.0.1\",|\"address\": \"\",|" /etc/sftpgo/sftpgo.json
+    sed -i "s|\"backups\"|\"/srv/sftpgo/backups\"|" /etc/sftpgo/sftpgo.json
 
 
 RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups
 RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups
 
 

+ 13 - 0
common/common_test.go

@@ -697,6 +697,19 @@ func TestCachedFs(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
 
 
+func TestParseAllowedIPAndRanges(t *testing.T) {
+	_, err := utils.ParseAllowedIPAndRanges([]string{"1.1.1.1", "not an ip"})
+	assert.Error(t, err)
+	_, err = utils.ParseAllowedIPAndRanges([]string{"1.1.1.5", "192.168.1.0/240"})
+	assert.Error(t, err)
+	allow, err := utils.ParseAllowedIPAndRanges([]string{"192.168.1.2", "172.16.0.0/24"})
+	assert.NoError(t, err)
+	assert.True(t, allow[0](net.ParseIP("192.168.1.2")))
+	assert.False(t, allow[0](net.ParseIP("192.168.2.2")))
+	assert.True(t, allow[1](net.ParseIP("172.16.0.1")))
+	assert.False(t, allow[1](net.ParseIP("172.16.1.1")))
+}
+
 func BenchmarkBcryptHashing(b *testing.B) {
 func BenchmarkBcryptHashing(b *testing.B) {
 	bcryptPassword := "bcryptpassword"
 	bcryptPassword := "bcryptpassword"
 	for i := 0; i < b.N; i++ {
 	for i := 0; i < b.N; i++ {

+ 14 - 0
config/config.go

@@ -60,6 +60,7 @@ var (
 		ClientAuthType:  0,
 		ClientAuthType:  0,
 		TLSCipherSuites: nil,
 		TLSCipherSuites: nil,
 		Prefix:          "",
 		Prefix:          "",
+		ProxyAllowed:    nil,
 	}
 	}
 	defaultHTTPDBinding = httpd.Binding{
 	defaultHTTPDBinding = httpd.Binding{
 		Address:         "127.0.0.1",
 		Address:         "127.0.0.1",
@@ -69,6 +70,7 @@ var (
 		EnableHTTPS:     false,
 		EnableHTTPS:     false,
 		ClientAuthType:  0,
 		ClientAuthType:  0,
 		TLSCipherSuites: nil,
 		TLSCipherSuites: nil,
+		ProxyAllowed:    nil,
 	}
 	}
 	defaultRateLimiter = common.RateLimiterConfig{
 	defaultRateLimiter = common.RateLimiterConfig{
 		Average:                0,
 		Average:                0,
@@ -768,6 +770,12 @@ func getWebDAVDBindingFromEnv(idx int) {
 		isSet = true
 		isSet = true
 	}
 	}
 
 
+	proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_ALLOWED", idx))
+	if ok {
+		binding.ProxyAllowed = proxyAllowed
+		isSet = true
+	}
+
 	prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
 	prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
 	if ok {
 	if ok {
 		binding.Prefix = prefix
 		binding.Prefix = prefix
@@ -833,6 +841,12 @@ func getHTTPDBindingFromEnv(idx int) {
 		isSet = true
 		isSet = true
 	}
 	}
 
 
+	proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_ALLOWED", idx))
+	if ok {
+		binding.ProxyAllowed = proxyAllowed
+		isSet = true
+	}
+
 	if isSet {
 	if isSet {
 		if len(globalConf.HTTPDConfig.Bindings) > idx {
 		if len(globalConf.HTTPDConfig.Bindings) > idx {
 			globalConf.HTTPDConfig.Bindings[idx] = binding
 			globalConf.HTTPDConfig.Bindings[idx] = binding

+ 8 - 0
config/config_test.go

@@ -572,6 +572,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT", "8000")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT", "8000")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ")
+	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED", "192.168.10.1")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS", "127.0.1.1")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS", "127.0.1.1")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT", "9000")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT", "9000")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS", "1")
 	os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS", "1")
@@ -582,6 +583,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES")
+		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS")
 		os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS")
@@ -605,6 +607,7 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
 	require.Equal(t, 0, bindings[1].ClientAuthType)
 	require.Equal(t, 0, bindings[1].ClientAuthType)
 	require.Len(t, bindings[1].TLSCipherSuites, 1)
 	require.Len(t, bindings[1].TLSCipherSuites, 1)
 	require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0])
 	require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0])
+	require.Equal(t, "192.168.10.1", bindings[1].ProxyAllowed[0])
 	require.Empty(t, bindings[1].Prefix)
 	require.Empty(t, bindings[1].Prefix)
 	require.Equal(t, 9000, bindings[2].Port)
 	require.Equal(t, 9000, bindings[2].Port)
 	require.Equal(t, "127.0.1.1", bindings[2].Address)
 	require.Equal(t, "127.0.1.1", bindings[2].Address)
@@ -634,6 +637,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
 	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1")
 	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1")
 	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1")
 	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1")
 	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256")
 	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256")
+	os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED", " 192.168.9.1 , 172.16.25.0/24")
 	t.Cleanup(func() {
 	t.Cleanup(func() {
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__PORT")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__PORT")
@@ -650,6 +654,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES")
 		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES")
+		os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED")
 	})
 	})
 
 
 	configDir := ".."
 	configDir := ".."
@@ -680,6 +685,9 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
 	require.Len(t, bindings[2].TLSCipherSuites, 2)
 	require.Len(t, bindings[2].TLSCipherSuites, 2)
 	require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0])
 	require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0])
 	require.Equal(t, "TLS_CHACHA20_POLY1305_SHA256", bindings[2].TLSCipherSuites[1])
 	require.Equal(t, "TLS_CHACHA20_POLY1305_SHA256", bindings[2].TLSCipherSuites[1])
+	require.Len(t, bindings[2].ProxyAllowed, 2)
+	require.Equal(t, "192.168.9.1", bindings[2].ProxyAllowed[0])
+	require.Equal(t, "172.16.25.0/24", bindings[2].ProxyAllowed[1])
 }
 }
 
 
 func TestHTTPClientCertificatesFromEnv(t *testing.T) {
 func TestHTTPClientCertificatesFromEnv(t *testing.T) {

+ 3 - 1
docs/full-configuration.md

@@ -147,6 +147,7 @@ The configuration file contains the following sections:
     - `client_auth_type`, integer. Set to `1` to require a client certificate and verify it. Set to `2` to request a client certificate during the TLS handshake and verify it if given, in this mode the client is allowed not to send a certificate. At least one certification authority must be defined in order to verify client certificates. If no certification authority is defined, this setting is ignored. Default: 0.
     - `client_auth_type`, integer. Set to `1` to require a client certificate and verify it. Set to `2` to request a client certificate during the TLS handshake and verify it if given, in this mode the client is allowed not to send a certificate. At least one certification authority must be defined in order to verify client certificates. If no certification authority is defined, this setting is ignored. Default: 0.
     - `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty.
     - `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty.
     - `prefix`, string. Prefix for WebDAV resources, if empty WebDAV resources will be available at the `/` URI. If defined it must be an absolute URI, for example `/dav`. Default: "".
     - `prefix`, string. Prefix for WebDAV resources, if empty WebDAV resources will be available at the `/` URI. If defined it must be an absolute URI, for example `/dav`. Default: "".
+    - `proxy_allowed`, list of IP addresses and IP ranges allowed to set `X-Forwarded-For`, `X-Real-IP` headers. Any of the indicated headers, if set on requests from a connection address not in this list, will be silently ignored. Default: empty.
   - `bind_port`, integer. Deprecated, please use `bindings`.
   - `bind_port`, integer. Deprecated, please use `bindings`.
   - `bind_address`, string. Deprecated, please use `bindings`.
   - `bind_address`, string. Deprecated, please use `bindings`.
   - `certificate_file`, string. Certificate for WebDAV over HTTPS. This can be an absolute path or a path relative to the config dir.
   - `certificate_file`, string. Certificate for WebDAV over HTTPS. This can be an absolute path or a path relative to the config dir.
@@ -216,8 +217,9 @@ The configuration file contains the following sections:
     - `enable_https`, boolean. Set to `true` and provide both a certificate and a key file to enable HTTPS connection for this binding. Default `false`.
     - `enable_https`, boolean. Set to `true` and provide both a certificate and a key file to enable HTTPS connection for this binding. Default `false`.
     - `client_auth_type`, integer. Set to `1` to require client certificate authentication in addition to JWT/Web authentication. You need to define at least a certificate authority for this to work. Default: 0.
     - `client_auth_type`, integer. Set to `1` to require client certificate authentication in addition to JWT/Web authentication. You need to define at least a certificate authority for this to work. Default: 0.
     - `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty.
     - `tls_cipher_suites`, list of strings. List of supported cipher suites for TLS version 1.2. If empty, a default list of secure cipher suites is used, with a preference order based on hardware performance. Note that TLS 1.3 ciphersuites are not configurable. The supported ciphersuites names are defined [here](https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L52). Any invalid name will be silently ignored. The order matters, the ciphers listed first will be the preferred ones. Default: empty.
+    - `proxy_allowed`, list of IP addresses and IP ranges allowed to set `X-Forwarded-For`, `X-Real-IP`, `X-Forwarded-Proto` headers. Any of the indicated headers, if set on requests from a connection address not in this list, will be silently ignored. Default: empty.
   - `bind_port`, integer. Deprecated, please use `bindings`.
   - `bind_port`, integer. Deprecated, please use `bindings`.
-  - `bind_address`, string. Deprecated, please use `bindings`. Leave blank to listen on all available network interfaces. On \*NIX you can specify an absolute path to listen on a Unix-domain socket. Default: "127.0.0.1"
+  - `bind_address`, string. Deprecated, please use `bindings`. Leave blank to listen on all available network interfaces. On \*NIX you can specify an absolute path to listen on a Unix-domain socket. Default: ""
   - `templates_path`, string. Path to the HTML web templates. This can be an absolute path or a path relative to the config dir
   - `templates_path`, string. Path to the HTML web templates. This can be an absolute path or a path relative to the config dir
   - `static_files_path`, string. Path to the static files for the web interface. This can be an absolute path or a path relative to the config dir. If both `templates_path` and `static_files_path` are empty the built-in web interface will be disabled
   - `static_files_path`, string. Path to the static files for the web interface. This can be an absolute path or a path relative to the config dir. If both `templates_path` and `static_files_path` are empty the built-in web interface will be disabled
   - `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons
   - `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons

+ 1 - 7
docs/howto/postgresql-s3.md

@@ -192,13 +192,7 @@ systemctl status sftpgo
 
 
 The easiest way to add virtual users is to use the built-in Web interface.
 The easiest way to add virtual users is to use the built-in Web interface.
 
 
-You can expose the Web Admin interface over the network replacing `"bind_address": "127.0.0.1"` in the `httpd` configuration section with `"bind_address": ""` and apply the change restarting the SFTPGo service with the following command.
-
-```shell
-sudo systemctl restart sftpgo
-```
-
-So now open the Web Admin URL.
+So navigate to the Web Admin URL.
 
 
 [http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin)
 [http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin)
 
 

+ 10 - 6
ftpd/ftpd_test.go

@@ -760,8 +760,10 @@ func TestMaxConnections(t *testing.T) {
 	oldValue := common.Config.MaxTotalConnections
 	oldValue := common.Config.MaxTotalConnections
 	common.Config.MaxTotalConnections = 1
 	common.Config.MaxTotalConnections = 1
 
 
-	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	user := getTestUser()
+	err := dataprovider.AddUser(&user)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
+	user.Password = ""
 	client, err := getFTPClient(user, true, nil)
 	client, err := getFTPClient(user, true, nil)
 	if assert.NoError(t, err) {
 	if assert.NoError(t, err) {
 		err = checkBasicFTP(client)
 		err = checkBasicFTP(client)
@@ -771,7 +773,7 @@ func TestMaxConnections(t *testing.T) {
 		err = client.Quit()
 		err = client.Quit()
 		assert.NoError(t, err)
 		assert.NoError(t, err)
 	}
 	}
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)
@@ -783,8 +785,10 @@ func TestMaxPerHostConnections(t *testing.T) {
 	oldValue := common.Config.MaxPerHostConnections
 	oldValue := common.Config.MaxPerHostConnections
 	common.Config.MaxPerHostConnections = 1
 	common.Config.MaxPerHostConnections = 1
 
 
-	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	user := getTestUser()
+	err := dataprovider.AddUser(&user)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
+	user.Password = ""
 	client, err := getFTPClient(user, true, nil)
 	client, err := getFTPClient(user, true, nil)
 	if assert.NoError(t, err) {
 	if assert.NoError(t, err) {
 		err = checkBasicFTP(client)
 		err = checkBasicFTP(client)
@@ -794,7 +798,7 @@ func TestMaxPerHostConnections(t *testing.T) {
 		err = client.Quit()
 		err = client.Quit()
 		assert.NoError(t, err)
 		assert.NoError(t, err)
 	}
 	}
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)
@@ -851,7 +855,7 @@ func TestRateLimiter(t *testing.T) {
 		assert.Contains(t, err.Error(), "banned client IP")
 		assert.Contains(t, err.Error(), "banned client IP")
 	}
 	}
 
 
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)
@@ -893,7 +897,7 @@ func TestDefender(t *testing.T) {
 		assert.Contains(t, err.Error(), "banned client IP")
 		assert.Contains(t, err.Error(), "banned client IP")
 	}
 	}
 
 
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)

+ 12 - 2
httpd/auth_utils.go

@@ -137,7 +137,7 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque
 		Path:     basePath,
 		Path:     basePath,
 		Expires:  time.Now().Add(tokenDuration),
 		Expires:  time.Now().Add(tokenDuration),
 		HttpOnly: true,
 		HttpOnly: true,
-		Secure:   r.TLS != nil,
+		Secure:   isTLS(r),
 	})
 	})
 
 
 	return nil
 	return nil
@@ -150,11 +150,21 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request) {
 		Path:     webBasePath,
 		Path:     webBasePath,
 		MaxAge:   -1,
 		MaxAge:   -1,
 		HttpOnly: true,
 		HttpOnly: true,
-		Secure:   r.TLS != nil,
+		Secure:   isTLS(r),
 	})
 	})
 	invalidateToken(r)
 	invalidateToken(r)
 }
 }
 
 
+func isTLS(r *http.Request) bool {
+	if r.TLS != nil {
+		return true
+	}
+	if proto, ok := r.Context().Value(forwardedProtoKey).(string); ok {
+		return proto == "https"
+	}
+	return false
+}
+
 func isTokenInvalidated(r *http.Request) bool {
 func isTokenInvalidated(r *http.Request) bool {
 	isTokenFound := false
 	isTokenFound := false
 	token := jwtauth.TokenFromHeader(r)
 	token := jwtauth.TokenFromHeader(r)

+ 27 - 3
httpd/httpd.go

@@ -6,6 +6,7 @@ package httpd
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"path"
 	"path"
@@ -156,6 +157,19 @@ type Binding struct {
 	// any invalid name will be silently ignored.
 	// any invalid name will be silently ignored.
 	// The order matters, the ciphers listed first will be the preferred ones.
 	// The order matters, the ciphers listed first will be the preferred ones.
 	TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"`
 	TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"`
+	// List of IP addresses and IP ranges allowed to set X-Forwarded-For, X-Real-IP,
+	// X-Forwarded-Proto headers.
+	ProxyAllowed     []string `json:"proxy_allowed" mapstructure:"proxy_allowed"`
+	allowHeadersFrom []func(net.IP) bool
+}
+
+func (b *Binding) parseAllowedProxy() error {
+	allowedFuncs, err := utils.ParseAllowedIPAndRanges(b.ProxyAllowed)
+	if err != nil {
+		return err
+	}
+	b.allowHeadersFrom = allowedFuncs
+	return nil
 }
 }
 
 
 // GetAddress returns the binding address
 // GetAddress returns the binding address
@@ -252,6 +266,14 @@ func (c *Conf) isWebClientEnabled() bool {
 	return false
 	return false
 }
 }
 
 
+func (c *Conf) checkRequiredDirs(staticFilesPath, templatesPath string) error {
+	if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") {
+		return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v",
+			staticFilesPath, templatesPath)
+	}
+	return nil
+}
+
 // Initialize configures and starts the HTTP server
 // Initialize configures and starts the HTTP server
 func (c *Conf) Initialize(configDir string) error {
 func (c *Conf) Initialize(configDir string) error {
 	logger.Debug(logSender, "", "initializing HTTP server with config %+v", c)
 	logger.Debug(logSender, "", "initializing HTTP server with config %+v", c)
@@ -261,9 +283,8 @@ func (c *Conf) Initialize(configDir string) error {
 	if backupsPath == "" {
 	if backupsPath == "" {
 		return fmt.Errorf("required directory is invalid, backup path %#v", backupsPath)
 		return fmt.Errorf("required directory is invalid, backup path %#v", backupsPath)
 	}
 	}
-	if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") {
-		return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v",
-			staticFilesPath, templatesPath)
+	if err := c.checkRequiredDirs(staticFilesPath, templatesPath); err != nil {
+		return err
 	}
 	}
 	certificateFile := getConfigPath(c.CertificateFile, configDir)
 	certificateFile := getConfigPath(c.CertificateFile, configDir)
 	certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir)
 	certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir)
@@ -303,6 +324,9 @@ func (c *Conf) Initialize(configDir string) error {
 		if !binding.IsValid() {
 		if !binding.IsValid() {
 			continue
 			continue
 		}
 		}
+		if err := binding.parseAllowedProxy(); err != nil {
+			return err
+		}
 
 
 		go func(b Binding) {
 		go func(b Binding) {
 			server := newHttpdServer(b, staticFilesPath)
 			server := newHttpdServer(b, staticFilesPath)

+ 35 - 1
httpd/httpd_test.go

@@ -326,6 +326,12 @@ func TestInitialization(t *testing.T) {
 	err = httpdConf.Initialize(configDir)
 	err = httpdConf.Initialize(configDir)
 	assert.Error(t, err)
 	assert.Error(t, err)
 	httpdConf.CARevocationLists = nil
 	httpdConf.CARevocationLists = nil
+	httpdConf.Bindings[0].ProxyAllowed = []string{"invalid ip/network"}
+	err = httpdConf.Initialize(configDir)
+	if assert.Error(t, err) {
+		assert.Contains(t, err.Error(), "is not a valid IP range")
+	}
+	httpdConf.Bindings[0].ProxyAllowed = nil
 	httpdConf.Bindings[0].EnableWebAdmin = false
 	httpdConf.Bindings[0].EnableWebAdmin = false
 	httpdConf.Bindings[0].EnableWebClient = false
 	httpdConf.Bindings[0].EnableWebClient = false
 	httpdConf.Bindings[0].Port = 8081
 	httpdConf.Bindings[0].Port = 8081
@@ -3288,6 +3294,22 @@ func TestRateLimiter(t *testing.T) {
 	err = resp.Body.Close()
 	err = resp.Body.Close()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
+	resp, err = client.Get(httpBaseURL + webLoginPath)
+	assert.NoError(t, err)
+	assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+	assert.Equal(t, "1", resp.Header.Get("Retry-After"))
+	assert.NotEmpty(t, resp.Header.Get("X-Retry-In"))
+	err = resp.Body.Close()
+	assert.NoError(t, err)
+
+	resp, err = client.Get(httpBaseURL + webClientLoginPath)
+	assert.NoError(t, err)
+	assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+	assert.Equal(t, "1", resp.Header.Get("Retry-After"))
+	assert.NotEmpty(t, resp.Header.Get("X-Retry-In"))
+	err = resp.Body.Close()
+	assert.NoError(t, err)
+
 	err = common.Initialize(oldConfig)
 	err = common.Initialize(oldConfig)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
@@ -4625,10 +4647,13 @@ func TestDefender(t *testing.T) {
 
 
 	remoteAddr := "172.16.5.6:9876"
 	remoteAddr := "172.16.5.6:9876"
 
 
+	webAdminToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
+	assert.NoError(t, err)
 	webToken, err := getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr)
 	webToken, err := getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req.RemoteAddr = remoteAddr
 	req.RemoteAddr = remoteAddr
+	req.RequestURI = webClientFilesPath
 	setJWTCookieForReq(req, webToken)
 	setJWTCookieForReq(req, webToken)
 	rr := executeRequest(req)
 	rr := executeRequest(req)
 	checkResponseCode(t, http.StatusOK, rr)
 	checkResponseCode(t, http.StatusOK, rr)
@@ -4642,11 +4667,20 @@ func TestDefender(t *testing.T) {
 	assert.Error(t, err)
 	assert.Error(t, err)
 	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req.RemoteAddr = remoteAddr
 	req.RemoteAddr = remoteAddr
+	req.RequestURI = webClientFilesPath
 	setJWTCookieForReq(req, webToken)
 	setJWTCookieForReq(req, webToken)
 	rr = executeRequest(req)
 	rr = executeRequest(req)
 	checkResponseCode(t, http.StatusForbidden, rr)
 	checkResponseCode(t, http.StatusForbidden, rr)
 	assert.Contains(t, rr.Body.String(), "your IP address is banned")
 	assert.Contains(t, rr.Body.String(), "your IP address is banned")
 
 
+	req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil)
+	req.RemoteAddr = remoteAddr
+	req.RequestURI = webUsersPath
+	setJWTCookieForReq(req, webAdminToken)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusForbidden, rr)
+	assert.Contains(t, rr.Body.String(), "your IP address is banned")
+
 	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req.RemoteAddr = remoteAddr
 	req.RemoteAddr = remoteAddr
 	req.Header.Set("X-Real-IP", "127.0.0.1:2345")
 	req.Header.Set("X-Real-IP", "127.0.0.1:2345")
@@ -5192,7 +5226,7 @@ func TestWebAdminLoginMock(t *testing.T) {
 	req.Header.Set("X-Forwarded-For", "10.9.9.9")
 	req.Header.Set("X-Forwarded-For", "10.9.9.9")
 	rr = executeRequest(req)
 	rr = executeRequest(req)
 	checkResponseCode(t, http.StatusOK, rr)
 	checkResponseCode(t, http.StatusOK, rr)
-	assert.Contains(t, rr.Body.String(), "Login from IP 127.0.1.1:4567 is not allowed")
+	assert.Contains(t, rr.Body.String(), "login from IP 127.0.1.1 not allowed")
 
 
 	// invalid csrf token
 	// invalid csrf token
 	form = getLoginForm(altAdminUsername, altAdminPassword, "invalid csrf")
 	form = getLoginForm(altAdminUsername, altAdminPassword, "invalid csrf")

+ 120 - 33
httpd/internal_test.go

@@ -424,7 +424,7 @@ func TestCreateTokenError(t *testing.T) {
 	}
 	}
 	req, _ := http.NewRequest(http.MethodGet, tokenPath, nil)
 	req, _ := http.NewRequest(http.MethodGet, tokenPath, nil)
 
 
-	server.checkAddrAndSendToken(rr, req, admin)
+	server.generateAndSendToken(rr, req, admin)
 	assert.Equal(t, http.StatusInternalServerError, rr.Code)
 	assert.Equal(t, http.StatusInternalServerError, rr.Code)
 
 
 	rr = httptest.NewRecorder()
 	rr = httptest.NewRecorder()
@@ -565,22 +565,6 @@ func TestJWTTokenValidation(t *testing.T) {
 	assert.Equal(t, http.StatusBadRequest, rr.Code)
 	assert.Equal(t, http.StatusBadRequest, rr.Code)
 }
 }
 
 
-func TestAdminAllowListConnAddr(t *testing.T) {
-	server := httpdServer{}
-	admin := dataprovider.Admin{
-		Filters: dataprovider.AdminFilters{
-			AllowList: []string{"192.168.1.0/24"},
-		},
-	}
-	rr := httptest.NewRecorder()
-	req, _ := http.NewRequest(http.MethodGet, tokenPath, nil)
-	ctx := context.WithValue(req.Context(), connAddrKey, "127.0.0.1:4567")
-	req.RemoteAddr = "192.168.1.16:1234"
-	server.checkAddrAndSendToken(rr, req.WithContext(ctx), admin)
-	assert.Equal(t, http.StatusForbidden, rr.Code, rr.Body.String())
-	assert.Equal(t, "context value connection address", connAddrKey.String())
-}
-
 func TestUpdateContextFromCookie(t *testing.T) {
 func TestUpdateContextFromCookie(t *testing.T) {
 	server := httpdServer{
 	server := httpdServer{
 		tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil),
 		tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil),
@@ -672,14 +656,6 @@ func TestCookieExpiration(t *testing.T) {
 	cookie = rr.Header().Get("Set-Cookie")
 	cookie = rr.Header().Get("Set-Cookie")
 	assert.Empty(t, cookie)
 	assert.Empty(t, cookie)
 
 
-	req, _ = http.NewRequest(http.MethodGet, tokenPath, nil)
-	req.RemoteAddr = "172.16.1.2:1234"
-	ctx = jwtauth.NewContext(req.Context(), token, nil)
-	ctx = context.WithValue(ctx, connAddrKey, "10.9.9.9")
-	server.checkCookieExpiration(rr, req.WithContext(ctx))
-	cookie = rr.Header().Get("Set-Cookie")
-	assert.Empty(t, cookie)
-
 	req, _ = http.NewRequest(http.MethodGet, tokenPath, nil)
 	req, _ = http.NewRequest(http.MethodGet, tokenPath, nil)
 	req.RemoteAddr = "172.16.1.12:4567"
 	req.RemoteAddr = "172.16.1.12:4567"
 	ctx = jwtauth.NewContext(req.Context(), token, nil)
 	ctx = jwtauth.NewContext(req.Context(), token, nil)
@@ -749,17 +725,10 @@ func TestCookieExpiration(t *testing.T) {
 	cookie = rr.Header().Get("Set-Cookie")
 	cookie = rr.Header().Get("Set-Cookie")
 	assert.Empty(t, cookie)
 	assert.Empty(t, cookie)
 
 
-	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
-	req.RemoteAddr = "172.16.4.12:4567"
-	ctx = jwtauth.NewContext(req.Context(), token, nil)
-	server.checkCookieExpiration(rr, req.WithContext(context.WithValue(ctx, connAddrKey, "172.16.0.1:4567")))
-	cookie = rr.Header().Get("Set-Cookie")
-	assert.Empty(t, cookie)
-
 	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
 	req.RemoteAddr = "172.16.4.16:4567"
 	req.RemoteAddr = "172.16.4.16:4567"
 	ctx = jwtauth.NewContext(req.Context(), token, nil)
 	ctx = jwtauth.NewContext(req.Context(), token, nil)
-	server.checkCookieExpiration(rr, req.WithContext(context.WithValue(ctx, connAddrKey, "172.16.4.18:4567")))
+	server.checkCookieExpiration(rr, req.WithContext(ctx))
 	cookie = rr.Header().Get("Set-Cookie")
 	cookie = rr.Header().Get("Set-Cookie")
 	assert.NotEmpty(t, cookie)
 	assert.NotEmpty(t, cookie)
 
 
@@ -1014,6 +983,111 @@ func TestJWTTokenCleanup(t *testing.T) {
 	stopJWTTokensCleanupTicker()
 	stopJWTTokensCleanupTicker()
 }
 }
 
 
+func TestProxyHeaders(t *testing.T) {
+	username := "adminTest"
+	password := "testPwd"
+	admin := dataprovider.Admin{
+		Username:    username,
+		Password:    password,
+		Permissions: []string{dataprovider.PermAdminAny},
+		Status:      1,
+		Filters: dataprovider.AdminFilters{
+			AllowList: []string{"172.19.2.0/24"},
+		},
+	}
+
+	err := dataprovider.AddAdmin(&admin)
+	assert.NoError(t, err)
+
+	testIP := "10.29.1.9"
+	validForwardedFor := "172.19.2.6"
+	b := Binding{
+		Address:         "",
+		Port:            8080,
+		EnableWebAdmin:  true,
+		EnableWebClient: false,
+		ProxyAllowed:    []string{testIP, "10.8.0.0/30"},
+	}
+	err = b.parseAllowedProxy()
+	assert.NoError(t, err)
+	server := newHttpdServer(b, "")
+	server.initializeRouter()
+	testServer := httptest.NewServer(server.router)
+	defer testServer.Close()
+
+	req, err := http.NewRequest(http.MethodGet, tokenPath, nil)
+	assert.NoError(t, err)
+	req.Header.Set("X-Forwarded-For", validForwardedFor)
+	req.Header.Set(xForwardedProto, "https")
+	req.RemoteAddr = "127.0.0.1:123"
+	req.SetBasicAuth(username, password)
+	rr := httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusUnauthorized, rr.Code)
+	assert.Contains(t, rr.Body.String(), "login from IP 127.0.0.1 not allowed")
+
+	req.RemoteAddr = testIP
+	rr = httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusOK, rr.Code)
+
+	req.RemoteAddr = "10.8.0.2"
+	rr = httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusOK, rr.Code)
+
+	form := make(url.Values)
+	form.Set("username", username)
+	form.Set("password", password)
+	form.Set(csrfFormToken, createCSRFToken())
+	req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = testIP
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	rr = httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
+	assert.Contains(t, rr.Body.String(), "login from IP 10.29.1.9 not allowed")
+
+	req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = testIP
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	req.Header.Set("X-Forwarded-For", validForwardedFor)
+	rr = httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
+	cookie := rr.Header().Get("Set-Cookie")
+	assert.NotContains(t, cookie, "Secure")
+
+	req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = testIP
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	req.Header.Set("X-Forwarded-For", validForwardedFor)
+	req.Header.Set(xForwardedProto, "https")
+	rr = httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
+	cookie = rr.Header().Get("Set-Cookie")
+	assert.Contains(t, cookie, "Secure")
+
+	req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode())))
+	assert.NoError(t, err)
+	req.RemoteAddr = testIP
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	req.Header.Set("X-Forwarded-For", validForwardedFor)
+	req.Header.Set(xForwardedProto, "http")
+	rr = httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
+	cookie = rr.Header().Get("Set-Cookie")
+	assert.NotContains(t, cookie, "Secure")
+
+	err = dataprovider.DeleteAdmin(username)
+	assert.NoError(t, err)
+}
+
 func TestWebAdminRedirect(t *testing.T) {
 func TestWebAdminRedirect(t *testing.T) {
 	b := Binding{
 	b := Binding{
 		Address:         "",
 		Address:         "",
@@ -1306,3 +1380,16 @@ func TestManageKeysInvalidClaims(t *testing.T) {
 	assert.Equal(t, http.StatusInternalServerError, rr.Code)
 	assert.Equal(t, http.StatusInternalServerError, rr.Code)
 	assert.Contains(t, rr.Body.String(), "Invalid token claims")
 	assert.Contains(t, rr.Body.String(), "Invalid token claims")
 }
 }
+
+func TestTLSReq(t *testing.T) {
+	req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil)
+	assert.NoError(t, err)
+	req.TLS = &tls.ConnectionState{}
+	assert.True(t, isTLS(req))
+	req.TLS = nil
+	ctx := context.WithValue(req.Context(), forwardedProtoKey, "https")
+	assert.True(t, isTLS(req.WithContext(ctx)))
+	ctx = context.WithValue(req.Context(), forwardedProtoKey, "http")
+	assert.False(t, isTLS(req.WithContext(ctx)))
+	assert.Equal(t, "context value forwarded proto", forwardedProtoKey.String())
+}

+ 2 - 26
httpd/middleware.go

@@ -1,23 +1,19 @@
 package httpd
 package httpd
 
 
 import (
 import (
-	"context"
 	"errors"
 	"errors"
-	"fmt"
 	"net/http"
 	"net/http"
-	"time"
 
 
 	"github.com/go-chi/jwtauth/v5"
 	"github.com/go-chi/jwtauth/v5"
 	"github.com/lestrrat-go/jwx/jwt"
 	"github.com/lestrrat-go/jwx/jwt"
 
 
-	"github.com/drakkan/sftpgo/common"
 	"github.com/drakkan/sftpgo/logger"
 	"github.com/drakkan/sftpgo/logger"
 	"github.com/drakkan/sftpgo/utils"
 	"github.com/drakkan/sftpgo/utils"
 )
 )
 
 
 var (
 var (
-	connAddrKey     = &contextKey{"connection address"}
-	errInvalidToken = errors.New("invalid JWT token")
+	forwardedProtoKey = &contextKey{"forwarded proto"}
+	errInvalidToken   = errors.New("invalid JWT token")
 )
 )
 
 
 type contextKey struct {
 type contextKey struct {
@@ -28,13 +24,6 @@ func (k *contextKey) String() string {
 	return "context value " + k.name
 	return "context value " + k.name
 }
 }
 
 
-func saveConnectionAddress(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		ctx := context.WithValue(r.Context(), connAddrKey, r.RemoteAddr)
-		next.ServeHTTP(w, r.WithContext(ctx))
-	})
-}
-
 func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
 func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
 	token, _, err := jwtauth.FromContext(r.Context())
 	token, _, err := jwtauth.FromContext(r.Context())
 
 
@@ -188,16 +177,3 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
 		next.ServeHTTP(w, r)
 		next.ServeHTTP(w, r)
 	})
 	})
 }
 }
-
-func rateLimiter(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		if delay, err := common.LimitRate(common.ProtocolHTTP, utils.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
-			delay += 499999999 * time.Nanosecond
-			w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
-			w.Header().Set("X-Retry-In", delay.String())
-			sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
-			return
-		}
-		next.ServeHTTP(w, r)
-	})
-}

+ 246 - 214
httpd/server.go

@@ -1,11 +1,13 @@
 package httpd
 package httpd
 
 
 import (
 import (
+	"context"
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
+	"net"
 	"net/http"
 	"net/http"
 	"time"
 	"time"
 
 
@@ -23,7 +25,10 @@ import (
 	"github.com/drakkan/sftpgo/version"
 	"github.com/drakkan/sftpgo/version"
 )
 )
 
 
-var compressor = middleware.NewCompressor(5)
+var (
+	compressor      = middleware.NewCompressor(5)
+	xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
+)
 
 
 type httpdServer struct {
 type httpdServer struct {
 	binding         Binding
 	binding         Binding
@@ -111,10 +116,6 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
 	r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
 
 
-	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
-	common.Connections.AddClientConnection(ipAddr)
-	defer common.Connections.RemoveClientConnection(ipAddr)
-
 	if err := r.ParseForm(); err != nil {
 	if err := r.ParseForm(); err != nil {
 		renderClientLoginPage(w, err.Error())
 		renderClientLoginPage(w, err.Error())
 		return
 		return
@@ -130,16 +131,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 		return
 		return
 	}
 	}
 
 
-	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
-		logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
-		renderClientLoginPage(w, "configured connections limit reached")
-		return
-	}
-	if common.IsBanned(ipAddr) {
-		renderClientLoginPage(w, "your IP address is banned")
-		return
-	}
-
+	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
 	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
 	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolHTTP); err != nil {
 		renderClientLoginPage(w, fmt.Sprintf("access denied by post connect hook: %v", err))
 		renderClientLoginPage(w, fmt.Sprintf("access denied by post connect hook: %v", err))
 		return
 		return
@@ -204,14 +196,6 @@ func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Req
 		renderLoginPage(w, err.Error())
 		renderLoginPage(w, err.Error())
 		return
 		return
 	}
 	}
-	if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
-		if connAddr != r.RemoteAddr {
-			if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) {
-				renderLoginPage(w, fmt.Sprintf("Login from IP %v is not allowed", connAddr))
-				return
-			}
-		}
-	}
 	c := jwtTokenClaims{
 	c := jwtTokenClaims{
 		Username:    admin.Username,
 		Username:    admin.Username,
 		Permissions: admin.Permissions,
 		Permissions: admin.Permissions,
@@ -246,19 +230,10 @@ func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	s.checkAddrAndSendToken(w, r, admin)
+	s.generateAndSendToken(w, r, admin)
 }
 }
 
 
-func (s *httpdServer) checkAddrAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin) {
-	if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
-		if connAddr != r.RemoteAddr {
-			if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) {
-				sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden)
-				return
-			}
-		}
-	}
-
+func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin) {
 	c := jwtTokenClaims{
 	c := jwtTokenClaims{
 		Username:    admin.Username,
 		Username:    admin.Username,
 		Permissions: admin.Permissions,
 		Permissions: admin.Permissions,
@@ -330,15 +305,6 @@ func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request,
 		logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie", admin.Username, r.RemoteAddr)
 		logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie", admin.Username, r.RemoteAddr)
 		return
 		return
 	}
 	}
-	if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
-		if connAddr != r.RemoteAddr {
-			if !admin.CanLoginFromIP(utils.GetIPFromRemoteAddress(connAddr)) {
-				logger.Debug(logSender, "", "admin %#v cannot login from %v, unable to refresh cookie",
-					admin.Username, connAddr)
-				return
-			}
-		}
-	}
 	logger.Debug(logSender, "", "cookie refreshed for admin %#v", admin.Username)
 	logger.Debug(logSender, "", "cookie refreshed for admin %#v", admin.Username)
 	tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebAdmin) //nolint:errcheck
 	tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebAdmin) //nolint:errcheck
 }
 }
@@ -357,200 +323,266 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request {
 	return r
 	return r
 }
 }
 
 
+func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
+		ip := net.ParseIP(ipAddr)
+		if ip != nil {
+			for _, allow := range s.binding.allowHeadersFrom {
+				if allow(ip) {
+					parsedIP := utils.GetRealIP(r)
+					if parsedIP != "" {
+						ipAddr = parsedIP
+						r.RemoteAddr = ipAddr
+					}
+					if forwardedProto := r.Header.Get(xForwardedProto); forwardedProto != "" {
+						ctx := context.WithValue(r.Context(), forwardedProtoKey, forwardedProto)
+						r = r.WithContext(ctx)
+					}
+					break
+				}
+			}
+		}
+
+		common.Connections.AddClientConnection(ipAddr)
+		defer common.Connections.RemoveClientConnection(ipAddr)
+
+		if !common.Connections.IsNewConnectionAllowed(ipAddr) {
+			logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
+			s.sendForbiddenResponse(w, r, "configured connections limit reached")
+			return
+		}
+		if common.IsBanned(ipAddr) {
+			s.sendForbiddenResponse(w, r, "your IP address is banned")
+			return
+		}
+		if delay, err := common.LimitRate(common.ProtocolHTTP, ipAddr); err != nil {
+			delay += 499999999 * time.Nanosecond
+			w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
+			w.Header().Set("X-Retry-In", delay.String())
+			s.sendTooManyRequestResponse(w, r, err)
+			return
+		}
+
+		next.ServeHTTP(w, r)
+	})
+}
+
+func (s *httpdServer) sendTooManyRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
+	if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
+		r = s.updateContextFromCookie(r)
+		if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
+			renderClientMessagePage(w, r, http.StatusText(http.StatusTooManyRequests), "Rate limit exceeded",
+				http.StatusTooManyRequests, err, "")
+			return
+		}
+		renderMessagePage(w, r, http.StatusText(http.StatusTooManyRequests), "Rate limit exceeded", http.StatusTooManyRequests,
+			err, "")
+		return
+	}
+	sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
+}
+
+func (s *httpdServer) sendForbiddenResponse(w http.ResponseWriter, r *http.Request, message string) {
+	if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
+		r = s.updateContextFromCookie(r)
+		if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
+			renderClientForbiddenPage(w, r, message)
+			return
+		}
+		renderForbiddenPage(w, r, message)
+		return
+	}
+	sendAPIResponse(w, r, errors.New(message), message, http.StatusForbidden)
+}
+
 func (s *httpdServer) initializeRouter() {
 func (s *httpdServer) initializeRouter() {
 	s.tokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil)
 	s.tokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil)
 	s.router = chi.NewRouter()
 	s.router = chi.NewRouter()
 
 
-	s.router.Use(saveConnectionAddress)
+	s.router.Use(middleware.RequestID)
+	s.router.Use(logger.NewStructuredLogger(logger.GetLogger()))
+	s.router.Use(middleware.Recoverer)
+	s.router.Use(s.checkConnection)
 	s.router.Use(middleware.GetHead)
 	s.router.Use(middleware.GetHead)
 	s.router.Use(middleware.StripSlashes)
 	s.router.Use(middleware.StripSlashes)
-	s.router.Use(middleware.RealIP)
-	s.router.Use(rateLimiter)
-
-	s.router.Group(func(r chi.Router) {
-		r.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) {
-			render.PlainText(w, r, "ok")
-		})
-	})
 
 
-	s.router.Group(func(router chi.Router) {
-		router.Use(middleware.RequestID)
-		router.Use(logger.NewStructuredLogger(logger.GetLogger()))
-		router.Use(middleware.Recoverer)
-
-		router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
-				r = s.updateContextFromCookie(r)
-				if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
-					renderClientNotFoundPage(w, r, nil)
-					return
-				}
-				renderNotFoundPage(w, r, nil)
+	s.router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) {
+			r = s.updateContextFromCookie(r)
+			if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) {
+				renderClientNotFoundPage(w, r, nil)
 				return
 				return
 			}
 			}
-			sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound)
-		}))
+			renderNotFoundPage(w, r, nil)
+			return
+		}
+		sendAPIResponse(w, r, nil, http.StatusText(http.StatusNotFound), http.StatusNotFound)
+	}))
 
 
-		router.Get(tokenPath, s.getToken)
+	s.router.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) {
+		render.PlainText(w, r, "ok")
+	})
 
 
-		router.Group(func(router chi.Router) {
-			router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
-			router.Use(jwtAuthenticatorAPI)
+	s.router.Get(tokenPath, s.getToken)
 
 
-			router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
-				render.JSON(w, r, version.Get())
-			})
+	s.router.Group(func(router chi.Router) {
+		router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader))
+		router.Use(jwtAuthenticatorAPI)
 
 
-			router.Get(logoutPath, s.logout)
-			router.Put(adminPwdPath, changeAdminPassword)
-
-			router.With(checkPerm(dataprovider.PermAdminViewServerStatus)).
-				Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) {
-					render.JSON(w, r, getServicesStatus())
-				})
-
-			router.With(checkPerm(dataprovider.PermAdminViewConnections)).
-				Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) {
-					render.JSON(w, r, common.Connections.GetStats())
-				})
-
-			router.With(checkPerm(dataprovider.PermAdminCloseConnections)).
-				Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection)
-			router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanPath, getQuotaScans)
-			router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanPath, startQuotaScan)
-			router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanVFolderPath, getVFolderQuotaScans)
-			router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanVFolderPath, startVFolderQuotaScan)
-			router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers)
-			router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(userPath, addUser)
-			router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername)
-			router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser)
-			router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser)
-			router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath, getFolders)
-			router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath+"/{name}", getFolderByName)
-			router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(folderPath, addFolder)
-			router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(folderPath+"/{name}", updateFolder)
-			router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(folderPath+"/{name}", deleteFolder)
-			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(dumpDataPath, dumpData)
-			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData)
-			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest)
-			router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsage)
-			router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateVFolderQuotaUsage)
-			router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderBanTime, getBanTime)
-			router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderScore, getScore)
-			router.With(checkPerm(dataprovider.PermAdminManageDefender)).Post(defenderUnban, unban)
-			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath, getAdmins)
-			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(adminPath, addAdmin)
-			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath+"/{username}", getAdminByUsername)
-			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Put(adminPath+"/{username}", updateAdmin)
-			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Delete(adminPath+"/{username}", deleteAdmin)
+		router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) {
+			render.JSON(w, r, version.Get())
 		})
 		})
 
 
-		if s.enableWebAdmin || s.enableWebClient {
-			router.Group(func(router chi.Router) {
-				router.Use(compressor.Handler)
-				fileServer(router, webStaticFilesPath, http.Dir(s.staticFilesPath))
+		router.Get(logoutPath, s.logout)
+		router.Put(adminPwdPath, changeAdminPassword)
+
+		router.With(checkPerm(dataprovider.PermAdminViewServerStatus)).
+			Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) {
+				render.JSON(w, r, getServicesStatus())
 			})
 			})
-			if s.enableWebClient {
-				router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
-					http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
-				})
-				router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
-					http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
-				})
-			} else {
-				router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
-					http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
-				})
-				router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
-					http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
-				})
-			}
-		}
 
 
+		router.With(checkPerm(dataprovider.PermAdminViewConnections)).
+			Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) {
+				render.JSON(w, r, common.Connections.GetStats())
+			})
+
+		router.With(checkPerm(dataprovider.PermAdminCloseConnections)).
+			Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection)
+		router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanPath, getQuotaScans)
+		router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanPath, startQuotaScan)
+		router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Get(quotaScanVFolderPath, getVFolderQuotaScans)
+		router.With(checkPerm(dataprovider.PermAdminQuotaScans)).Post(quotaScanVFolderPath, startVFolderQuotaScan)
+		router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers)
+		router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(userPath, addUser)
+		router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername)
+		router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser)
+		router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser)
+		router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath, getFolders)
+		router.With(checkPerm(dataprovider.PermAdminViewUsers)).Get(folderPath+"/{name}", getFolderByName)
+		router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(folderPath, addFolder)
+		router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(folderPath+"/{name}", updateFolder)
+		router.With(checkPerm(dataprovider.PermAdminDeleteUsers)).Delete(folderPath+"/{name}", deleteFolder)
+		router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(dumpDataPath, dumpData)
+		router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData)
+		router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest)
+		router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsage)
+		router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateVFolderQuotaUsage)
+		router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderBanTime, getBanTime)
+		router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderScore, getScore)
+		router.With(checkPerm(dataprovider.PermAdminManageDefender)).Post(defenderUnban, unban)
+		router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath, getAdmins)
+		router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(adminPath, addAdmin)
+		router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Get(adminPath+"/{username}", getAdminByUsername)
+		router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Put(adminPath+"/{username}", updateAdmin)
+		router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Delete(adminPath+"/{username}", deleteAdmin)
+	})
+
+	if s.enableWebAdmin || s.enableWebClient {
+		s.router.Group(func(router chi.Router) {
+			router.Use(compressor.Handler)
+			fileServer(router, webStaticFilesPath, http.Dir(s.staticFilesPath))
+		})
 		if s.enableWebClient {
 		if s.enableWebClient {
-			router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) {
+			s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
 				http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
 				http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
 			})
 			})
-			router.Get(webClientLoginPath, handleClientWebLogin)
-			router.Post(webClientLoginPath, s.handleWebClientLoginPost)
-
-			router.Group(func(router chi.Router) {
-				router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
-				router.Use(jwtAuthenticatorWebClient)
-
-				router.Get(webClientLogoutPath, handleWebClientLogout)
-				router.With(s.refreshCookie).Get(webClientFilesPath, handleClientGetFiles)
-				router.With(s.refreshCookie).Get(webClientCredentialsPath, handleClientGetCredentials)
-				router.Post(webChangeClientPwdPath, handleWebClientChangePwdPost)
-				router.With(checkClientPerm(dataprovider.WebClientPubKeyChangeDisabled)).
-					Post(webChangeClientKeysPath, handleWebClientManageKeysPost)
+			s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
+				http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
 			})
 			})
-		}
-
-		if s.enableWebAdmin {
-			router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) {
+		} else {
+			s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) {
 				http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
 				http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
 			})
 			})
-			router.Get(webLoginPath, handleWebLogin)
-			router.Post(webLoginPath, s.handleWebAdminLoginPost)
-
-			router.Group(func(router chi.Router) {
-				router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
-				router.Use(jwtAuthenticatorWebAdmin)
-
-				router.Get(webLogoutPath, handleWebLogout)
-				router.With(s.refreshCookie).Get(webChangeAdminPwdPath, handleWebAdminChangePwd)
-				router.Post(webChangeAdminPwdPath, handleWebAdminChangePwdPost)
-				router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
-					Get(webUsersPath, handleGetWebUsers)
-				router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
-					Get(webUserPath, handleWebAddUserGet)
-				router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
-					Get(webUserPath+"/{username}", handleWebUpdateUserGet)
-				router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webUserPath, handleWebAddUserPost)
-				router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", handleWebUpdateUserPost)
-				router.With(checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie).
-					Get(webConnectionsPath, handleWebGetConnections)
-				router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
-					Get(webFoldersPath, handleWebGetFolders)
-				router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
-					Get(webFolderPath, handleWebAddFolderGet)
-				router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webFolderPath, handleWebAddFolderPost)
-				router.With(checkPerm(dataprovider.PermAdminViewServerStatus), s.refreshCookie).
-					Get(webStatusPath, handleWebGetStatus)
-				router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
-					Get(webAdminsPath, handleGetWebAdmins)
-				router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
-					Get(webAdminPath, handleWebAddAdminGet)
-				router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
-					Get(webAdminPath+"/{username}", handleWebUpdateAdminGet)
-				router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, handleWebAddAdminPost)
-				router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", handleWebUpdateAdminPost)
-				router.With(checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader).
-					Delete(webAdminPath+"/{username}", deleteAdmin)
-				router.With(checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader).
-					Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection)
-				router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
-					Get(webFolderPath+"/{name}", handleWebUpdateFolderGet)
-				router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webFolderPath+"/{name}", handleWebUpdateFolderPost)
-				router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
-					Delete(webFolderPath+"/{name}", deleteFolder)
-				router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
-					Post(webScanVFolderPath, startVFolderQuotaScan)
-				router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
-					Delete(webUserPath+"/{username}", deleteUser)
-				router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
-					Post(webQuotaScanPath, startQuotaScan)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, handleWebMaintenance)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webRestorePath, handleWebRestore)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
-					Get(webTemplateUser, handleWebTemplateUserGet)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateUser, handleWebTemplateUserPost)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
-					Get(webTemplateFolder, handleWebTemplateFolderGet)
-				router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateFolder, handleWebTemplateFolderPost)
+			s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) {
+				http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
 			})
 			})
 		}
 		}
-	})
+	}
+
+	if s.enableWebClient {
+		s.router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) {
+			http.Redirect(w, r, webClientLoginPath, http.StatusMovedPermanently)
+		})
+		s.router.Get(webClientLoginPath, handleClientWebLogin)
+		s.router.Post(webClientLoginPath, s.handleWebClientLoginPost)
+
+		s.router.Group(func(router chi.Router) {
+			router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
+			router.Use(jwtAuthenticatorWebClient)
+
+			router.Get(webClientLogoutPath, handleWebClientLogout)
+			router.With(s.refreshCookie).Get(webClientFilesPath, handleClientGetFiles)
+			router.With(s.refreshCookie).Get(webClientCredentialsPath, handleClientGetCredentials)
+			router.Post(webChangeClientPwdPath, handleWebClientChangePwdPost)
+			router.With(checkClientPerm(dataprovider.WebClientPubKeyChangeDisabled)).
+				Post(webChangeClientKeysPath, handleWebClientManageKeysPost)
+		})
+	}
+
+	if s.enableWebAdmin {
+		s.router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) {
+			http.Redirect(w, r, webLoginPath, http.StatusMovedPermanently)
+		})
+		s.router.Get(webLoginPath, handleWebLogin)
+		s.router.Post(webLoginPath, s.handleWebAdminLoginPost)
+
+		s.router.Group(func(router chi.Router) {
+			router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie))
+			router.Use(jwtAuthenticatorWebAdmin)
+
+			router.Get(webLogoutPath, handleWebLogout)
+			router.With(s.refreshCookie).Get(webChangeAdminPwdPath, handleWebAdminChangePwd)
+			router.Post(webChangeAdminPwdPath, handleWebAdminChangePwdPost)
+			router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
+				Get(webUsersPath, handleGetWebUsers)
+			router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
+				Get(webUserPath, handleWebAddUserGet)
+			router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
+				Get(webUserPath+"/{username}", handleWebUpdateUserGet)
+			router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webUserPath, handleWebAddUserPost)
+			router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", handleWebUpdateUserPost)
+			router.With(checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie).
+				Get(webConnectionsPath, handleWebGetConnections)
+			router.With(checkPerm(dataprovider.PermAdminViewUsers), s.refreshCookie).
+				Get(webFoldersPath, handleWebGetFolders)
+			router.With(checkPerm(dataprovider.PermAdminAddUsers), s.refreshCookie).
+				Get(webFolderPath, handleWebAddFolderGet)
+			router.With(checkPerm(dataprovider.PermAdminAddUsers)).Post(webFolderPath, handleWebAddFolderPost)
+			router.With(checkPerm(dataprovider.PermAdminViewServerStatus), s.refreshCookie).
+				Get(webStatusPath, handleWebGetStatus)
+			router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
+				Get(webAdminsPath, handleGetWebAdmins)
+			router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
+				Get(webAdminPath, handleWebAddAdminGet)
+			router.With(checkPerm(dataprovider.PermAdminManageAdmins), s.refreshCookie).
+				Get(webAdminPath+"/{username}", handleWebUpdateAdminGet)
+			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, handleWebAddAdminPost)
+			router.With(checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", handleWebUpdateAdminPost)
+			router.With(checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader).
+				Delete(webAdminPath+"/{username}", deleteAdmin)
+			router.With(checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader).
+				Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection)
+			router.With(checkPerm(dataprovider.PermAdminChangeUsers), s.refreshCookie).
+				Get(webFolderPath+"/{name}", handleWebUpdateFolderGet)
+			router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Post(webFolderPath+"/{name}", handleWebUpdateFolderPost)
+			router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
+				Delete(webFolderPath+"/{name}", deleteFolder)
+			router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
+				Post(webScanVFolderPath, startVFolderQuotaScan)
+			router.With(checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader).
+				Delete(webUserPath+"/{username}", deleteUser)
+			router.With(checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader).
+				Post(webQuotaScanPath, startQuotaScan)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, handleWebMaintenance)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webRestorePath, handleWebRestore)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
+				Get(webTemplateUser, handleWebTemplateUserGet)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateUser, handleWebTemplateUserPost)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).
+				Get(webTemplateFolder, handleWebTemplateFolderGet)
+			router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(webTemplateFolder, handleWebTemplateFolderPost)
+		})
+	}
 }
 }

+ 0 - 24
httpd/webclient.go

@@ -279,24 +279,11 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
 }
 }
 
 
 func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
 func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
-	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
-	common.Connections.AddClientConnection(ipAddr)
-	defer common.Connections.RemoveClientConnection(ipAddr)
-
 	claims, err := getTokenClaims(r)
 	claims, err := getTokenClaims(r)
 	if err != nil || claims.Username == "" {
 	if err != nil || claims.Username == "" {
 		renderClientForbiddenPage(w, r, "Invalid token claims")
 		renderClientForbiddenPage(w, r, "Invalid token claims")
 		return
 		return
 	}
 	}
-	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
-		logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
-		renderClientForbiddenPage(w, r, "configured connections limit reached")
-		return
-	}
-	if common.IsBanned(ipAddr) {
-		renderClientForbiddenPage(w, r, "your IP address is banned")
-		return
-	}
 
 
 	user, err := dataprovider.UserExists(claims.Username)
 	user, err := dataprovider.UserExists(claims.Username)
 	if err != nil {
 	if err != nil {
@@ -635,16 +622,5 @@ func checkWebClientUser(user *dataprovider.User, r *http.Request, connectionID s
 		logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
 		logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
 		return fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
 		return fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
 	}
 	}
-	if connAddr, ok := r.Context().Value(connAddrKey).(string); ok {
-		if connAddr != r.RemoteAddr {
-			connIPAddr := utils.GetIPFromRemoteAddress(connAddr)
-			if common.IsBanned(connIPAddr) {
-				return errors.New("your IP address is banned")
-			}
-			if !user.IsLoginFromAddrAllowed(connIPAddr) {
-				return fmt.Errorf("login from IP %v is not allowed", connIPAddr)
-			}
-		}
-	}
 	return nil
 	return nil
 }
 }

+ 17 - 11
sftpd/sftpd_test.go

@@ -553,7 +553,7 @@ func TestDefender(t *testing.T) {
 	_, _, err = getSftpClient(user, usePubKey)
 	_, _, err = getSftpClient(user, usePubKey)
 	assert.Error(t, err)
 	assert.Error(t, err)
 
 
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)
@@ -2907,21 +2907,24 @@ func TestMaxConnections(t *testing.T) {
 	common.Config.MaxTotalConnections = 1
 	common.Config.MaxTotalConnections = 1
 
 
 	usePubKey := true
 	usePubKey := true
-	u := getTestUser(usePubKey)
-	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
+	user := getTestUser(usePubKey)
+	err := dataprovider.AddUser(&user)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
+	user.Password = ""
 	conn, client, err := getSftpClient(user, usePubKey)
 	conn, client, err := getSftpClient(user, usePubKey)
 	if assert.NoError(t, err) {
 	if assert.NoError(t, err) {
-		defer conn.Close()
-		defer client.Close()
 		assert.NoError(t, checkBasicSFTP(client))
 		assert.NoError(t, checkBasicSFTP(client))
 		s, c, err := getSftpClient(user, usePubKey)
 		s, c, err := getSftpClient(user, usePubKey)
 		if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
 		if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
 			c.Close()
 			c.Close()
 			s.Close()
 			s.Close()
 		}
 		}
+		err = client.Close()
+		assert.NoError(t, err)
+		err = conn.Close()
+		assert.NoError(t, err)
 	}
 	}
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)
@@ -2935,21 +2938,24 @@ func TestMaxPerHostConnections(t *testing.T) {
 	common.Config.MaxPerHostConnections = 1
 	common.Config.MaxPerHostConnections = 1
 
 
 	usePubKey := true
 	usePubKey := true
-	u := getTestUser(usePubKey)
-	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
+	user := getTestUser(usePubKey)
+	err := dataprovider.AddUser(&user)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
+	user.Password = ""
 	conn, client, err := getSftpClient(user, usePubKey)
 	conn, client, err := getSftpClient(user, usePubKey)
 	if assert.NoError(t, err) {
 	if assert.NoError(t, err) {
-		defer conn.Close()
-		defer client.Close()
 		assert.NoError(t, checkBasicSFTP(client))
 		assert.NoError(t, checkBasicSFTP(client))
 		s, c, err := getSftpClient(user, usePubKey)
 		s, c, err := getSftpClient(user, usePubKey)
 		if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
 		if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
 			c.Close()
 			c.Close()
 			s.Close()
 			s.Close()
 		}
 		}
+		err = client.Close()
+		assert.NoError(t, err)
+		err = conn.Close()
+		assert.NoError(t, err)
 	}
 	}
-	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	err = dataprovider.DeleteUser(user.Username)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.NoError(t, err)

+ 5 - 3
sftpgo.json

@@ -107,7 +107,8 @@
         "enable_https": false,
         "enable_https": false,
         "client_auth_type": 0,
         "client_auth_type": 0,
         "tls_cipher_suites": [],
         "tls_cipher_suites": [],
-        "prefix": ""
+        "prefix": "",
+        "proxy_allowed": []
       }
       }
     ],
     ],
     "certificate_file": "",
     "certificate_file": "",
@@ -180,12 +181,13 @@
     "bindings": [
     "bindings": [
       {
       {
         "port": 8080,
         "port": 8080,
-        "address": "127.0.0.1",
+        "address": "",
         "enable_web_admin": true,
         "enable_web_admin": true,
         "enable_web_client": true,
         "enable_web_client": true,
         "enable_https": false,
         "enable_https": false,
         "client_auth_type": 0,
         "client_auth_type": 0,
-        "tls_cipher_suites": []
+        "tls_cipher_suites": [],
+        "proxy_allowed": []
       }
       }
     ],
     ],
     "templates_path": "templates",
     "templates_path": "templates",

+ 2 - 0
templates/webadmin/base.html

@@ -209,6 +209,7 @@
         <i class="fas fa-angle-up"></i>
         <i class="fas fa-angle-up"></i>
     </a>
     </a>
 
 
+    {{if .LoggedAdmin.Username}}
     <!-- Logout Modal-->
     <!-- Logout Modal-->
     <div class="modal fade" id="logoutModal" tabindex="-1" role="dialog" aria-labelledby="modalLabel"
     <div class="modal fade" id="logoutModal" tabindex="-1" role="dialog" aria-labelledby="modalLabel"
         aria-hidden="true">
         aria-hidden="true">
@@ -228,6 +229,7 @@
             </div>
             </div>
         </div>
         </div>
     </div>
     </div>
+    {{end}}
 
 
     {{block "dialog" .}}{{end}}
     {{block "dialog" .}}{{end}}
 
 

+ 2 - 0
templates/webclient/base.html

@@ -168,6 +168,7 @@
         <i class="fas fa-angle-up"></i>
         <i class="fas fa-angle-up"></i>
     </a>
     </a>
 
 
+    {{if .LoggedUser.Username}}
     <!-- Logout Modal-->
     <!-- Logout Modal-->
     <div class="modal fade" id="logoutModal" tabindex="-1" role="dialog" aria-labelledby="modalLabel"
     <div class="modal fade" id="logoutModal" tabindex="-1" role="dialog" aria-labelledby="modalLabel"
         aria-hidden="true">
         aria-hidden="true">
@@ -187,6 +188,7 @@
             </div>
             </div>
         </div>
         </div>
     </div>
     </div>
+    {{end}}
 
 
     {{block "dialog" .}}{{end}}
     {{block "dialog" .}}{{end}}
 
 

+ 1 - 1
templates/webclient/credentials.html

@@ -57,7 +57,7 @@
             <div class="form-group row">
             <div class="form-group row">
                 <label for="idPublicKeys" class="col-sm-2 col-form-label">Keys</label>
                 <label for="idPublicKeys" class="col-sm-2 col-form-label">Keys</label>
                 <div class="col-sm-10">
                 <div class="col-sm-10">
-                    <textarea class="form-control" id="idPublicKeys" name="public_keys" rows="3"
+                    <textarea class="form-control" id="idPublicKeys" name="public_keys" rows="5"
                         aria-describedby="pkHelpBlock">{{range .PublicKeys}}{{.}}&#10;{{end}}</textarea>
                         aria-describedby="pkHelpBlock">{{range .PublicKeys}}{{.}}&#10;{{end}}</textarea>
                     <small id="pkHelpBlock" class="form-text text-muted">
                     <small id="pkHelpBlock" class="form-text text-muted">
                         One public key per line
                         One public key per line

+ 50 - 0
utils/utils.go

@@ -38,6 +38,11 @@ const (
 	osWindows = "windows"
 	osWindows = "windows"
 )
 )
 
 
+var (
+	xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
+	xRealIP       = http.CanonicalHeaderKey("X-Real-IP")
+)
+
 // IsStringInSlice searches a string in a slice and returns true if the string is found
 // IsStringInSlice searches a string in a slice and returns true if the string is found
 func IsStringInSlice(obj string, list []string) bool {
 func IsStringInSlice(obj string, list []string) bool {
 	for i := 0; i < len(list); i++ {
 	for i := 0; i < len(list); i++ {
@@ -516,3 +521,48 @@ func GetSSHPublicKeyAsString(pubKey []byte) (string, error) {
 	}
 	}
 	return string(ssh.MarshalAuthorizedKey(k)), nil
 	return string(ssh.MarshalAuthorizedKey(k)), nil
 }
 }
+
+// GetRealIP returns the ip address as result of parsing either the
+// X-Real-IP header or the X-Forwarded-For header
+func GetRealIP(r *http.Request) string {
+	var ip string
+
+	if xrip := r.Header.Get(xRealIP); xrip != "" {
+		ip = xrip
+	} else if xff := r.Header.Get(xForwardedFor); xff != "" {
+		i := strings.Index(xff, ", ")
+		if i == -1 {
+			i = len(xff)
+		}
+		ip = strings.TrimSpace(xff[:i])
+	}
+	if net.ParseIP(ip) == nil {
+		return ""
+	}
+
+	return ip
+}
+
+// ParseAllowedIPAndRanges returns a list of functions that allow to find if a
+func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) {
+	res := make([]func(net.IP) bool, len(allowed))
+	for i, allowFrom := range allowed {
+		if strings.LastIndex(allowFrom, "/") > 0 {
+			_, ipRange, err := net.ParseCIDR(allowFrom)
+			if err != nil {
+				return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err)
+			}
+
+			res[i] = ipRange.Contains
+		} else {
+			allowed := net.ParseIP(allowFrom)
+			if allowed == nil {
+				return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom)
+			}
+
+			res[i] = allowed.Equal
+		}
+	}
+
+	return res, nil
+}

+ 51 - 18
webdavd/internal_test.go

@@ -23,6 +23,7 @@ import (
 	"github.com/drakkan/sftpgo/common"
 	"github.com/drakkan/sftpgo/common"
 	"github.com/drakkan/sftpgo/dataprovider"
 	"github.com/drakkan/sftpgo/dataprovider"
 	"github.com/drakkan/sftpgo/kms"
 	"github.com/drakkan/sftpgo/kms"
+	"github.com/drakkan/sftpgo/utils"
 	"github.com/drakkan/sftpgo/vfs"
 	"github.com/drakkan/sftpgo/vfs"
 )
 )
 
 
@@ -409,36 +410,68 @@ func TestUserInvalidParams(t *testing.T) {
 }
 }
 
 
 func TestRemoteAddress(t *testing.T) {
 func TestRemoteAddress(t *testing.T) {
-	req, err := http.NewRequest(http.MethodGet, "/username", nil)
-	assert.NoError(t, err)
-	assert.Empty(t, req.RemoteAddr)
-
 	remoteAddr1 := "100.100.100.100"
 	remoteAddr1 := "100.100.100.100"
 	remoteAddr2 := "172.172.172.172"
 	remoteAddr2 := "172.172.172.172"
 
 
+	c := &Configuration{
+		Bindings: []Binding{
+			{
+				Port:         9000,
+				ProxyAllowed: []string{remoteAddr2, "10.8.0.0/30"},
+			},
+		},
+	}
+
+	server := webDavServer{
+		config:  c,
+		binding: c.Bindings[0],
+	}
+	err := server.binding.parseAllowedProxy()
+	assert.NoError(t, err)
+	req, err := http.NewRequest(http.MethodGet, "/", nil)
+	assert.NoError(t, err)
+	assert.Empty(t, req.RemoteAddr)
+
 	req.Header.Set("X-Forwarded-For", remoteAddr1)
 	req.Header.Set("X-Forwarded-For", remoteAddr1)
-	checkRemoteAddress(req)
-	assert.Equal(t, remoteAddr1, req.RemoteAddr)
+	ip := utils.GetRealIP(req)
+	assert.Equal(t, remoteAddr1, ip)
+	// this will be ignore, remoteAddr1 is not allowed to se this header
+	req.Header.Set("X-Forwarded-For", remoteAddr2)
+	req.RemoteAddr = remoteAddr1
+	ip = server.checkRemoteAddress(req)
+	assert.Equal(t, remoteAddr1, ip)
 	req.RemoteAddr = ""
 	req.RemoteAddr = ""
+	ip = server.checkRemoteAddress(req)
+	assert.Empty(t, ip)
 
 
 	req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", remoteAddr2, remoteAddr1))
 	req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", remoteAddr2, remoteAddr1))
-	checkRemoteAddress(req)
-	assert.Equal(t, remoteAddr2, req.RemoteAddr)
+	ip = utils.GetRealIP(req)
+	assert.Equal(t, remoteAddr2, ip)
+
+	req.RemoteAddr = remoteAddr2
+	req.Header.Set("X-Forwarded-For", fmt.Sprintf("%v, %v", "12.34.56.78", "172.16.2.4"))
+	ip = server.checkRemoteAddress(req)
+	assert.Equal(t, "12.34.56.78", ip)
+	assert.Equal(t, ip, req.RemoteAddr)
+
+	req.RemoteAddr = "10.8.0.2"
+	req.Header.Set("X-Forwarded-For", remoteAddr1)
+	ip = server.checkRemoteAddress(req)
+	assert.Equal(t, remoteAddr1, ip)
+	assert.Equal(t, ip, req.RemoteAddr)
+
+	req.RemoteAddr = "10.8.0.3"
+	req.Header.Set("X-Forwarded-For", "not an ip")
+	ip = server.checkRemoteAddress(req)
+	assert.Equal(t, "10.8.0.3", ip)
+	assert.Equal(t, ip, req.RemoteAddr)
 
 
 	req.Header.Del("X-Forwarded-For")
 	req.Header.Del("X-Forwarded-For")
 	req.RemoteAddr = ""
 	req.RemoteAddr = ""
 	req.Header.Set("X-Real-IP", remoteAddr1)
 	req.Header.Set("X-Real-IP", remoteAddr1)
-	checkRemoteAddress(req)
-	assert.Equal(t, remoteAddr1, req.RemoteAddr)
+	ip = utils.GetRealIP(req)
+	assert.Equal(t, remoteAddr1, ip)
 	req.RemoteAddr = ""
 	req.RemoteAddr = ""
-
-	oldValue := common.Config.ProxyProtocol
-	common.Config.ProxyProtocol = 1
-
-	checkRemoteAddress(req)
-	assert.Empty(t, req.RemoteAddr)
-
-	common.Config.ProxyProtocol = oldValue
 }
 }
 
 
 func TestConnWithNilRequest(t *testing.T) {
 func TestConnWithNilRequest(t *testing.T) {

+ 21 - 28
webdavd/server.go

@@ -7,11 +7,11 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
+	"net"
 	"net/http"
 	"net/http"
 	"path"
 	"path"
 	"path/filepath"
 	"path/filepath"
 	"runtime/debug"
 	"runtime/debug"
-	"strings"
 	"time"
 	"time"
 
 
 	"github.com/go-chi/chi/v5/middleware"
 	"github.com/go-chi/chi/v5/middleware"
@@ -27,9 +27,7 @@ import (
 )
 )
 
 
 var (
 var (
-	err401        = errors.New("Unauthorized")
-	xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
-	xRealIP       = http.CanonicalHeaderKey("X-Real-IP")
+	err401 = errors.New("Unauthorized")
 )
 )
 
 
 type webDavServer struct {
 type webDavServer struct {
@@ -145,8 +143,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 		}
 	}()
 	}()
 
 
-	checkRemoteAddress(r)
-	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
+	ipAddr := s.checkRemoteAddress(r)
 
 
 	common.Connections.AddClientConnection(ipAddr)
 	common.Connections.AddClientConnection(ipAddr)
 	defer common.Connections.RemoveClientConnection(ipAddr)
 	defer common.Connections.RemoveClientConnection(ipAddr)
@@ -327,6 +324,24 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
 	return connID, nil
 	return connID, nil
 }
 }
 
 
+func (s *webDavServer) checkRemoteAddress(r *http.Request) string {
+	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
+	ip := net.ParseIP(ipAddr)
+	if ip != nil {
+		for _, allow := range s.binding.allowHeadersFrom {
+			if allow(ip) {
+				parsedIP := utils.GetRealIP(r)
+				if parsedIP != "" {
+					ipAddr = parsedIP
+					r.RemoteAddr = ipAddr
+				}
+				break
+			}
+		}
+	}
+	return ipAddr
+}
+
 func writeLog(r *http.Request, err error) {
 func writeLog(r *http.Request, err error) {
 	scheme := "http"
 	scheme := "http"
 	if r.TLS != nil {
 	if r.TLS != nil {
@@ -352,28 +367,6 @@ func writeLog(r *http.Request, err error) {
 		Send()
 		Send()
 }
 }
 
 
-func checkRemoteAddress(r *http.Request) {
-	if common.Config.ProxyProtocol != 0 {
-		return
-	}
-
-	var ip string
-
-	if xrip := r.Header.Get(xRealIP); xrip != "" {
-		ip = xrip
-	} else if xff := r.Header.Get(xForwardedFor); xff != "" {
-		i := strings.Index(xff, ", ")
-		if i == -1 {
-			i = len(xff)
-		}
-		ip = strings.TrimSpace(xff[:i])
-	}
-
-	if len(ip) > 0 {
-		r.RemoteAddr = ip
-	}
-}
-
 func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error) {
 func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error) {
 	metrics.AddLoginAttempt(loginMethod)
 	metrics.AddLoginAttempt(loginMethod)
 	if err != nil {
 	if err != nil {

+ 16 - 0
webdavd/webdavd.go

@@ -3,6 +3,7 @@ package webdavd
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"net"
 	"path/filepath"
 	"path/filepath"
 
 
 	"github.com/go-chi/chi/v5/middleware"
 	"github.com/go-chi/chi/v5/middleware"
@@ -90,6 +91,18 @@ type Binding struct {
 	// Prefix for WebDAV resources, if empty WebDAV resources will be available at the
 	// Prefix for WebDAV resources, if empty WebDAV resources will be available at the
 	// root ("/") URI. If defined it must be an absolute URI.
 	// root ("/") URI. If defined it must be an absolute URI.
 	Prefix string `json:"prefix" mapstructure:"prefix"`
 	Prefix string `json:"prefix" mapstructure:"prefix"`
+	// List of IP addresses and IP ranges allowed to set X-Forwarded-For/X-Real-IP headers.
+	ProxyAllowed     []string `json:"proxy_allowed" mapstructure:"proxy_allowed"`
+	allowHeadersFrom []func(net.IP) bool
+}
+
+func (b *Binding) parseAllowedProxy() error {
+	allowedFuncs, err := utils.ParseAllowedIPAndRanges(b.ProxyAllowed)
+	if err != nil {
+		return err
+	}
+	b.allowHeadersFrom = allowedFuncs
+	return nil
 }
 }
 
 
 func (b *Binding) isMutualTLSEnabled() bool {
 func (b *Binding) isMutualTLSEnabled() bool {
@@ -191,6 +204,9 @@ func (c *Configuration) Initialize(configDir string) error {
 		if !binding.IsValid() {
 		if !binding.IsValid() {
 			continue
 			continue
 		}
 		}
+		if err := binding.parseAllowedProxy(); err != nil {
+			return err
+		}
 
 
 		go func(binding Binding) {
 		go func(binding Binding) {
 			server := webDavServer{
 			server := webDavServer{

+ 10 - 0
webdavd/webdavd_test.go

@@ -469,6 +469,12 @@ func TestInitialization(t *testing.T) {
 	cfg.CertificateKeyFile = keyPath
 	cfg.CertificateKeyFile = keyPath
 	cfg.CACertificates = []string{caCrtPath}
 	cfg.CACertificates = []string{caCrtPath}
 	cfg.CARevocationLists = []string{caCRLPath}
 	cfg.CARevocationLists = []string{caCRLPath}
+	cfg.Bindings[0].ProxyAllowed = []string{"not valid"}
+	err = cfg.Initialize(configDir)
+	if assert.Error(t, err) {
+		assert.Contains(t, err.Error(), "is not a valid IP address")
+	}
+	cfg.Bindings[0].ProxyAllowed = nil
 	err = cfg.Initialize(configDir)
 	err = cfg.Initialize(configDir)
 	assert.Error(t, err)
 	assert.Error(t, err)
 }
 }
@@ -1180,6 +1186,10 @@ func TestQuotaLimits(t *testing.T) {
 		// test quota files
 		// test quota files
 		err = uploadFile(testFilePath, testFileName+".quota", testFileSize, client)
 		err = uploadFile(testFilePath, testFileName+".quota", testFileSize, client)
 		if !assert.NoError(t, err, "username: %v", user.Username) {
 		if !assert.NoError(t, err, "username: %v", user.Username) {
+			info, err := os.Stat(testFilePath)
+			if assert.NoError(t, err) {
+				fmt.Printf("local file size %v", info.Size())
+			}
 			printLatestLogs(20)
 			printLatestLogs(20)
 		}
 		}
 		err = uploadFile(testFilePath, testFileName+".quota1", testFileSize, client)
 		err = uploadFile(testFilePath, testFileName+".quota1", testFileSize, client)