123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- // Copyright (C) 2019-2023 Nicola Murino
- //
- // This program is free software: you can redistribute it and/or modify
- // it under the terms of the GNU Affero General Public License as published
- // by the Free Software Foundation, version 3.
- //
- // This program is distributed in the hope that it will be useful,
- // but WITHOUT ANY WARRANTY; without even the implied warranty of
- // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- // GNU Affero General Public License for more details.
- //
- // You should have received a copy of the GNU Affero General Public License
- // along with this program. If not, see <https://www.gnu.org/licenses/>.
- // Package httpclient provides HTTP client configuration for SFTPGo hooks
- package httpclient
- import (
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "io"
- "net/http"
- "os"
- "path/filepath"
- "strings"
- "time"
- "github.com/hashicorp/go-retryablehttp"
- "github.com/drakkan/sftpgo/v2/internal/logger"
- "github.com/drakkan/sftpgo/v2/internal/util"
- )
- // TLSKeyPair defines the paths for a TLS key pair
- type TLSKeyPair struct {
- Cert string `json:"cert" mapstructure:"cert"`
- Key string `json:"key" mapstructure:"key"`
- }
- // Header defines an HTTP header.
- // If the URL is not empty, the header is added only if the
- // requested URL starts with the one specified
- type Header struct {
- Key string `json:"key" mapstructure:"key"`
- Value string `json:"value" mapstructure:"value"`
- URL string `json:"url" mapstructure:"url"`
- }
- // Config defines the configuration for HTTP clients.
- // HTTP clients are used for executing hooks such as the ones used for
- // custom actions, external authentication and pre-login user modifications
- type Config struct {
- // Timeout specifies a time limit, in seconds, for a request
- Timeout float64 `json:"timeout" mapstructure:"timeout"`
- // RetryWaitMin defines the minimum waiting time between attempts in seconds
- RetryWaitMin int `json:"retry_wait_min" mapstructure:"retry_wait_min"`
- // RetryWaitMax defines the minimum waiting time between attempts in seconds
- RetryWaitMax int `json:"retry_wait_max" mapstructure:"retry_wait_max"`
- // RetryMax defines the maximum number of attempts
- RetryMax int `json:"retry_max" mapstructure:"retry_max"`
- // CACertificates defines extra CA certificates to trust.
- // The paths can be absolute or relative to the config dir.
- // Adding trusted CA certificates is a convenient way to use self-signed
- // certificates without defeating the purpose of using TLS
- CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"`
- // Certificates defines the certificates to use for mutual TLS
- Certificates []TLSKeyPair `json:"certificates" mapstructure:"certificates"`
- // if enabled the HTTP client accepts any TLS certificate presented by
- // the server and any host name in that certificate.
- // In this mode, TLS is susceptible to man-in-the-middle attacks.
- // This should be used only for testing.
- SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"`
- // Headers defines a list of http headers to add to each request
- Headers []Header `json:"headers" mapstructure:"headers"`
- customTransport *http.Transport
- }
- const logSender = "httpclient"
- var httpConfig Config
- // Initialize configures HTTP clients
- func (c *Config) Initialize(configDir string) error {
- if c.Timeout <= 0 {
- return fmt.Errorf("invalid timeout: %v", c.Timeout)
- }
- rootCAs, err := c.loadCACerts(configDir)
- if err != nil {
- return err
- }
- customTransport := http.DefaultTransport.(*http.Transport).Clone()
- if customTransport.TLSClientConfig != nil {
- customTransport.TLSClientConfig.RootCAs = rootCAs
- } else {
- customTransport.TLSClientConfig = &tls.Config{
- RootCAs: rootCAs,
- }
- }
- customTransport.TLSClientConfig.InsecureSkipVerify = c.SkipTLSVerify
- c.customTransport = customTransport
- err = c.loadCertificates(configDir)
- if err != nil {
- return err
- }
- var headers []Header
- for _, h := range c.Headers {
- if h.Key != "" && h.Value != "" {
- headers = append(headers, h)
- }
- }
- c.Headers = headers
- httpConfig = *c
- return nil
- }
- // loadCACerts returns system cert pools and try to add the configured
- // CA certificates to it
- func (c *Config) loadCACerts(configDir string) (*x509.CertPool, error) {
- if len(c.CACertificates) == 0 {
- return nil, nil
- }
- rootCAs, err := x509.SystemCertPool()
- if err != nil {
- rootCAs = x509.NewCertPool()
- }
- for _, ca := range c.CACertificates {
- if !util.IsFileInputValid(ca) {
- return nil, fmt.Errorf("unable to load invalid CA certificate: %q", ca)
- }
- if !filepath.IsAbs(ca) {
- ca = filepath.Join(configDir, ca)
- }
- certs, err := os.ReadFile(ca)
- if err != nil {
- return nil, fmt.Errorf("unable to load CA certificate: %v", err)
- }
- if rootCAs.AppendCertsFromPEM(certs) {
- logger.Debug(logSender, "", "CA certificate %q added to the trusted certificates", ca)
- } else {
- return nil, fmt.Errorf("unable to add CA certificate %q to the trusted cetificates", ca)
- }
- }
- return rootCAs, nil
- }
- func (c *Config) loadCertificates(configDir string) error {
- if len(c.Certificates) == 0 {
- return nil
- }
- for _, keyPair := range c.Certificates {
- cert := keyPair.Cert
- key := keyPair.Key
- if !util.IsFileInputValid(cert) {
- return fmt.Errorf("unable to load invalid certificate: %q", cert)
- }
- if !util.IsFileInputValid(key) {
- return fmt.Errorf("unable to load invalid key: %q", key)
- }
- if !filepath.IsAbs(cert) {
- cert = filepath.Join(configDir, cert)
- }
- if !filepath.IsAbs(key) {
- key = filepath.Join(configDir, key)
- }
- tlsCert, err := tls.LoadX509KeyPair(cert, key)
- if err != nil {
- return fmt.Errorf("unable to load key pair %q, %q: %v", cert, key, err)
- }
- x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
- if err == nil {
- logger.Debug(logSender, "", "adding leaf certificate for key pair %q, %q", cert, key)
- tlsCert.Leaf = x509Cert
- }
- logger.Debug(logSender, "", "client certificate %q and key %q successfully loaded", cert, key)
- c.customTransport.TLSClientConfig.Certificates = append(c.customTransport.TLSClientConfig.Certificates, tlsCert)
- }
- return nil
- }
- // GetHTTPClient returns a new HTTP client with the configured parameters
- func GetHTTPClient() *http.Client {
- return &http.Client{
- Timeout: time.Duration(httpConfig.Timeout * float64(time.Second)),
- Transport: httpConfig.customTransport,
- }
- }
- // GetRetraybleHTTPClient returns an HTTP client that retry a request on error.
- // It uses the configured retry parameters
- func GetRetraybleHTTPClient() *retryablehttp.Client {
- client := retryablehttp.NewClient()
- client.HTTPClient.Timeout = time.Duration(httpConfig.Timeout * float64(time.Second))
- client.HTTPClient.Transport.(*http.Transport).TLSClientConfig = httpConfig.customTransport.TLSClientConfig
- client.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"}
- client.RetryWaitMin = time.Duration(httpConfig.RetryWaitMin) * time.Second
- client.RetryWaitMax = time.Duration(httpConfig.RetryWaitMax) * time.Second
- client.RetryMax = httpConfig.RetryMax
- return client
- }
- // Get issues a GET to the specified URL
- func Get(url string) (*http.Response, error) {
- req, err := http.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return nil, err
- }
- addHeaders(req, url)
- client := GetHTTPClient()
- defer client.CloseIdleConnections()
- return client.Do(req)
- }
- // Post issues a POST to the specified URL
- func Post(url string, contentType string, body io.Reader) (*http.Response, error) {
- req, err := http.NewRequest(http.MethodPost, url, body)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", contentType)
- addHeaders(req, url)
- client := GetHTTPClient()
- defer client.CloseIdleConnections()
- return client.Do(req)
- }
- // RetryableGet issues a GET to the specified URL using the retryable client
- func RetryableGet(url string) (*http.Response, error) {
- req, err := retryablehttp.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return nil, err
- }
- addHeadersToRetryableReq(req, url)
- client := GetRetraybleHTTPClient()
- defer client.HTTPClient.CloseIdleConnections()
- return client.Do(req)
- }
- // RetryablePost issues a POST to the specified URL using the retryable client
- func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) {
- req, err := retryablehttp.NewRequest(http.MethodPost, url, body)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", contentType)
- addHeadersToRetryableReq(req, url)
- client := GetRetraybleHTTPClient()
- defer client.HTTPClient.CloseIdleConnections()
- return client.Do(req)
- }
- func addHeaders(req *http.Request, url string) {
- for idx := range httpConfig.Headers {
- h := &httpConfig.Headers[idx]
- if h.URL == "" || strings.HasPrefix(url, h.URL) {
- req.Header.Set(h.Key, h.Value)
- }
- }
- }
- func addHeadersToRetryableReq(req *retryablehttp.Request, url string) {
- for idx := range httpConfig.Headers {
- h := &httpConfig.Headers[idx]
- if h.URL == "" || strings.HasPrefix(url, h.URL) {
- req.Header.Set(h.Key, h.Value)
- }
- }
- }
|