LAPI: local api unix socket support (#2770)

This commit is contained in:
mmetc 2024-03-14 10:43:02 +01:00 committed by GitHub
parent 2a7e8383c8
commit 6c042f18f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 718 additions and 139 deletions

5
.gitignore vendored
View file

@ -6,7 +6,10 @@
*.dylib *.dylib
*~ *~
.pc .pc
# IDEs
.vscode .vscode
.idea
# If vendor is included, allow prebuilt (wasm?) libraries. # If vendor is included, allow prebuilt (wasm?) libraries.
!vendor/**/*.so !vendor/**/*.so
@ -34,7 +37,7 @@ test/coverage/*
*.swo *.swo
# Dependencies are not vendored by default, but a tarball is created by "make vendor" # Dependencies are not vendored by default, but a tarball is created by "make vendor"
# and provided in the release. Used by freebsd, gentoo, etc. # and provided in the release. Used by gentoo, etc.
vendor/ vendor/
vendor.tgz vendor.tgz

View file

@ -100,6 +100,7 @@ API Client:
{{- if .API.Server }} {{- if .API.Server }}
Local API Server{{if and .API.Server.Enable (not (ValueBool .API.Server.Enable))}} (disabled){{end}}: Local API Server{{if and .API.Server.Enable (not (ValueBool .API.Server.Enable))}} (disabled){{end}}:
- Listen URL : {{.API.Server.ListenURI}} - Listen URL : {{.API.Server.ListenURI}}
- Listen Socket : {{.API.Server.ListenSocket}}
- Profile File : {{.API.Server.ProfilesPath}} - Profile File : {{.API.Server.ProfilesPath}}
{{- if .API.Server.TLS }} {{- if .API.Server.TLS }}

View file

@ -44,7 +44,9 @@ func (cli *cliLapi) status() error {
password := strfmt.Password(cfg.API.Client.Credentials.Password) password := strfmt.Password(cfg.API.Client.Credentials.Password)
login := cfg.API.Client.Credentials.Login login := cfg.API.Client.Credentials.Login
apiurl, err := url.Parse(cfg.API.Client.Credentials.URL) origURL := cfg.API.Client.Credentials.URL
apiURL, err := url.Parse(origURL)
if err != nil { if err != nil {
return fmt.Errorf("parsing api url: %w", err) return fmt.Errorf("parsing api url: %w", err)
} }
@ -59,7 +61,7 @@ func (cli *cliLapi) status() error {
return fmt.Errorf("failed to get scenarios: %w", err) return fmt.Errorf("failed to get scenarios: %w", err)
} }
Client, err = apiclient.NewDefaultClient(apiurl, Client, err = apiclient.NewDefaultClient(apiURL,
LAPIURLPrefix, LAPIURLPrefix,
fmt.Sprintf("crowdsec/%s", version.String()), fmt.Sprintf("crowdsec/%s", version.String()),
nil) nil)
@ -74,7 +76,8 @@ func (cli *cliLapi) status() error {
} }
log.Infof("Loaded credentials from %s", cfg.API.Client.CredentialsFilePath) log.Infof("Loaded credentials from %s", cfg.API.Client.CredentialsFilePath)
log.Infof("Trying to authenticate with username %s on %s", login, apiurl) // use the original string because apiURL would print 'http://unix/'
log.Infof("Trying to authenticate with username %s on %s", login, origURL)
_, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
if err != nil { if err != nil {
@ -101,23 +104,7 @@ func (cli *cliLapi) register(apiURL string, outputFile string, machine string) e
password := strfmt.Password(generatePassword(passwordLength)) password := strfmt.Password(generatePassword(passwordLength))
if apiURL == "" { apiurl, err := prepareAPIURL(cfg.API.Client, apiURL)
if cfg.API.Client == nil || cfg.API.Client.Credentials == nil || cfg.API.Client.Credentials.URL == "" {
return fmt.Errorf("no Local API URL. Please provide it in your configuration or with the -u parameter")
}
apiURL = cfg.API.Client.Credentials.URL
}
/*URL needs to end with /, but user doesn't care*/
if !strings.HasSuffix(apiURL, "/") {
apiURL += "/"
}
/*URL needs to start with http://, but user doesn't care*/
if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") {
apiURL = "http://" + apiURL
}
apiurl, err := url.Parse(apiURL)
if err != nil { if err != nil {
return fmt.Errorf("parsing api url: %w", err) return fmt.Errorf("parsing api url: %w", err)
} }
@ -173,13 +160,36 @@ func (cli *cliLapi) register(apiURL string, outputFile string, machine string) e
return nil return nil
} }
// prepareAPIURL checks/fixes a LAPI connection url (http, https or socket) and returns an URL struct
func prepareAPIURL(clientCfg *csconfig.LocalApiClientCfg, apiURL string) (*url.URL, error) {
if apiURL == "" {
if clientCfg == nil || clientCfg.Credentials == nil || clientCfg.Credentials.URL == "" {
return nil, errors.New("no Local API URL. Please provide it in your configuration or with the -u parameter")
}
apiURL = clientCfg.Credentials.URL
}
// URL needs to end with /, but user doesn't care
if !strings.HasSuffix(apiURL, "/") {
apiURL += "/"
}
// URL needs to start with http://, but user doesn't care
if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "/") {
apiURL = "http://" + apiURL
}
return url.Parse(apiURL)
}
func (cli *cliLapi) newStatusCmd() *cobra.Command { func (cli *cliLapi) newStatusCmd() *cobra.Command {
cmdLapiStatus := &cobra.Command{ cmdLapiStatus := &cobra.Command{
Use: "status", Use: "status",
Short: "Check authentication to Local API (LAPI)", Short: "Check authentication to Local API (LAPI)",
Args: cobra.MinimumNArgs(0), Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true, DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(_ *cobra.Command, _ []string) error {
return cli.status() return cli.status()
}, },
} }

View file

@ -0,0 +1,49 @@
package main
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
)
func TestPrepareAPIURL_NoProtocol(t *testing.T) {
url, err := prepareAPIURL(nil, "localhost:81")
require.NoError(t, err)
assert.Equal(t, "http://localhost:81/", url.String())
}
func TestPrepareAPIURL_Http(t *testing.T) {
url, err := prepareAPIURL(nil, "http://localhost:81")
require.NoError(t, err)
assert.Equal(t, "http://localhost:81/", url.String())
}
func TestPrepareAPIURL_Https(t *testing.T) {
url, err := prepareAPIURL(nil, "https://localhost:81")
require.NoError(t, err)
assert.Equal(t, "https://localhost:81/", url.String())
}
func TestPrepareAPIURL_UnixSocket(t *testing.T) {
url, err := prepareAPIURL(nil, "/path/socket")
require.NoError(t, err)
assert.Equal(t, "/path/socket/", url.String())
}
func TestPrepareAPIURL_Empty(t *testing.T) {
_, err := prepareAPIURL(nil, "")
require.Error(t, err)
}
func TestPrepareAPIURL_Empty_ConfigOverride(t *testing.T) {
url, err := prepareAPIURL(&csconfig.LocalApiClientCfg{
Credentials: &csconfig.ApiCredentialsCfg{
URL: "localhost:80",
},
}, "")
require.NoError(t, err)
assert.Equal(t, "http://localhost:80/", url.String())
}

View file

@ -318,8 +318,8 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
if apiURL == "" { if apiURL == "" {
if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" { if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" {
apiURL = clientCfg.Credentials.URL apiURL = clientCfg.Credentials.URL
} else if serverCfg != nil && serverCfg.ListenURI != "" { } else if serverCfg.ClientURL() != "" {
apiURL = "http://" + serverCfg.ListenURI apiURL = serverCfg.ClientURL()
} else { } else {
return errors.New("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") return errors.New("unable to dump an api URL. Please provide it in your configuration or with the -u parameter")
} }

View file

@ -22,8 +22,7 @@ def test_missing_key_file(crowdsec, flavor):
} }
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
# XXX: this message appears twice, is that normal? cs.wait_for_log("*local API server stopped with error: missing TLS key file*")
cs.wait_for_log("*while starting API server: missing TLS key file*")
def test_missing_cert_file(crowdsec, flavor): def test_missing_cert_file(crowdsec, flavor):
@ -35,7 +34,7 @@ def test_missing_cert_file(crowdsec, flavor):
} }
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
cs.wait_for_log("*while starting API server: missing TLS cert file*") cs.wait_for_log("*local API server stopped with error: missing TLS cert file*")
def test_tls_missing_ca(crowdsec, flavor, certs_dir): def test_tls_missing_ca(crowdsec, flavor, certs_dir):

View file

@ -70,9 +70,14 @@ func (t *JWTTransport) refreshJwtToken() error {
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}
client := &http.Client{ client := &http.Client{
Transport: &retryRoundTripper{ Transport: &retryRoundTripper{
next: http.DefaultTransport, next: transport,
maxAttempts: 5, maxAttempts: 5,
withBackOff: true, withBackOff: true,
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
@ -153,7 +158,7 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)
req.Header.Add("User-Agent", t.UserAgent) req.Header.Add("User-Agent", t.UserAgent)
} }
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) req.Header.Add("Authorization", "Bearer "+t.Token)
return req, nil return req, nil
} }
@ -166,7 +171,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
} }
if log.GetLevel() >= log.TraceLevel { if log.GetLevel() >= log.TraceLevel {
//requestToDump := cloneRequest(req) // requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true) dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump)) log.Tracef("req-jwt: %s", string(dump))
} }

View file

@ -5,8 +5,10 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
@ -67,12 +69,18 @@ func NewClient(config *Config) (*ApiClient, error) {
MachineID: &config.MachineID, MachineID: &config.MachineID,
Password: &config.Password, Password: &config.Password,
Scenarios: config.Scenarios, Scenarios: config.Scenarios,
URL: config.URL,
UserAgent: config.UserAgent, UserAgent: config.UserAgent,
VersionPrefix: config.VersionPrefix, VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario, UpdateScenario: config.UpdateScenario,
} }
transport, baseURL := createTransport(config.URL)
if transport != nil {
t.Transport = transport
}
t.URL = baseURL
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool tlsconfig.RootCAs = CaCertPool
@ -84,7 +92,7 @@ func NewClient(config *Config) (*ApiClient, error) {
ht.TLSClientConfig = &tlsconfig ht.TLSClientConfig = &tlsconfig
} }
c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} c := &ApiClient{client: t.Client(), BaseURL: baseURL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common) c.Alerts = (*AlertsService)(&c.common)
@ -98,9 +106,14 @@ func NewClient(config *Config) (*ApiClient, error) {
} }
func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) { func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) {
transport, baseURL := createTransport(URL)
if client == nil { if client == nil {
client = &http.Client{} client = &http.Client{}
if transport != nil {
client.Transport = transport
} else {
if ht, ok := http.DefaultTransport.(*http.Transport); ok { if ht, ok := http.DefaultTransport.(*http.Transport); ok {
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool tlsconfig.RootCAs = CaCertPool
@ -113,8 +126,9 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
client.Transport = ht client.Transport = ht
} }
} }
}
c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix} c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: prefix}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common) c.Alerts = (*AlertsService)(&c.common)
@ -128,10 +142,13 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
} }
func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
transport, baseURL := createTransport(config.URL)
if client == nil { if client == nil {
client = &http.Client{} client = &http.Client{}
} if transport != nil {
client.Transport = transport
} else {
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil { if Cert != nil {
tlsconfig.RootCAs = CaCertPool tlsconfig.RootCAs = CaCertPool
@ -139,7 +156,12 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
} }
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} }
} else if client.Transport == nil && transport != nil {
client.Transport = transport
}
c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common) c.Alerts = (*AlertsService)(&c.common)
@ -158,11 +180,31 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
return c, nil return c, nil
} }
func createTransport(url *url.URL) (*http.Transport, *url.URL) {
urlString := url.String()
// TCP transport
if !strings.HasPrefix(urlString, "/") {
return nil, url
}
// Unix transport
url.Path = "/"
url.Host = "unix"
url.Scheme = "http"
return &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", strings.TrimSuffix(urlString, "/"))
},
}, url
}
type Response struct { type Response struct {
Response *http.Response Response *http.Response
//add our pagination stuff // add our pagination stuff
//NextPage int // NextPage int
//... // ...
} }
func newResponse(r *http.Response) *Response { func newResponse(r *http.Response) *Response {
@ -170,14 +212,14 @@ func newResponse(r *http.Response) *Response {
} }
type ListOpts struct { type ListOpts struct {
//Page int // Page int
//PerPage int // PerPage int
} }
type DeleteOpts struct { type DeleteOpts struct {
//?? // ??
} }
type AddOpts struct { type AddOpts struct {
//?? // ??
} }

View file

@ -3,10 +3,13 @@ package apiclient
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"path"
"runtime" "runtime"
"strings"
"testing" "testing"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -34,12 +37,50 @@ func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) {
apiHandler := http.NewServeMux() apiHandler := http.NewServeMux()
apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux))
// server is a test HTTP server used to provide mock API responses.
server := httptest.NewServer(apiHandler) server := httptest.NewServer(apiHandler)
return mux, server.URL, server.Close return mux, server.URL, server.Close
} }
// toUNCPath converts a Windows file path to a UNC path.
// This is necessary because the Go http package does not support Windows file paths.
func toUNCPath(path string) (string, error) {
colonIdx := strings.Index(path, ":")
if colonIdx == -1 {
return "", fmt.Errorf("invalid path format, missing drive letter: %s", path)
}
// URL parsing does not like backslashes
remaining := strings.ReplaceAll(path[colonIdx+1:], "\\", "/")
uncPath := "//localhost/" + path[:colonIdx] + "$" + remaining
return uncPath, nil
}
func setupUnixSocketWithPrefix(socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) {
var err error
if runtime.GOOS == "windows" {
socket, err = toUNCPath(socket)
if err != nil {
log.Fatalf("converting to UNC path: %s", err)
}
}
mux = http.NewServeMux()
baseURLPath := "/" + urlPrefix
apiHandler := http.NewServeMux()
apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux))
server := httptest.NewUnstartedServer(apiHandler)
l, _ := net.Listen("unix", socket)
_ = server.Listener.Close()
server.Listener = l
server.Start()
return mux, socket, server.Close
}
func testMethod(t *testing.T, r *http.Request, want string) { func testMethod(t *testing.T, r *http.Request, want string) {
t.Helper() t.Helper()
assert.Equal(t, want, r.Method) assert.Equal(t, want, r.Method)
@ -77,6 +118,49 @@ func TestNewClientOk(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
} }
func TestNewClientOk_UnixSocket(t *testing.T) {
tmpDir := t.TempDir()
socket := path.Join(tmpDir, "socket")
mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
defer teardown()
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
w.WriteHeader(http.StatusOK)
})
_, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
if err != nil {
t.Fatalf("test Unable to list alerts : %+v", err)
}
if resp.Response.StatusCode != http.StatusOK {
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
}
}
func TestNewClientKo(t *testing.T) { func TestNewClientKo(t *testing.T) {
mux, urlx, teardown := setup() mux, urlx, teardown := setup()
defer teardown() defer teardown()
@ -131,6 +215,33 @@ func TestNewDefaultClient(t *testing.T) {
log.Printf("err-> %s", err) log.Printf("err-> %s", err)
} }
func TestNewDefaultClient_UnixSocket(t *testing.T) {
tmpDir := t.TempDir()
socket := path.Join(tmpDir, "socket")
mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
defer teardown()
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
client, err := NewDefaultClient(apiURL, "/v1", "", nil)
if err != nil {
t.Fatalf("new api client: %s", err)
}
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"code": 401, "message" : "brr"}`))
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `performing request: API error: brr`)
log.Printf("err-> %s", err)
}
func TestNewClientRegisterKO(t *testing.T) { func TestNewClientRegisterKO(t *testing.T) {
apiURL, err := url.Parse("http://127.0.0.1:4242/") apiURL, err := url.Parse("http://127.0.0.1:4242/")
require.NoError(t, err) require.NoError(t, err)
@ -143,10 +254,10 @@ func TestNewClientRegisterKO(t *testing.T) {
VersionPrefix: "v1", VersionPrefix: "v1",
}, &http.Client{}) }, &http.Client{})
if runtime.GOOS != "windows" { if runtime.GOOS == "windows" {
cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused")
} else {
cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.") cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.")
} else {
cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused")
} }
} }
@ -178,6 +289,41 @@ func TestNewClientRegisterOK(t *testing.T) {
log.Printf("->%T", client) log.Printf("->%T", client)
} }
func TestNewClientRegisterOK_UnixSocket(t *testing.T) {
log.SetLevel(log.TraceLevel)
tmpDir := t.TempDir()
socket := path.Join(tmpDir, "socket")
mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
defer teardown()
/*mock login*/
mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
client, err := RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
if err != nil {
t.Fatalf("while registering client : %s", err)
}
log.Printf("->%T", client)
}
func TestNewClientBadAnswer(t *testing.T) { func TestNewClientBadAnswer(t *testing.T) {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)

View file

@ -32,6 +32,7 @@ const keyLength = 32
type APIServer struct { type APIServer struct {
URL string URL string
UnixSocket string
TLS *csconfig.TLSCfg TLS *csconfig.TLSCfg
dbClient *database.Client dbClient *database.Client
logFile string logFile string
@ -66,7 +67,7 @@ func recoverFromPanic(c *gin.Context) {
// because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go
// and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them
if strErr, ok := err.(error); ok { if strErr, ok := err.(error); ok {
//stolen from http2/server.go in x/net // stolen from http2/server.go in x/net
var ( var (
errClientDisconnected = errors.New("client disconnected") errClientDisconnected = errors.New("client disconnected")
errClosedBody = errors.New("body closed by handler") errClosedBody = errors.New("body closed by handler")
@ -124,10 +125,10 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
logger := &lumberjack.Logger{ logger := &lumberjack.Logger{
Filename: logFile, Filename: logFile,
MaxSize: 500, //megabytes MaxSize: 500, // megabytes
MaxBackups: 3, MaxBackups: 3,
MaxAge: 28, //days MaxAge: 28, // days
Compress: true, //disabled by default Compress: true, // disabled by default
} }
if config.LogMaxSize != 0 { if config.LogMaxSize != 0 {
@ -176,6 +177,13 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
router.ForwardedByClientIP = false router.ForwardedByClientIP = false
// set the remore address of the request to 127.0.0.1 if it comes from a unix socket
router.Use(func(c *gin.Context) {
if c.Request.RemoteAddr == "@" {
c.Request.RemoteAddr = "127.0.0.1:65535"
}
})
if config.TrustedProxies != nil && config.UseForwardedForHeaders { if config.TrustedProxies != nil && config.UseForwardedForHeaders {
if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil { if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil {
return nil, fmt.Errorf("while setting trusted_proxies: %w", err) return nil, fmt.Errorf("while setting trusted_proxies: %w", err)
@ -267,6 +275,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
return &APIServer{ return &APIServer{
URL: config.ListenURI, URL: config.ListenURI,
UnixSocket: config.ListenSocket,
TLS: config.TLS, TLS: config.TLS,
logFile: logFile, logFile: logFile,
dbClient: dbClient, dbClient: dbClient,
@ -317,11 +326,11 @@ func (s *APIServer) Run(apiReady chan bool) error {
return nil return nil
}) })
//csConfig.API.Server.ConsoleConfig.ShareCustomScenarios // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
if s.apic.apiClient.IsEnrolled() { if s.apic.apiClient.IsEnrolled() {
if s.consoleConfig.IsPAPIEnabled() { if s.consoleConfig.IsPAPIEnabled() {
if s.papi.URL != "" { if s.papi.URL != "" {
log.Infof("Starting PAPI decision receiver") log.Info("Starting PAPI decision receiver")
s.papi.pullTomb.Go(func() error { s.papi.pullTomb.Go(func() error {
if err := s.papi.Pull(); err != nil { if err := s.papi.Pull(); err != nil {
log.Errorf("papi pull: %s", err) log.Errorf("papi pull: %s", err)
@ -353,29 +362,31 @@ func (s *APIServer) Run(apiReady chan bool) error {
}) })
} }
s.httpServerTomb.Go(func() error { s.listenAndServeURL(apiReady); return nil }) s.httpServerTomb.Go(func() error {
return s.listenAndServeLAPI(apiReady)
})
if err := s.httpServerTomb.Wait(); err != nil {
return fmt.Errorf("local API server stopped with error: %w", err)
}
return nil return nil
} }
// listenAndServeURL starts the http server and blocks until it's closed // listenAndServeLAPI starts the http server and blocks until it's closed
// it also updates the URL field with the actual address the server is listening on // it also updates the URL field with the actual address the server is listening on
// it's meant to be run in a separate goroutine // it's meant to be run in a separate goroutine
func (s *APIServer) listenAndServeURL(apiReady chan bool) { func (s *APIServer) listenAndServeLAPI(apiReady chan bool) error {
serverError := make(chan error, 1) var (
tcpListener net.Listener
unixListener net.Listener
err error
serverError = make(chan error, 2)
listenerClosed = make(chan struct{})
)
go func() { startServer := func(listener net.Listener, canTLS bool) {
listener, err := net.Listen("tcp", s.URL) if canTLS && s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
if err != nil {
serverError <- fmt.Errorf("listening on %s: %w", s.URL, err)
return
}
s.URL = listener.Addr().String()
log.Infof("CrowdSec Local API listening on %s", s.URL)
apiReady <- true
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
if s.TLS.KeyFilePath == "" { if s.TLS.KeyFilePath == "" {
serverError <- errors.New("missing TLS key file") serverError <- errors.New("missing TLS key file")
return return
@ -391,25 +402,71 @@ func (s *APIServer) listenAndServeURL(apiReady chan bool) {
err = s.httpServer.Serve(listener) err = s.httpServer.Serve(listener)
} }
if err != nil && err != http.ErrServerClosed { switch {
serverError <- fmt.Errorf("while serving local API: %w", err) case errors.Is(err, http.ErrServerClosed):
break
case err != nil:
serverError <- err
}
}
// Starting TCP listener
go func() {
if s.URL == "" {
return return
} }
tcpListener, err = net.Listen("tcp", s.URL)
if err != nil {
serverError <- fmt.Errorf("listening on %s: %w", s.URL, err)
return
}
log.Infof("CrowdSec Local API listening on %s", s.URL)
startServer(tcpListener, true)
}() }()
// Starting Unix socket listener
go func() {
if s.UnixSocket == "" {
return
}
_ = os.RemoveAll(s.UnixSocket)
unixListener, err = net.Listen("unix", s.UnixSocket)
if err != nil {
serverError <- fmt.Errorf("while creating unix listener: %w", err)
return
}
log.Infof("CrowdSec Local API listening on Unix socket %s", s.UnixSocket)
startServer(unixListener, false)
}()
apiReady <- true
select { select {
case err := <-serverError: case err := <-serverError:
log.Fatalf("while starting API server: %s", err) return err
case <-s.httpServerTomb.Dying(): case <-s.httpServerTomb.Dying():
log.Infof("Shutting down API server") log.Info("Shutting down API server")
// do we need a graceful shutdown here?
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil { if err := s.httpServer.Shutdown(ctx); err != nil {
log.Errorf("while shutting down http server: %s", err) log.Errorf("while shutting down http server: %v", err)
}
close(listenerClosed)
case <-listenerClosed:
if s.UnixSocket != "" {
_ = os.RemoveAll(s.UnixSocket)
} }
} }
return nil
} }
func (s *APIServer) Close() { func (s *APIServer) Close() {
@ -437,7 +494,7 @@ func (s *APIServer) Shutdown() error {
} }
} }
//close io.writer logger given to gin // close io.writer logger given to gin
if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok { if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok {
pipe.Close() pipe.Close()
} }

View file

@ -174,7 +174,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
// if coming from cscli, alert already has decisions // if coming from cscli, alert already has decisions
if len(alert.Decisions) != 0 { if len(alert.Decisions) != 0 {
//alert already has a decision (cscli decisions add etc.), generate uuid here // alert already has a decision (cscli decisions add etc.), generate uuid here
for _, decision := range alert.Decisions { for _, decision := range alert.Decisions {
decision.UUID = uuid.NewString() decision.UUID = uuid.NewString()
} }
@ -323,12 +323,13 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) {
var err error var err error
incomingIP := gctx.ClientIP() incomingIP := gctx.ClientIP()
if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) { if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) {
gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)})
return return
} }
decisionIDStr := gctx.Param("alert_id") decisionIDStr := gctx.Param("alert_id")
decisionID, err := strconv.Atoi(decisionIDStr) decisionID, err := strconv.Atoi(decisionIDStr)
if err != nil { if err != nil {
gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"})
@ -349,7 +350,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) {
// DeleteAlerts deletes alerts from the database based on the specified filter // DeleteAlerts deletes alerts from the database based on the specified filter
func (c *Controller) DeleteAlerts(gctx *gin.Context) { func (c *Controller) DeleteAlerts(gctx *gin.Context) {
incomingIP := gctx.ClientIP() incomingIP := gctx.ClientIP()
if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) { if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) {
gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)})
return return
} }

View file

@ -2,7 +2,9 @@ package v1
import ( import (
"errors" "errors"
"net"
"net/http" "net/http"
"strings"
jwt "github.com/appleboy/gin-jwt/v2" jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -25,6 +27,14 @@ func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) {
return bouncerInfo, nil return bouncerInfo, nil
} }
func isUnixSocket(c *gin.Context) bool {
if localAddr, ok := c.Request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
return strings.HasPrefix(localAddr.Network(), "unix")
}
return false
}
func getMachineIDFromContext(ctx *gin.Context) (string, error) { func getMachineIDFromContext(ctx *gin.Context) (string, error) {
claims := jwt.ExtractClaims(ctx) claims := jwt.ExtractClaims(ctx)
if claims == nil { if claims == nil {
@ -47,8 +57,16 @@ func getMachineIDFromContext(ctx *gin.Context) (string, error) {
func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc {
return func(gctx *gin.Context) { return func(gctx *gin.Context) {
if !option {
return
}
if isUnixSocket(gctx) {
return
}
incomingIP := gctx.ClientIP() incomingIP := gctx.ClientIP()
if option && incomingIP != "127.0.0.1" && incomingIP != "::1" { if incomingIP != "127.0.0.1" && incomingIP != "::1" {
gctx.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) gctx.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
gctx.Abort() gctx.Abort()
} }

View file

@ -82,10 +82,10 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) bouncer, err := a.DbClient.SelectBouncerByName(bouncerName)
//This is likely not the proper way, but isNotFound does not seem to work // This is likely not the proper way, but isNotFound does not seem to work
if err != nil && strings.Contains(err.Error(), "bouncer not found") { if err != nil && strings.Contains(err.Error(), "bouncer not found") {
//Because we have a valid cert, automatically create the bouncer in the database if it does not exist // Because we have a valid cert, automatically create the bouncer in the database if it does not exist
//Set a random API key, but it will never be used // Set a random API key, but it will never be used
apiKey, err := GenerateAPIKey(dummyAPIKeySize) apiKey, err := GenerateAPIKey(dummyAPIKeySize)
if err != nil { if err != nil {
logger.Errorf("error generating mock api key: %s", err) logger.Errorf("error generating mock api key: %s", err)
@ -100,11 +100,11 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
return nil return nil
} }
} else if err != nil { } else if err != nil {
//error while selecting bouncer // error while selecting bouncer
logger.Errorf("while selecting bouncers: %s", err) logger.Errorf("while selecting bouncers: %s", err)
return nil return nil
} else if bouncer.AuthType != types.TlsAuthType { } else if bouncer.AuthType != types.TlsAuthType {
//bouncer was found in DB // bouncer was found in DB
logger.Errorf("bouncer isn't allowed to auth by TLS") logger.Errorf("bouncer isn't allowed to auth by TLS")
return nil return nil
} }
@ -139,8 +139,10 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var bouncer *ent.Bouncer var bouncer *ent.Bouncer
clientIP := c.ClientIP()
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"ip": c.ClientIP(), "ip": clientIP,
}) })
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
@ -152,6 +154,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
if bouncer == nil { if bouncer == nil {
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
} }
@ -160,7 +163,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
}) })
if bouncer.IPAddress == "" { if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
@ -169,11 +172,11 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
} }
} }
//Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided
if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead {
log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress)
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
@ -199,6 +202,5 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
} }
c.Set(BouncerContextKey, bouncer) c.Set(BouncerContextKey, bouncer)
c.Next()
} }
} }

View file

@ -61,6 +61,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
if j.TlsAuth == nil { if j.TlsAuth == nil {
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return nil, errors.New("TLS auth is not configured") return nil, errors.New("TLS auth is not configured")
} }
@ -76,7 +77,8 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
if !validCert { if !validCert {
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return nil, fmt.Errorf("failed cert authentication")
return nil, errors.New("failed cert authentication")
} }
ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
@ -85,9 +87,9 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
Where(machine.MachineId(ret.machineID)). Where(machine.MachineId(ret.machineID)).
First(j.DbClient.CTX) First(j.DbClient.CTX)
if ent.IsNotFound(err) { if ent.IsNotFound(err) {
//Machine was not found, let's create it // Machine was not found, let's create it
log.Infof("machine %s not found, create it", ret.machineID) log.Infof("machine %s not found, create it", ret.machineID)
//let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) // let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli)
pwd, err := GenerateAPIKey(dummyAPIKeySize) pwd, err := GenerateAPIKey(dummyAPIKeySize)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -95,7 +97,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
"cn": extractedCN, "cn": extractedCN,
}).Errorf("error generating password: %s", err) }).Errorf("error generating password: %s", err)
return nil, fmt.Errorf("error generating password") return nil, errors.New("error generating password")
} }
password := strfmt.Password(pwd) password := strfmt.Password(pwd)
@ -110,6 +112,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
if ret.clientMachine.AuthType != types.TlsAuthType { if ret.clientMachine.AuthType != types.TlsAuthType {
return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
} }
ret.machineID = ret.clientMachine.MachineId ret.machineID = ret.clientMachine.MachineId
} }
@ -213,18 +216,20 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
} }
} }
clientIP := c.ClientIP()
if auth.clientMachine.IpAddress == "" { if auth.clientMachine.IpAddress == "" {
err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
} }
if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" { if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" {
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress) log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress)
err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
@ -233,13 +238,14 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
useragent := strings.Split(c.Request.UserAgent(), "/") useragent := strings.Split(c.Request.UserAgent(), "/")
if len(useragent) != 2 { if len(useragent) != 2 {
log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP()) log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), clientIP)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
log.Errorf("bad user agent from : %s", c.ClientIP()) log.Errorf("bad user agent from : %s", clientIP)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
@ -323,8 +329,9 @@ func NewJWT(dbClient *database.Client) (*JWT, error) {
errInit := ret.MiddlewareInit() errInit := ret.MiddlewareInit()
if errInit != nil { if errInit != nil {
return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) return &JWT{}, errors.New("authMiddleware.MiddlewareInit() Error:" + errInit.Error())
} }
jwtMiddleware.Middleware = ret jwtMiddleware.Middleware = ret
return jwtMiddleware, nil return jwtMiddleware, nil

View file

@ -141,12 +141,25 @@ func (l *LocalApiClientCfg) Load() error {
} }
if l.Credentials != nil && l.Credentials.URL != "" { if l.Credentials != nil && l.Credentials.URL != "" {
if !strings.HasSuffix(l.Credentials.URL, "/") { // don't append a trailing slash if the URL is a unix socket
if strings.HasPrefix(l.Credentials.URL, "http") && !strings.HasSuffix(l.Credentials.URL, "/") {
l.Credentials.URL += "/" l.Credentials.URL += "/"
} }
} }
if l.Credentials.Login != "" && (l.Credentials.CertPath != "" || l.Credentials.KeyPath != "") { // is the configuration asking for client authentication via TLS?
credTLSClientAuth := l.Credentials.CertPath != "" || l.Credentials.KeyPath != ""
// is the configuration asking for TLS encryption and server authentication?
credTLS := credTLSClientAuth || l.Credentials.CACertPath != ""
credSocket := strings.HasPrefix(l.Credentials.URL, "/")
if credTLS && credSocket {
return errors.New("cannot use TLS with a unix socket")
}
if credTLSClientAuth && l.Credentials.Login != "" {
return errors.New("user/password authentication and TLS authentication are mutually exclusive") return errors.New("user/password authentication and TLS authentication are mutually exclusive")
} }
@ -187,10 +200,10 @@ func (l *LocalApiClientCfg) Load() error {
return nil return nil
} }
func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) { func (c *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) {
trustedIPs := make([]net.IPNet, 0) trustedIPs := make([]net.IPNet, 0)
for _, ip := range lapiCfg.TrustedIPs { for _, ip := range c.TrustedIPs {
cidr := toValidCIDR(ip) cidr := toValidCIDR(ip)
_, ipNet, err := net.ParseCIDR(cidr) _, ipNet, err := net.ParseCIDR(cidr)
@ -225,6 +238,7 @@ type CapiWhitelist struct {
type LocalApiServerCfg struct { type LocalApiServerCfg struct {
Enable *bool `yaml:"enable"` Enable *bool `yaml:"enable"`
ListenURI string `yaml:"listen_uri,omitempty"` // 127.0.0.1:8080 ListenURI string `yaml:"listen_uri,omitempty"` // 127.0.0.1:8080
ListenSocket string `yaml:"listen_socket,omitempty"`
TLS *TLSCfg `yaml:"tls"` TLS *TLSCfg `yaml:"tls"`
DbConfig *DatabaseCfg `yaml:"-"` DbConfig *DatabaseCfg `yaml:"-"`
LogDir string `yaml:"-"` LogDir string `yaml:"-"`
@ -248,6 +262,22 @@ type LocalApiServerCfg struct {
CapiWhitelists *CapiWhitelist `yaml:"-"` CapiWhitelists *CapiWhitelist `yaml:"-"`
} }
func (c *LocalApiServerCfg) ClientURL() string {
if c == nil {
return ""
}
if c.ListenSocket != "" {
return c.ListenSocket
}
if c.ListenURI != "" {
return "http://" + c.ListenURI
}
return ""
}
func (c *Config) LoadAPIServer(inCli bool) error { func (c *Config) LoadAPIServer(inCli bool) error {
if c.DisableAPI { if c.DisableAPI {
log.Warning("crowdsec local API is disabled from flag") log.Warning("crowdsec local API is disabled from flag")
@ -255,7 +285,9 @@ func (c *Config) LoadAPIServer(inCli bool) error {
if c.API.Server == nil { if c.API.Server == nil {
log.Warning("crowdsec local API is disabled") log.Warning("crowdsec local API is disabled")
c.DisableAPI = true c.DisableAPI = true
return nil return nil
} }
@ -266,6 +298,7 @@ func (c *Config) LoadAPIServer(inCli bool) error {
if !*c.API.Server.Enable { if !*c.API.Server.Enable {
log.Warning("crowdsec local API is disabled because 'enable' is set to false") log.Warning("crowdsec local API is disabled because 'enable' is set to false")
c.DisableAPI = true c.DisableAPI = true
} }
@ -273,8 +306,8 @@ func (c *Config) LoadAPIServer(inCli bool) error {
return nil return nil
} }
if c.API.Server.ListenURI == "" { if c.API.Server.ListenURI == "" && c.API.Server.ListenSocket == "" {
return errors.New("no listen_uri specified") return errors.New("no listen_uri or listen_socket specified")
} }
// inherit log level from common, then api->server // inherit log level from common, then api->server
@ -393,21 +426,21 @@ func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) {
return ret, nil return ret, nil
} }
func (s *LocalApiServerCfg) LoadCapiWhitelists() error { func (c *LocalApiServerCfg) LoadCapiWhitelists() error {
if s.CapiWhitelistsPath == "" { if c.CapiWhitelistsPath == "" {
return nil return nil
} }
fd, err := os.Open(s.CapiWhitelistsPath) fd, err := os.Open(c.CapiWhitelistsPath)
if err != nil { if err != nil {
return fmt.Errorf("while opening capi whitelist file: %w", err) return fmt.Errorf("while opening capi whitelist file: %w", err)
} }
defer fd.Close() defer fd.Close()
s.CapiWhitelists, err = parseCapiWhitelists(fd) c.CapiWhitelists, err = parseCapiWhitelists(fd)
if err != nil { if err != nil {
return fmt.Errorf("while parsing capi whitelist file '%s': %w", s.CapiWhitelistsPath, err) return fmt.Errorf("while parsing capi whitelist file '%s': %w", c.CapiWhitelistsPath, err)
} }
return nil return nil

View file

@ -32,20 +32,20 @@ teardown() {
} }
@test "lapi (no .api.server.listen_uri)" { @test "lapi (no .api.server.listen_uri)" {
rune -0 config_set 'del(.api.server.listen_uri)' rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)'
rune -1 "${CROWDSEC}" -no-cs rune -1 "${CROWDSEC}" -no-cs
assert_stderr --partial "no listen_uri specified" assert_stderr --partial "no listen_uri or listen_socket specified"
} }
@test "lapi (bad .api.server.listen_uri)" { @test "lapi (bad .api.server.listen_uri)" {
rune -0 config_set '.api.server.listen_uri="127.0.0.1:-80"' rune -0 config_set 'del(.api.server.listen_socket) | .api.server.listen_uri="127.0.0.1:-80"'
rune -1 "${CROWDSEC}" -no-cs rune -1 "${CROWDSEC}" -no-cs
assert_stderr --partial "while starting API server: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port" assert_stderr --partial "local API server stopped with error: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port"
} }
@test "lapi (listen on random port)" { @test "lapi (listen on random port)" {
config_set '.common.log_media="stdout"' config_set '.common.log_media="stdout"'
rune -0 config_set '.api.server.listen_uri="127.0.0.1:0"' rune -0 config_set 'del(.api.server.listen_socket) | .api.server.listen_uri="127.0.0.1:0"'
rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "${CROWDSEC}" -no-cs rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "${CROWDSEC}" -no-cs
} }

View file

@ -100,10 +100,14 @@ teardown() {
# check that LAPI configuration is loaded (human and json, not shows in raw) # check that LAPI configuration is loaded (human and json, not shows in raw)
sock=$(config_get '.api.server.listen_socket')
rune -0 cscli config show -o human rune -0 cscli config show -o human
assert_line --regexp ".*- URL +: http://127.0.0.1:8080/" assert_line --regexp ".*- URL +: http://127.0.0.1:8080/"
assert_line --regexp ".*- Login +: githubciXXXXXXXXXXXXXXXXXXXXXXXX([a-zA-Z0-9]{16})?" assert_line --regexp ".*- Login +: githubciXXXXXXXXXXXXXXXXXXXXXXXX([a-zA-Z0-9]{16})?"
assert_line --regexp ".*- Credentials File +: .*/local_api_credentials.yaml" assert_line --regexp ".*- Credentials File +: .*/local_api_credentials.yaml"
assert_line --regexp ".*- Listen URL +: 127.0.0.1:8080"
assert_line --regexp ".*- Listen Socket +: $sock"
rune -0 cscli config show -o json rune -0 cscli config show -o json
rune -0 jq -c '.API.Client.Credentials | [.url,.login[0:32]]' <(output) rune -0 jq -c '.API.Client.Credentials | [.url,.login[0:32]]' <(output)
@ -212,7 +216,6 @@ teardown() {
assert_stderr --partial "Loaded credentials from" assert_stderr --partial "Loaded credentials from"
assert_stderr --partial "Trying to authenticate with username" assert_stderr --partial "Trying to authenticate with username"
assert_stderr --partial " on http://127.0.0.1:8080/"
assert_stderr --partial "You can successfully interact with Local API (LAPI)" assert_stderr --partial "You can successfully interact with Local API (LAPI)"
} }

158
test/bats/09_socket.bats Normal file
View file

@ -0,0 +1,158 @@
#!/usr/bin/env bats
# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si:
set -u
setup_file() {
load "../lib/setup_file.sh"
sockdir=$(TMPDIR="$BATS_FILE_TMPDIR" mktemp -u)
export sockdir
mkdir -p "$sockdir"
socket="$sockdir/crowdsec_api.sock"
export socket
LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path')
export LOCAL_API_CREDENTIALS
}
teardown_file() {
load "../lib/teardown_file.sh"
}
setup() {
load "../lib/setup.sh"
load "../lib/bats-file/load.bash"
./instance-data load
config_set ".api.server.listen_socket=strenv(socket)"
}
teardown() {
./instance-crowdsec stop
}
#----------
@test "cscli - connects from existing machine with socket" {
config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)"
./instance-crowdsec start
rune -0 cscli lapi status
assert_stderr --regexp "Trying to authenticate with username .* on $socket"
assert_stderr --partial "You can successfully interact with Local API (LAPI)"
}
@test "crowdsec - listen on both socket and TCP" {
./instance-crowdsec start
rune -0 cscli lapi status
assert_stderr --regexp "Trying to authenticate with username .* on http://127.0.0.1:8080/"
assert_stderr --partial "You can successfully interact with Local API (LAPI)"
config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)"
rune -0 cscli lapi status
assert_stderr --regexp "Trying to authenticate with username .* on $socket"
assert_stderr --partial "You can successfully interact with Local API (LAPI)"
}
@test "cscli - authenticate new machine with socket" {
# verify that if a listen_uri and a socket are set, the socket is used
# by default when creating a local machine.
rune -0 cscli machines delete "$(cscli machines list -o json | jq -r '.[].machineId')"
# this one should be using the socket
rune -0 cscli machines add --auto --force
using=$(config_get "$LOCAL_API_CREDENTIALS" ".url")
assert [ "$using" = "$socket" ]
# disable the agent because it counts as a first authentication
config_disable_agent
./instance-crowdsec start
# the machine does not have an IP yet
rune -0 cscli machines list -o json
rune -0 jq -r '.[].ipAddress' <(output)
assert_output null
# upon first authentication, it's assigned to localhost
rune -0 cscli lapi status
rune -0 cscli machines list -o json
rune -0 jq -r '.[].ipAddress' <(output)
assert_output 127.0.0.1
}
bouncer_http() {
URI="$1"
curl -fs -H "X-Api-Key: $API_KEY" "http://localhost:8080$URI"
}
bouncer_socket() {
URI="$1"
curl -fs -H "X-Api-Key: $API_KEY" --unix-socket "$socket" "http://localhost$URI"
}
@test "lapi - connects from existing bouncer with socket" {
./instance-crowdsec start
API_KEY=$(cscli bouncers add testbouncer -o raw)
export API_KEY
# the bouncer does not have an IP yet
rune -0 cscli bouncers list -o json
rune -0 jq -r '.[].ip_address' <(output)
assert_output ""
# upon first authentication, it's assigned to localhost
rune -0 bouncer_socket '/v1/decisions'
assert_output 'null'
refute_stderr
rune -0 cscli bouncers list -o json
rune -0 jq -r '.[].ip_address' <(output)
assert_output "127.0.0.1"
# we can still use TCP of course
rune -0 bouncer_http '/v1/decisions'
assert_output 'null'
refute_stderr
}
@test "lapi - listen on socket only" {
config_set "del(.api.server.listen_uri)"
mkdir -p "$sockdir"
# agent is not able to connect right now
config_disable_agent
./instance-crowdsec start
API_KEY=$(cscli bouncers add testbouncer -o raw)
export API_KEY
# now we can't
rune -1 cscli lapi status
assert_stderr --partial "connection refused"
rune -7 bouncer_http '/v1/decisions'
refute_output
refute_stderr
# here we can
config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)"
rune -0 cscli lapi status
rune -0 bouncer_socket '/v1/decisions'
assert_output 'null'
refute_stderr
}

View file

@ -120,7 +120,50 @@ teardown() {
rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output) rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output)
assert_output '[1,"localhost@127.0.0.1",true,"127.0.0.1","tls"]' assert_output '[1,"localhost@127.0.0.1",true,"127.0.0.1","tls"]'
cscli machines delete localhost@127.0.0.1 rune -0 cscli machines delete localhost@127.0.0.1
}
@test "a machine can still connect with a unix socket, no TLS" {
sock=$(config_get '.api.server.listen_socket')
export sock
# an agent is a machine too
config_disable_agent
./instance-crowdsec start
rune -0 cscli machines add with-socket --auto --force
rune -0 cscli lapi status
rune -0 cscli machines list -o json
rune -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output)
assert_output '[1,"with-socket",true,"127.0.0.1","password"]'
# TLS cannot be used with a unix socket
config_set "${CONFIG_DIR}/local_api_credentials.yaml" '
.ca_cert_path=strenv(tmpdir) + "/bundle.pem"
'
rune -1 cscli lapi status
assert_stderr --partial "loading api client: cannot use TLS with a unix socket"
config_set "${CONFIG_DIR}/local_api_credentials.yaml" '
del(.ca_cert_path) |
.key_path=strenv(tmpdir) + "/agent-key.pem"
'
rune -1 cscli lapi status
assert_stderr --partial "loading api client: cannot use TLS with a unix socket"
config_set "${CONFIG_DIR}/local_api_credentials.yaml" '
del(.key_path) |
.cert_path=strenv(tmpdir) + "/agent.pem"
'
rune -1 cscli lapi status
assert_stderr --partial "loading api client: cannot use TLS with a unix socket"
rune -0 cscli machines delete with-socket
} }
@test "invalid cert for agent" { @test "invalid cert for agent" {

View file

@ -58,6 +58,7 @@ config_prepare() {
# remove trailing slash from CONFIG_DIR # remove trailing slash from CONFIG_DIR
# since it's assumed to be missing during the tests # since it's assumed to be missing during the tests
yq e -i ' yq e -i '
.api.server.listen_socket="/run/crowdsec.sock" |
.config_paths.config_dir |= sub("/$", "") .config_paths.config_dir |= sub("/$", "")
' "${CONFIG_DIR}/config.yaml" ' "${CONFIG_DIR}/config.yaml"
} }

View file

@ -57,7 +57,6 @@ config_generate() {
cp ../config/profiles.yaml \ cp ../config/profiles.yaml \
../config/simulation.yaml \ ../config/simulation.yaml \
../config/local_api_credentials.yaml \
../config/online_api_credentials.yaml \ ../config/online_api_credentials.yaml \
"${CONFIG_DIR}/" "${CONFIG_DIR}/"
@ -95,6 +94,7 @@ config_generate() {
.db_config.db_path=strenv(DATA_DIR)+"/crowdsec.db" | .db_config.db_path=strenv(DATA_DIR)+"/crowdsec.db" |
.db_config.use_wal=true | .db_config.use_wal=true |
.api.client.credentials_path=strenv(CONFIG_DIR)+"/local_api_credentials.yaml" | .api.client.credentials_path=strenv(CONFIG_DIR)+"/local_api_credentials.yaml" |
.api.server.listen_socket=strenv(DATA_DIR)+"/crowdsec.sock" |
.api.server.profiles_path=strenv(CONFIG_DIR)+"/profiles.yaml" | .api.server.profiles_path=strenv(CONFIG_DIR)+"/profiles.yaml" |
.api.server.console_path=strenv(CONFIG_DIR)+"/console.yaml" | .api.server.console_path=strenv(CONFIG_DIR)+"/console.yaml" |
del(.api.server.online_client) del(.api.server.online_client)
@ -119,7 +119,8 @@ make_init_data() {
./bin/preload-hub-items ./bin/preload-hub-items
"$CSCLI" --warning machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --auto --force # force TCP, the default would be unix socket
"$CSCLI" --warning machines add githubciXXXXXXXXXXXXXXXXXXXXXXXX --url http://127.0.0.1:8080 --auto --force
mkdir -p "$LOCAL_INIT_DIR" mkdir -p "$LOCAL_INIT_DIR"