diff --git a/cmd/notification-http/main.go b/cmd/notification-http/main.go index 340d462c1..382f30fea 100644 --- a/cmd/notification-http/main.go +++ b/cmd/notification-http/main.go @@ -7,8 +7,10 @@ import ( "crypto/x509" "fmt" "io" + "net" "net/http" "os" + "strings" "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" @@ -19,6 +21,7 @@ import ( type PluginConfig struct { Name string `yaml:"name"` URL string `yaml:"url"` + UnixSocket string `yaml:"unix_socket"` Headers map[string]string `yaml:"headers"` SkipTLSVerification bool `yaml:"skip_tls_verification"` Method string `yaml:"method"` @@ -66,36 +69,40 @@ func getCertPool(caPath string) (*x509.CertPool, error) { return cp, nil } -func getTLSClient(tlsVerify bool, caPath, certPath, keyPath string) (*http.Client, error) { - var client *http.Client - - caCertPool, err := getCertPool(caPath) +func getTLSClient(c *PluginConfig) error { + caCertPool, err := getCertPool(c.CAPath) if err != nil { - return nil, err + return err } tlsConfig := &tls.Config{ RootCAs: caCertPool, - InsecureSkipVerify: tlsVerify, + InsecureSkipVerify: c.SkipTLSVerification, } - if certPath != "" && keyPath != "" { - logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", certPath, keyPath)) + if c.CertPath != "" && c.KeyPath != "" { + logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", c.CertPath, c.KeyPath)) - cert, err := tls.LoadX509KeyPair(certPath, keyPath) + cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath) if err != nil { - return nil, fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", certPath, keyPath, err) + return fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", c.CertPath, c.KeyPath, err) } tlsConfig.Certificates = []tls.Certificate{cert} } - - client = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, + transport := &http.Transport{ + TLSClientConfig: tlsConfig, } - return client, err + if c.UnixSocket != "" { + logger.Info(fmt.Sprintf("Using socket '%s'", c.UnixSocket)) + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(c.UnixSocket, "/")) + } + } + c.Client = &http.Client{ + Transport: transport, + } + return nil } func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { @@ -135,6 +142,7 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific if resp.StatusCode < 200 || resp.StatusCode >= 300 { logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode)) + logger.Debug(fmt.Sprintf("HTTP server returned body: %s", string(respData))) return &protobufs.Empty{}, nil } @@ -147,7 +155,7 @@ func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (* if err != nil { return nil, err } - d.Client, err = getTLSClient(d.SkipTLSVerification, d.CAPath, d.CertPath, d.KeyPath) + err = getTLSClient(&d) if err != nil { return nil, err }