httpclient: allow to set custom headers

This commit is contained in:
Nicola Murino 2021-05-25 08:36:01 +02:00
parent 1223957f91
commit 600268ebb8
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
12 changed files with 204 additions and 47 deletions

View file

@ -155,12 +155,10 @@ func (h *defaultActionHandler) handleHTTP(notification *ActionNotification) erro
startTime := time.Now()
respCode := 0
httpClient := httpclient.GetRetraybleHTTPClient()
var b bytes.Buffer
_ = json.NewEncoder(&b).Encode(notification)
resp, err := httpClient.Post(u.String(), "application/json", &b)
resp, err := httpclient.RetryablePost(Config.Actions.Hook, "application/json", &b)
if err == nil {
respCode = resp.StatusCode
resp.Body.Close()

View file

@ -420,8 +420,7 @@ func (c *Configuration) ExecuteStartupHook() error {
return err
}
startTime := time.Now()
httpClient := httpclient.GetRetraybleHTTPClient()
resp, err := httpClient.Get(url.String())
resp, err := httpclient.RetryableGet(url.String())
if err != nil {
logger.Warn(logSender, "", "Error executing startup hook: %v", err)
return err
@ -457,13 +456,12 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
ipAddr, c.PostConnectHook, err)
return err
}
httpClient := httpclient.GetRetraybleHTTPClient()
q := url.Query()
q.Add("ip", ipAddr)
q.Add("protocol", protocol)
url.RawQuery = q.Encode()
resp, err := httpClient.Get(url.String())
resp, err := httpclient.RetryableGet(url.String())
if err != nil {
logger.Warn(protocol, "", "Login from ip %#v denied, error executing post connect hook: %v", ipAddr, err)
return err

View file

@ -2631,8 +2631,7 @@ func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) {
}
func TestProxyProtocol(t *testing.T) {
httpClient := httpclient.GetHTTPClient()
resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
resp, err := httpclient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
if assert.NoError(t, err) {
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)

View file

@ -256,6 +256,7 @@ func Init() {
CACertificates: nil,
Certificates: nil,
SkipTLSVerify: false,
Headers: nil,
},
KMSConfig: kms.Configuration{
Secrets: kms.Secrets{
@ -577,6 +578,7 @@ func loadBindingsFromEnv() {
getWebDAVDBindingFromEnv(idx)
getHTTPDBindingFromEnv(idx)
getHTTPClientCertificatesFromEnv(idx)
getHTTPClientHeadersFromEnv(idx)
}
}
@ -889,6 +891,33 @@ func getHTTPClientCertificatesFromEnv(idx int) {
}
}
func getHTTPClientHeadersFromEnv(idx int) {
header := httpclient.Header{}
key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__KEY", idx))
if ok {
header.Key = key
}
value, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__VALUE", idx))
if ok {
header.Value = value
}
url, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__URL", idx))
if ok {
header.URL = url
}
if header.Key != "" && header.Value != "" {
if len(globalConf.HTTPConfig.Headers) > idx {
globalConf.HTTPConfig.Headers[idx] = header
} else {
globalConf.HTTPConfig.Headers = append(globalConf.HTTPConfig.Headers, header)
}
}
}
func setViperDefaults() {
viper.SetDefault("common.idle_timeout", globalConf.Common.IdleTimeout)
viper.SetDefault("common.upload_mode", globalConf.Common.UploadMode)

View file

@ -751,6 +751,77 @@ func TestHTTPClientCertificatesFromEnv(t *testing.T) {
require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key)
}
func TestHTTPClientHeadersFromEnv(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
httpConf := config.GetHTTPConfig()
httpConf.Headers = append(httpConf.Headers, httpclient.Header{
Key: "key",
Value: "value",
URL: "url",
})
c := make(map[string]httpclient.Config)
c["http"] = httpConf
jsonConf, err := json.Marshal(c)
require.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
require.NoError(t, err)
err = config.LoadConfig(configDir, confName)
require.NoError(t, err)
require.Len(t, config.GetHTTPConfig().Headers, 1)
require.Equal(t, "key", config.GetHTTPConfig().Headers[0].Key)
require.Equal(t, "value", config.GetHTTPConfig().Headers[0].Value)
require.Equal(t, "url", config.GetHTTPConfig().Headers[0].URL)
os.Setenv("SFTPGO_HTTP__HEADERS__0__KEY", "key0")
os.Setenv("SFTPGO_HTTP__HEADERS__0__VALUE", "value0")
os.Setenv("SFTPGO_HTTP__HEADERS__0__URL", "url0")
os.Setenv("SFTPGO_HTTP__HEADERS__8__KEY", "key8")
os.Setenv("SFTPGO_HTTP__HEADERS__9__KEY", "key9")
os.Setenv("SFTPGO_HTTP__HEADERS__9__VALUE", "value9")
os.Setenv("SFTPGO_HTTP__HEADERS__9__URL", "url9")
t.Cleanup(func() {
os.Unsetenv("SFTPGO_HTTP__HEADERS__0__KEY")
os.Unsetenv("SFTPGO_HTTP__HEADERS__0__VALUE")
os.Unsetenv("SFTPGO_HTTP__HEADERS__0__URL")
os.Unsetenv("SFTPGO_HTTP__HEADERS__8__KEY")
os.Unsetenv("SFTPGO_HTTP__HEADERS__9__KEY")
os.Unsetenv("SFTPGO_HTTP__HEADERS__9__VALUE")
os.Unsetenv("SFTPGO_HTTP__HEADERS__9__URL")
})
err = config.LoadConfig(configDir, confName)
require.NoError(t, err)
require.Len(t, config.GetHTTPConfig().Headers, 2)
require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key)
require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value)
require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL)
require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key)
require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value)
require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL)
err = os.Remove(configFilePath)
assert.NoError(t, err)
config.Init()
err = config.LoadConfig(configDir, "")
require.NoError(t, err)
require.Len(t, config.GetHTTPConfig().Headers, 2)
require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key)
require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value)
require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL)
require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key)
require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value)
require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL)
}
func TestConfigFromEnv(t *testing.T) {
reset()

View file

@ -1912,15 +1912,14 @@ func validateKeyboardAuthResponse(response keyboardAuthHookResponse) error {
return nil
}
func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (keyboardAuthHookResponse, error) {
func sendKeyboardAuthHTTPReq(url string, request keyboardAuthHookRequest) (keyboardAuthHookResponse, error) {
var response keyboardAuthHookResponse
httpClient := httpclient.GetHTTPClient()
reqAsJSON, err := json.Marshal(request)
if err != nil {
providerLog(logger.LevelWarn, "error serializing keyboard interactive auth request: %v", err)
return response, err
}
resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(reqAsJSON))
resp, err := httpclient.Post(url, "application/json", bytes.NewBuffer(reqAsJSON))
if err != nil {
providerLog(logger.LevelWarn, "error getting keyboard interactive auth hook HTTP response: %v", err)
return response, err
@ -1935,12 +1934,6 @@ func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (key
func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
authResult := 0
var url *url.URL
url, err := url.Parse(authHook)
if err != nil {
providerLog(logger.LevelWarn, "invalid url for keyboard interactive hook %#v, error: %v", authHook, err)
return authResult, err
}
requestID := xid.New().String()
req := keyboardAuthHookRequest{
Username: user.Username,
@ -1949,8 +1942,9 @@ func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.
RequestID: requestID,
}
var response keyboardAuthHookResponse
var err error
for {
response, err = sendKeyboardAuthHTTPReq(url, req)
response, err = sendKeyboardAuthHTTPReq(authHook, req)
if err != nil {
return authResult, err
}
@ -2120,12 +2114,6 @@ func isCheckPasswordHookDefined(protocol string) bool {
func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, error) {
if strings.HasPrefix(config.CheckPasswordHook, "http") {
var result []byte
var url *url.URL
url, err := url.Parse(config.CheckPasswordHook)
if err != nil {
providerLog(logger.LevelWarn, "invalid url for check password hook %#v, error: %v", config.CheckPasswordHook, err)
return result, err
}
req := checkPasswordRequest{
Username: username,
Password: password,
@ -2136,8 +2124,7 @@ func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, e
if err != nil {
return result, err
}
httpClient := httpclient.GetHTTPClient()
resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(reqAsJSON))
resp, err := httpclient.Post(config.CheckPasswordHook, "application/json", bytes.NewBuffer(reqAsJSON))
if err != nil {
providerLog(logger.LevelWarn, "error getting check password hook response: %v", err)
return result, err
@ -2192,8 +2179,8 @@ func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte
q.Add("ip", ip)
q.Add("protocol", protocol)
url.RawQuery = q.Encode()
httpClient := httpclient.GetHTTPClient()
resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
resp, err := httpclient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
if err != nil {
providerLog(logger.LevelWarn, "error getting pre-login hook response: %v", err)
return result, err
@ -2318,8 +2305,7 @@ func ExecutePostLoginHook(user *User, loginMethod, ip, protocol string, err erro
startTime := time.Now()
respCode := 0
httpClient := httpclient.GetRetraybleHTTPClient()
resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
if err == nil {
respCode = resp.StatusCode
resp.Body.Close()
@ -2353,14 +2339,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip,
}
}
if strings.HasPrefix(config.ExternalAuthHook, "http") {
var url *url.URL
var result []byte
url, err := url.Parse(config.ExternalAuthHook)
if err != nil {
providerLog(logger.LevelWarn, "invalid url for external auth hook %#v, error: %v", config.ExternalAuthHook, err)
return result, err
}
httpClient := httpclient.GetHTTPClient()
authRequest := make(map[string]string)
authRequest["username"] = username
authRequest["ip"] = ip
@ -2377,7 +2356,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip,
providerLog(logger.LevelWarn, "error serializing external auth request: %v", err)
return result, err
}
resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(authRequestAsJSON))
resp, err := httpclient.Post(config.ExternalAuthHook, "application/json", bytes.NewBuffer(authRequestAsJSON))
if err != nil {
providerLog(logger.LevelWarn, "error getting external auth hook HTTP response: %v", err)
return result, err
@ -2561,8 +2540,7 @@ func executeAction(operation string, user *User) {
q.Add("action", operation)
url.RawQuery = q.Encode()
startTime := time.Now()
httpClient := httpclient.GetRetraybleHTTPClient()
resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON))
respCode := 0
if err == nil {
respCode = resp.StatusCode

View file

@ -248,6 +248,10 @@ The configuration file contains the following sections:
- `cert`, string. Path to the certificate file. The path can be absolute or relative to the config dir.
- `key`, string. Path to the key file. The path can be absolute or relative to the config dir.
- `skip_tls_verify`, boolean. if enabled the HTTP client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
- `headers`, list of structs. You can define a list of http headers to add to each hook. Each struct has the following fields:
- `key`, string
- `value`, string. The header is silently ignored if `key` or `value` are empty
- `url`, string, optional. If not empty, the header will be added only if the request URL starts with the one specified here
- **kms**, configuration for the Key Management Service, more details can be found [here](./kms.md)
- `secrets`
- `url`

View file

@ -4,9 +4,11 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/hashicorp/go-retryablehttp"
@ -21,6 +23,15 @@ type TLSKeyPair struct {
Key string `json:"key" mapstructure:"key"`
}
// Header defines an HTTP header.
// If the URL is not empty, the header is added only if the
// requested URL starts with the one specified
type Header struct {
Key string `json:"key" mapstructure:"key"`
Value string `json:"value" mapstructure:"value"`
URL string `json:"url" mapstructure:"url"`
}
// Config defines the configuration for HTTP clients.
// HTTP clients are used for executing hooks such as the ones used for
// custom actions, external authentication and pre-login user modifications
@ -44,7 +55,9 @@ type Config struct {
// the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"`
SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"`
// Headers defines a list of http headers to add to each request
Headers []Header `json:"headers" mapstructure:"headers"`
customTransport *http.Transport
tlsConfig *tls.Config
}
@ -76,6 +89,13 @@ func (c *Config) Initialize(configDir string) error {
if err != nil {
return err
}
var headers []Header
for _, h := range c.Headers {
if h.Key != "" && h.Value != "" {
headers = append(headers, h)
}
}
c.Headers = headers
httpConfig = *c
return nil
}
@ -162,3 +182,63 @@ func GetRetraybleHTTPClient() *retryablehttp.Client {
return client
}
// Get issues a GET to the specified URL
func Get(url string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
addHeaders(req, url)
return GetHTTPClient().Do(req)
}
// Post issues a POST to the specified URL
func Post(url string, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
addHeaders(req, url)
return GetHTTPClient().Do(req)
}
// RetryableGet issues a GET to the specified URL using the retryable client
func RetryableGet(url string) (*http.Response, error) {
req, err := retryablehttp.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
addHeadersToRetryableReq(req, url)
return GetRetraybleHTTPClient().Do(req)
}
// RetryablePost issues a POST to the specified URL using the retryable client
func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) {
req, err := retryablehttp.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
addHeadersToRetryableReq(req, url)
return GetRetraybleHTTPClient().Do(req)
}
func addHeaders(req *http.Request, url string) {
for idx := range httpConfig.Headers {
h := &httpConfig.Headers[idx]
if h.URL == "" || strings.HasPrefix(url, h.URL) {
req.Header.Set(h.Key, h.Value)
}
}
}
func addHeadersToRetryableReq(req *retryablehttp.Request, url string) {
for idx := range httpConfig.Headers {
h := &httpConfig.Headers[idx]
if h.URL == "" || strings.HasPrefix(url, h.URL) {
req.Header.Set(h.Key, h.Value)
}
}
}

View file

@ -31,8 +31,8 @@ const (
)
var (
tokenDuration = 10 * time.Minute
tokenRefreshMin = 5 * time.Minute
tokenDuration = 15 * time.Minute
tokenRefreshMin = 10 * time.Minute
)
type jwtTokenClaims struct {

View file

@ -1,7 +1,6 @@
// Package httpd implements REST API and Web interface for SFTPGo.
// The OpenAPI 3 schema for the exposed API can be found inside the source tree:
// https://github.com/drakkan/sftpgo/blob/main/httpd/schema/openapi.yaml
// A basic Web interface to manage users and connections is provided too
package httpd
import (

View file

@ -217,7 +217,8 @@
"retry_max": 3,
"ca_certificates": [],
"certificates": [],
"skip_tls_verify": false
"skip_tls_verify": false,
"headers": []
},
"kms": {
"secrets": {

View file

@ -404,7 +404,7 @@ func createDirPathIfMissing(file string, perm os.FileMode) error {
func GenerateRandomBytes(length int) []byte {
b := make([]byte, length)
_, err := io.ReadFull(rand.Reader, b)
if err != nil {
if err == nil {
return b
}