Browse Source

Increase test Coverage of pkg/authorization

Signed-off-by: Raja Sami <raja.sami@tenpearls.com>
Raja Sami 8 năm trước cách đây
mục cha
commit
f1eb0c0ebb

+ 75 - 0
pkg/authorization/api_test.go

@@ -0,0 +1,75 @@
+package authorization
+
+import (
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/tls"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"math/big"
+	"net/http"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestPeerCertificateMarshalJSON(t *testing.T) {
+	template := &x509.Certificate{
+		IsCA: true,
+		BasicConstraintsValid: true,
+		SubjectKeyId:          []byte{1, 2, 3},
+		SerialNumber:          big.NewInt(1234),
+		Subject: pkix.Name{
+			Country:      []string{"Earth"},
+			Organization: []string{"Mother Nature"},
+		},
+		NotBefore: time.Now(),
+		NotAfter:  time.Now().AddDate(5, 5, 5),
+
+		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
+		KeyUsage:    x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+	}
+	// generate private key
+	privatekey, err := rsa.GenerateKey(rand.Reader, 2048)
+	require.NoError(t, err)
+	publickey := &privatekey.PublicKey
+
+	// create a self-signed certificate. template = parent
+	var parent = template
+	raw, err := x509.CreateCertificate(rand.Reader, template, parent, publickey, privatekey)
+	require.NoError(t, err)
+
+	cert, err := x509.ParseCertificate(raw)
+	require.NoError(t, err)
+
+	var certs = []*x509.Certificate{cert}
+	addr := "www.authz.com/auth"
+	req, err := http.NewRequest("GET", addr, nil)
+	require.NoError(t, err)
+
+	req.RequestURI = addr
+	req.TLS = &tls.ConnectionState{}
+	req.TLS.PeerCertificates = certs
+	req.Header.Add("header", "value")
+
+	for _, c := range req.TLS.PeerCertificates {
+		pcObj := PeerCertificate(*c)
+
+		t.Run("Marshalling :", func(t *testing.T) {
+			raw, err = pcObj.MarshalJSON()
+			require.NotNil(t, raw)
+			require.Nil(t, err)
+		})
+
+		t.Run("UnMarshalling :", func(t *testing.T) {
+			err := pcObj.UnmarshalJSON(raw)
+			require.Nil(t, err)
+			require.Equal(t, "Earth", pcObj.Subject.Country[0])
+			require.Equal(t, true, pcObj.IsCA)
+
+		})
+
+	}
+
+}

+ 53 - 0
pkg/authorization/middleware_test.go

@@ -0,0 +1,53 @@
+package authorization
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+
+	"github.com/docker/docker/pkg/plugingetter"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMiddleware(t *testing.T) {
+	pluginNames := []string{"testPlugin1", "testPlugin2"}
+	var pluginGetter plugingetter.PluginGetter
+	m := NewMiddleware(pluginNames, pluginGetter)
+	authPlugins := m.getAuthzPlugins()
+	require.Equal(t, 2, len(authPlugins))
+	require.EqualValues(t, pluginNames[0], authPlugins[0].Name())
+	require.EqualValues(t, pluginNames[1], authPlugins[1].Name())
+}
+
+func TestNewResponseModifier(t *testing.T) {
+	recorder := httptest.NewRecorder()
+	modifier := NewResponseModifier(recorder)
+	modifier.Header().Set("H1", "V1")
+	modifier.Write([]byte("body"))
+	require.False(t, modifier.Hijacked())
+	modifier.WriteHeader(http.StatusInternalServerError)
+	require.NotNil(t, modifier.RawBody())
+
+	raw, err := modifier.RawHeaders()
+	require.NotNil(t, raw)
+	require.Nil(t, err)
+
+	headerData := strings.Split(strings.TrimSpace(string(raw)), ":")
+	require.EqualValues(t, "H1", strings.TrimSpace(headerData[0]))
+	require.EqualValues(t, "V1", strings.TrimSpace(headerData[1]))
+
+	modifier.Flush()
+	modifier.FlushAll()
+
+	if recorder.Header().Get("H1") != "V1" {
+		t.Fatalf("Header value must exists %s", recorder.Header().Get("H1"))
+	}
+
+}
+
+func setAuthzPlugins(m *Middleware, plugins []Plugin) {
+	m.mu.Lock()
+	m.plugins = plugins
+	m.mu.Unlock()
+}

+ 65 - 0
pkg/authorization/middleware_unix_test.go

@@ -0,0 +1,65 @@
+// +build !windows
+
+package authorization
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/docker/docker/pkg/plugingetter"
+	"github.com/stretchr/testify/require"
+	"golang.org/x/net/context"
+)
+
+func TestMiddlewareWrapHandler(t *testing.T) {
+	server := authZPluginTestServer{t: t}
+	server.start()
+	defer server.stop()
+
+	authZPlugin := createTestPlugin(t)
+	pluginNames := []string{authZPlugin.name}
+
+	var pluginGetter plugingetter.PluginGetter
+	middleWare := NewMiddleware(pluginNames, pluginGetter)
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		return nil
+	}
+
+	authList := []Plugin{authZPlugin}
+	middleWare.SetPlugins([]string{"My Test Plugin"})
+	setAuthzPlugins(middleWare, authList)
+	mdHandler := middleWare.WrapHandler(handler)
+	require.NotNil(t, mdHandler)
+
+	addr := "www.example.com/auth"
+	req, _ := http.NewRequest("GET", addr, nil)
+	req.RequestURI = addr
+	req.Header.Add("header", "value")
+
+	resp := httptest.NewRecorder()
+	ctx := context.Background()
+
+	t.Run("Error Test Case :", func(t *testing.T) {
+		server.replayResponse = Response{
+			Allow: false,
+			Msg:   "Server Auth Not Allowed",
+		}
+		if err := mdHandler(ctx, resp, req, map[string]string{}); err == nil {
+			require.Error(t, err)
+		}
+
+	})
+
+	t.Run("Positive Test Case :", func(t *testing.T) {
+		server.replayResponse = Response{
+			Allow: true,
+			Msg:   "Server Auth Allowed",
+		}
+		if err := mdHandler(ctx, resp, req, map[string]string{}); err != nil {
+			require.NoError(t, err)
+		}
+
+	})
+
+}