Prechádzať zdrojové kódy

cscli: Add user-agent to all hub requests (#2915)

* cscli: Add user-agent to all hub requests

* fix unit test and avoid httpmock

* fix windows test
mmetc 1 rok pred
rodič
commit
2e1ddec107
2 zmenil súbory, kde vykonal 48 pridanie a 28 odobranie
  1. 14 0
      pkg/cwhub/cwhub.go
  2. 34 28
      pkg/cwhub/dataset_test.go

+ 14 - 0
pkg/cwhub/cwhub.go

@@ -7,10 +7,24 @@ import (
 	"sort"
 	"strings"
 	"time"
+
+	"github.com/crowdsecurity/go-cs-lib/version"
 )
 
+// hubTransport wraps a Transport to set a custom User-Agent.
+type hubTransport struct {
+	http.RoundTripper
+}
+
+func (t *hubTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+	req.Header.Set("User-Agent", "crowdsec/"+version.String())
+	return t.RoundTripper.RoundTrip(req)
+}
+
+// hubClient is the HTTP client used to communicate with the CrowdSec Hub.
 var hubClient = &http.Client{
 	Timeout: 120 * time.Second,
+	Transport: &hubTransport{http.DefaultTransport},
 }
 
 // safePath returns a joined path and ensures that it does not escape the base directory.

+ 34 - 28
pkg/cwhub/dataset_test.go

@@ -1,50 +1,56 @@
 package cwhub
 
 import (
+	"io"
+	"net/http"
+	"net/http/httptest"
 	"os"
+	"path/filepath"
 	"testing"
 
-	"github.com/jarcoal/httpmock"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
+
+	"github.com/crowdsecurity/go-cs-lib/cstest"
 )
 
 func TestDownloadFile(t *testing.T) {
-	examplePath := "./example.txt"
-	defer os.Remove(examplePath)
-
-	httpmock.Activate()
-	defer httpmock.DeactivateAndReset()
-
-	// OK
-	httpmock.RegisterResponder(
-		"GET",
-		"https://example.com/xx",
-		httpmock.NewStringResponder(200, "example content oneoneone"),
-	)
-
-	httpmock.RegisterResponder(
-		"GET",
-		"https://example.com/x",
-		httpmock.NewStringResponder(404, "not found"),
-	)
-
-	err := downloadFile("https://example.com/xx", examplePath)
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		switch r.URL.Path {
+		case "/xx":
+			w.WriteHeader(http.StatusOK)
+			_, _ = io.WriteString(w, "example content oneoneone")
+		default:
+			w.WriteHeader(http.StatusNotFound)
+			_, _ = io.WriteString(w, "not found")
+		}
+	}))
+	defer ts.Close()
+
+	dest := filepath.Join(t.TempDir(), "example.txt")
+	defer os.Remove(dest)
+
+	err := downloadFile(ts.URL+"/xx", dest)
 	require.NoError(t, err)
 
-	content, err := os.ReadFile(examplePath)
+	content, err := os.ReadFile(dest)
 	assert.Equal(t, "example content oneoneone", string(content))
 	require.NoError(t, err)
 
 	// bad uri
-	err = downloadFile("https://zz.com", examplePath)
-	require.Error(t, err)
+	err = downloadFile("https://zz.com", dest)
+	cstest.RequireErrorContains(t, err, "lookup zz.com")
+	cstest.RequireErrorContains(t, err, "no such host")
 
 	// 404
-	err = downloadFile("https://example.com/x", examplePath)
-	require.Error(t, err)
+	err = downloadFile(ts.URL+"/x", dest)
+	cstest.RequireErrorContains(t, err, "bad http code 404")
 
 	// bad target
-	err = downloadFile("https://example.com/xx", "")
-	require.Error(t, err)
+	err = downloadFile(ts.URL+"/xx", "")
+	cstest.RequireErrorContains(t, err, cstest.PathNotFoundMessage)
+
+	// destination directory does not exist
+	err = downloadFile(ts.URL+"/xx", filepath.Join(t.TempDir(), "missing/example.txt"))
+	cstest.RequireErrorContains(t, err, cstest.PathNotFoundMessage)
 }