Browse Source

Add unix socket option to http plugin, we have to use this in conjunction with URL parameter as we dont know which path the user wants so if they would like to communicate over unix socket they need to use both, however, the hostname can be whatever they want. We could be a little smarter and actually parse the url, however, increasing code when a user can just define it correctly make no sense

Laurence 1 year ago
parent
commit
4d83b3b49f
1 changed files with 25 additions and 17 deletions
  1. 25 17
      cmd/notification-http/main.go

+ 25 - 17
cmd/notification-http/main.go

@@ -7,8 +7,10 @@ import (
 	"crypto/x509"
 	"crypto/x509"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
+	"strings"
 
 
 	"github.com/crowdsecurity/crowdsec/pkg/protobufs"
 	"github.com/crowdsecurity/crowdsec/pkg/protobufs"
 	"github.com/hashicorp/go-hclog"
 	"github.com/hashicorp/go-hclog"
@@ -19,6 +21,7 @@ import (
 type PluginConfig struct {
 type PluginConfig struct {
 	Name                string            `yaml:"name"`
 	Name                string            `yaml:"name"`
 	URL                 string            `yaml:"url"`
 	URL                 string            `yaml:"url"`
+	UnixSocket          string            `yaml:"unix_socket"`
 	Headers             map[string]string `yaml:"headers"`
 	Headers             map[string]string `yaml:"headers"`
 	SkipTLSVerification bool              `yaml:"skip_tls_verification"`
 	SkipTLSVerification bool              `yaml:"skip_tls_verification"`
 	Method              string            `yaml:"method"`
 	Method              string            `yaml:"method"`
@@ -66,36 +69,40 @@ func getCertPool(caPath string) (*x509.CertPool, error) {
 	return cp, nil
 	return cp, nil
 }
 }
 
 
-func getTLSClient(tlsVerify bool, caPath, certPath, keyPath string) (*http.Client, error) {
-	var client *http.Client
-
-	caCertPool, err := getCertPool(caPath)
+func getTLSClient(c *PluginConfig) error {
+	caCertPool, err := getCertPool(c.CAPath)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return err
 	}
 	}
 
 
 	tlsConfig := &tls.Config{
 	tlsConfig := &tls.Config{
 		RootCAs:            caCertPool,
 		RootCAs:            caCertPool,
-		InsecureSkipVerify: tlsVerify,
+		InsecureSkipVerify: c.SkipTLSVerification,
 	}
 	}
 
 
-	if certPath != "" && keyPath != "" {
-		logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", certPath, keyPath))
+	if c.CertPath != "" && c.KeyPath != "" {
+		logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", c.CertPath, c.KeyPath))
 
 
-		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+		cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", certPath, keyPath, err)
+			return fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", c.CertPath, c.KeyPath, err)
 		}
 		}
 
 
 		tlsConfig.Certificates = []tls.Certificate{cert}
 		tlsConfig.Certificates = []tls.Certificate{cert}
 	}
 	}
-
-	client = &http.Client{
-		Transport: &http.Transport{
-			TLSClientConfig: tlsConfig,
-		},
+	transport := &http.Transport{
+		TLSClientConfig: tlsConfig,
+	}
+	if c.UnixSocket != "" {
+		logger.Info(fmt.Sprintf("Using socket '%s'", c.UnixSocket))
+		transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
+			return net.Dial("unix", strings.TrimSuffix(c.UnixSocket, "/"))
+		}
+	}
+	c.Client = &http.Client{
+		Transport: transport,
 	}
 	}
-	return client, err
+	return nil
 }
 }
 
 
 func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) {
 func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) {
@@ -135,6 +142,7 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific
 
 
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
 		logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode))
 		logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode))
+		logger.Debug(fmt.Sprintf("HTTP server returned body: %s", string(respData)))
 		return &protobufs.Empty{}, nil
 		return &protobufs.Empty{}, nil
 	}
 	}
 
 
@@ -147,7 +155,7 @@ func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (*
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	d.Client, err = getTLSClient(d.SkipTLSVerification, d.CAPath, d.CertPath, d.KeyPath)
+	err = getTLSClient(&d)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}