Explorar o código

httpclient: allow to set custom headers

Nicola Murino %!s(int64=4) %!d(string=hai) anos
pai
achega
600268ebb8

+ 1 - 3
common/actions.go

@@ -155,12 +155,10 @@ func (h *defaultActionHandler) handleHTTP(notification *ActionNotification) erro
 	startTime := time.Now()
 	respCode := 0
 
-	httpClient := httpclient.GetRetraybleHTTPClient()
-
 	var b bytes.Buffer
 	_ = json.NewEncoder(&b).Encode(notification)
 
-	resp, err := httpClient.Post(u.String(), "application/json", &b)
+	resp, err := httpclient.RetryablePost(Config.Actions.Hook, "application/json", &b)
 	if err == nil {
 		respCode = resp.StatusCode
 		resp.Body.Close()

+ 2 - 4
common/common.go

@@ -420,8 +420,7 @@ func (c *Configuration) ExecuteStartupHook() error {
 			return err
 		}
 		startTime := time.Now()
-		httpClient := httpclient.GetRetraybleHTTPClient()
-		resp, err := httpClient.Get(url.String())
+		resp, err := httpclient.RetryableGet(url.String())
 		if err != nil {
 			logger.Warn(logSender, "", "Error executing startup hook: %v", err)
 			return err
@@ -457,13 +456,12 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
 				ipAddr, c.PostConnectHook, err)
 			return err
 		}
-		httpClient := httpclient.GetRetraybleHTTPClient()
 		q := url.Query()
 		q.Add("ip", ipAddr)
 		q.Add("protocol", protocol)
 		url.RawQuery = q.Encode()
 
-		resp, err := httpClient.Get(url.String())
+		resp, err := httpclient.RetryableGet(url.String())
 		if err != nil {
 			logger.Warn(protocol, "", "Login from ip %#v denied, error executing post connect hook: %v", ipAddr, err)
 			return err

+ 1 - 2
common/protocol_test.go

@@ -2631,8 +2631,7 @@ func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) {
 }
 
 func TestProxyProtocol(t *testing.T) {
-	httpClient := httpclient.GetHTTPClient()
-	resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
+	resp, err := httpclient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
 	if assert.NoError(t, err) {
 		defer resp.Body.Close()
 		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)

+ 29 - 0
config/config.go

@@ -256,6 +256,7 @@ func Init() {
 			CACertificates: nil,
 			Certificates:   nil,
 			SkipTLSVerify:  false,
+			Headers:        nil,
 		},
 		KMSConfig: kms.Configuration{
 			Secrets: kms.Secrets{
@@ -577,6 +578,7 @@ func loadBindingsFromEnv() {
 		getWebDAVDBindingFromEnv(idx)
 		getHTTPDBindingFromEnv(idx)
 		getHTTPClientCertificatesFromEnv(idx)
+		getHTTPClientHeadersFromEnv(idx)
 	}
 }
 
@@ -889,6 +891,33 @@ func getHTTPClientCertificatesFromEnv(idx int) {
 	}
 }
 
+func getHTTPClientHeadersFromEnv(idx int) {
+	header := httpclient.Header{}
+
+	key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__KEY", idx))
+	if ok {
+		header.Key = key
+	}
+
+	value, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__VALUE", idx))
+	if ok {
+		header.Value = value
+	}
+
+	url, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__URL", idx))
+	if ok {
+		header.URL = url
+	}
+
+	if header.Key != "" && header.Value != "" {
+		if len(globalConf.HTTPConfig.Headers) > idx {
+			globalConf.HTTPConfig.Headers[idx] = header
+		} else {
+			globalConf.HTTPConfig.Headers = append(globalConf.HTTPConfig.Headers, header)
+		}
+	}
+}
+
 func setViperDefaults() {
 	viper.SetDefault("common.idle_timeout", globalConf.Common.IdleTimeout)
 	viper.SetDefault("common.upload_mode", globalConf.Common.UploadMode)

+ 71 - 0
config/config_test.go

@@ -751,6 +751,77 @@ func TestHTTPClientCertificatesFromEnv(t *testing.T) {
 	require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key)
 }
 
+func TestHTTPClientHeadersFromEnv(t *testing.T) {
+	reset()
+
+	configDir := ".."
+	confName := tempConfigName + ".json"
+	configFilePath := filepath.Join(configDir, confName)
+	err := config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	httpConf := config.GetHTTPConfig()
+	httpConf.Headers = append(httpConf.Headers, httpclient.Header{
+		Key:   "key",
+		Value: "value",
+		URL:   "url",
+	})
+	c := make(map[string]httpclient.Config)
+	c["http"] = httpConf
+	jsonConf, err := json.Marshal(c)
+	require.NoError(t, err)
+	err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
+	require.NoError(t, err)
+	err = config.LoadConfig(configDir, confName)
+	require.NoError(t, err)
+	require.Len(t, config.GetHTTPConfig().Headers, 1)
+	require.Equal(t, "key", config.GetHTTPConfig().Headers[0].Key)
+	require.Equal(t, "value", config.GetHTTPConfig().Headers[0].Value)
+	require.Equal(t, "url", config.GetHTTPConfig().Headers[0].URL)
+
+	os.Setenv("SFTPGO_HTTP__HEADERS__0__KEY", "key0")
+	os.Setenv("SFTPGO_HTTP__HEADERS__0__VALUE", "value0")
+	os.Setenv("SFTPGO_HTTP__HEADERS__0__URL", "url0")
+	os.Setenv("SFTPGO_HTTP__HEADERS__8__KEY", "key8")
+	os.Setenv("SFTPGO_HTTP__HEADERS__9__KEY", "key9")
+	os.Setenv("SFTPGO_HTTP__HEADERS__9__VALUE", "value9")
+	os.Setenv("SFTPGO_HTTP__HEADERS__9__URL", "url9")
+
+	t.Cleanup(func() {
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__0__KEY")
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__0__VALUE")
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__0__URL")
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__8__KEY")
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__9__KEY")
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__9__VALUE")
+		os.Unsetenv("SFTPGO_HTTP__HEADERS__9__URL")
+	})
+
+	err = config.LoadConfig(configDir, confName)
+	require.NoError(t, err)
+	require.Len(t, config.GetHTTPConfig().Headers, 2)
+	require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key)
+	require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value)
+	require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL)
+	require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key)
+	require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value)
+	require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL)
+
+	err = os.Remove(configFilePath)
+	assert.NoError(t, err)
+
+	config.Init()
+
+	err = config.LoadConfig(configDir, "")
+	require.NoError(t, err)
+	require.Len(t, config.GetHTTPConfig().Headers, 2)
+	require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key)
+	require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value)
+	require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL)
+	require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key)
+	require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value)
+	require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL)
+}
+
 func TestConfigFromEnv(t *testing.T) {
 	reset()
 

+ 10 - 32
dataprovider/dataprovider.go

@@ -1912,15 +1912,14 @@ func validateKeyboardAuthResponse(response keyboardAuthHookResponse) error {
 	return nil
 }
 
-func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (keyboardAuthHookResponse, error) {
+func sendKeyboardAuthHTTPReq(url string, request keyboardAuthHookRequest) (keyboardAuthHookResponse, error) {
 	var response keyboardAuthHookResponse
-	httpClient := httpclient.GetHTTPClient()
 	reqAsJSON, err := json.Marshal(request)
 	if err != nil {
 		providerLog(logger.LevelWarn, "error serializing keyboard interactive auth request: %v", err)
 		return response, err
 	}
-	resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(reqAsJSON))
+	resp, err := httpclient.Post(url, "application/json", bytes.NewBuffer(reqAsJSON))
 	if err != nil {
 		providerLog(logger.LevelWarn, "error getting keyboard interactive auth hook HTTP response: %v", err)
 		return response, err
@@ -1935,12 +1934,6 @@ func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (key
 
 func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
 	authResult := 0
-	var url *url.URL
-	url, err := url.Parse(authHook)
-	if err != nil {
-		providerLog(logger.LevelWarn, "invalid url for keyboard interactive hook %#v, error: %v", authHook, err)
-		return authResult, err
-	}
 	requestID := xid.New().String()
 	req := keyboardAuthHookRequest{
 		Username:  user.Username,
@@ -1949,8 +1942,9 @@ func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.
 		RequestID: requestID,
 	}
 	var response keyboardAuthHookResponse
+	var err error
 	for {
-		response, err = sendKeyboardAuthHTTPReq(url, req)
+		response, err = sendKeyboardAuthHTTPReq(authHook, req)
 		if err != nil {
 			return authResult, err
 		}
@@ -2120,12 +2114,6 @@ func isCheckPasswordHookDefined(protocol string) bool {
 func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, error) {
 	if strings.HasPrefix(config.CheckPasswordHook, "http") {
 		var result []byte
-		var url *url.URL
-		url, err := url.Parse(config.CheckPasswordHook)
-		if err != nil {
-			providerLog(logger.LevelWarn, "invalid url for check password hook %#v, error: %v", config.CheckPasswordHook, err)
-			return result, err
-		}
 		req := checkPasswordRequest{
 			Username: username,
 			Password: password,
@@ -2136,8 +2124,7 @@ func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, e
 		if err != nil {
 			return result, err
 		}
-		httpClient := httpclient.GetHTTPClient()
-		resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(reqAsJSON))
+		resp, err := httpclient.Post(config.CheckPasswordHook, "application/json", bytes.NewBuffer(reqAsJSON))
 		if err != nil {
 			providerLog(logger.LevelWarn, "error getting check password hook response: %v", err)
 			return result, err
@@ -2192,8 +2179,8 @@ func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte
 		q.Add("ip", ip)
 		q.Add("protocol", protocol)
 		url.RawQuery = q.Encode()
-		httpClient := httpclient.GetHTTPClient()
-		resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
+
+		resp, err := httpclient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
 		if err != nil {
 			providerLog(logger.LevelWarn, "error getting pre-login hook response: %v", err)
 			return result, err
@@ -2318,8 +2305,7 @@ func ExecutePostLoginHook(user *User, loginMethod, ip, protocol string, err erro
 
 			startTime := time.Now()
 			respCode := 0
-			httpClient := httpclient.GetRetraybleHTTPClient()
-			resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
+			resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
 			if err == nil {
 				respCode = resp.StatusCode
 				resp.Body.Close()
@@ -2353,14 +2339,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip,
 		}
 	}
 	if strings.HasPrefix(config.ExternalAuthHook, "http") {
-		var url *url.URL
 		var result []byte
-		url, err := url.Parse(config.ExternalAuthHook)
-		if err != nil {
-			providerLog(logger.LevelWarn, "invalid url for external auth hook %#v, error: %v", config.ExternalAuthHook, err)
-			return result, err
-		}
-		httpClient := httpclient.GetHTTPClient()
 		authRequest := make(map[string]string)
 		authRequest["username"] = username
 		authRequest["ip"] = ip
@@ -2377,7 +2356,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip,
 			providerLog(logger.LevelWarn, "error serializing external auth request: %v", err)
 			return result, err
 		}
-		resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(authRequestAsJSON))
+		resp, err := httpclient.Post(config.ExternalAuthHook, "application/json", bytes.NewBuffer(authRequestAsJSON))
 		if err != nil {
 			providerLog(logger.LevelWarn, "error getting external auth hook HTTP response: %v", err)
 			return result, err
@@ -2561,8 +2540,7 @@ func executeAction(operation string, user *User) {
 			q.Add("action", operation)
 			url.RawQuery = q.Encode()
 			startTime := time.Now()
-			httpClient := httpclient.GetRetraybleHTTPClient()
-			resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
+			resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
 			respCode := 0
 			if err == nil {
 				respCode = resp.StatusCode

+ 4 - 0
docs/full-configuration.md

@@ -248,6 +248,10 @@ The configuration file contains the following sections:
     - `cert`, string. Path to the certificate file. The path can be absolute or relative to the config dir.
     - `key`, string. Path to the key file. The path can be absolute or relative to the config dir.
   - `skip_tls_verify`, boolean. if enabled the HTTP client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
+  - `headers`, list of structs. You can define a list of http headers to add to each hook. Each struct has the following fields:
+    - `key`, string
+    - `value`, string. The header is silently ignored if `key` or `value` are empty
+    - `url`, string, optional. If not empty, the header will be added only if the request URL starts with the one specified here
 - **kms**, configuration for the Key Management Service, more details can be found [here](./kms.md)
   - `secrets`
     - `url`

+ 81 - 1
httpclient/httpclient.go

@@ -4,9 +4,11 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"fmt"
+	"io"
 	"net/http"
 	"os"
 	"path/filepath"
+	"strings"
 	"time"
 
 	"github.com/hashicorp/go-retryablehttp"
@@ -21,6 +23,15 @@ type TLSKeyPair struct {
 	Key  string `json:"key" mapstructure:"key"`
 }
 
+// Header defines an HTTP header.
+// If the URL is not empty, the header is added only if the
+// requested URL starts with the one specified
+type Header struct {
+	Key   string `json:"key" mapstructure:"key"`
+	Value string `json:"value" mapstructure:"value"`
+	URL   string `json:"url" mapstructure:"url"`
+}
+
 // Config defines the configuration for HTTP clients.
 // HTTP clients are used for executing hooks such as the ones used for
 // custom actions, external authentication and pre-login user modifications
@@ -44,7 +55,9 @@ type Config struct {
 	// the server and any host name in that certificate.
 	// In this mode, TLS is susceptible to man-in-the-middle attacks.
 	// This should be used only for testing.
-	SkipTLSVerify   bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"`
+	SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"`
+	// Headers defines a list of http headers to add to each request
+	Headers         []Header `json:"headers" mapstructure:"headers"`
 	customTransport *http.Transport
 	tlsConfig       *tls.Config
 }
@@ -76,6 +89,13 @@ func (c *Config) Initialize(configDir string) error {
 	if err != nil {
 		return err
 	}
+	var headers []Header
+	for _, h := range c.Headers {
+		if h.Key != "" && h.Value != "" {
+			headers = append(headers, h)
+		}
+	}
+	c.Headers = headers
 	httpConfig = *c
 	return nil
 }
@@ -162,3 +182,63 @@ func GetRetraybleHTTPClient() *retryablehttp.Client {
 
 	return client
 }
+
+// Get issues a GET to the specified URL
+func Get(url string) (*http.Response, error) {
+	req, err := http.NewRequest(http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+	addHeaders(req, url)
+	return GetHTTPClient().Do(req)
+}
+
+// Post issues a POST to the specified URL
+func Post(url string, contentType string, body io.Reader) (*http.Response, error) {
+	req, err := http.NewRequest(http.MethodPost, url, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", contentType)
+	addHeaders(req, url)
+	return GetHTTPClient().Do(req)
+}
+
+// RetryableGet issues a GET to the specified URL using the retryable client
+func RetryableGet(url string) (*http.Response, error) {
+	req, err := retryablehttp.NewRequest(http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+	addHeadersToRetryableReq(req, url)
+	return GetRetraybleHTTPClient().Do(req)
+}
+
+// RetryablePost issues a POST to the specified URL using the retryable client
+func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) {
+	req, err := retryablehttp.NewRequest(http.MethodPost, url, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", contentType)
+	addHeadersToRetryableReq(req, url)
+	return GetRetraybleHTTPClient().Do(req)
+}
+
+func addHeaders(req *http.Request, url string) {
+	for idx := range httpConfig.Headers {
+		h := &httpConfig.Headers[idx]
+		if h.URL == "" || strings.HasPrefix(url, h.URL) {
+			req.Header.Set(h.Key, h.Value)
+		}
+	}
+}
+
+func addHeadersToRetryableReq(req *retryablehttp.Request, url string) {
+	for idx := range httpConfig.Headers {
+		h := &httpConfig.Headers[idx]
+		if h.URL == "" || strings.HasPrefix(url, h.URL) {
+			req.Header.Set(h.Key, h.Value)
+		}
+	}
+}

+ 2 - 2
httpd/auth_utils.go

@@ -31,8 +31,8 @@ const (
 )
 
 var (
-	tokenDuration   = 10 * time.Minute
-	tokenRefreshMin = 5 * time.Minute
+	tokenDuration   = 15 * time.Minute
+	tokenRefreshMin = 10 * time.Minute
 )
 
 type jwtTokenClaims struct {

+ 0 - 1
httpd/httpd.go

@@ -1,7 +1,6 @@
 // Package httpd implements REST API and Web interface for SFTPGo.
 // The OpenAPI 3 schema for the exposed API can be found inside the source tree:
 // https://github.com/drakkan/sftpgo/blob/main/httpd/schema/openapi.yaml
-// A basic Web interface to manage users and connections is provided too
 package httpd
 
 import (

+ 2 - 1
sftpgo.json

@@ -217,7 +217,8 @@
     "retry_max": 3,
     "ca_certificates": [],
     "certificates": [],
-    "skip_tls_verify": false
+    "skip_tls_verify": false,
+    "headers": []
   },
   "kms": {
     "secrets": {

+ 1 - 1
utils/utils.go

@@ -404,7 +404,7 @@ func createDirPathIfMissing(file string, perm os.FileMode) error {
 func GenerateRandomBytes(length int) []byte {
 	b := make([]byte, length)
 	_, err := io.ReadFull(rand.Reader, b)
-	if err != nil {
+	if err == nil {
 		return b
 	}