moby/pkg/authorization/authz_unix_test.go
Sebastiaan van Stijn ab35df454d
remove pre-go1.17 build-tags
Removed pre-go1.17 build-tags with go fix;

    go mod init
    go fix -mod=readonly ./...
    rm go.mod

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
2023-05-19 20:38:51 +02:00

348 lines
8.6 KiB
Go

//go:build !windows
// TODO Windows: This uses a Unix socket for testing. This might be possible
// to port to Windows using a named pipe instead.
package authorization // import "github.com/docker/docker/pkg/authorization"
import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"strings"
"testing"
"github.com/docker/docker/pkg/plugins"
"github.com/docker/go-connections/tlsconfig"
"github.com/gorilla/mux"
)
const (
pluginAddress = "authz-test-plugin.sock"
)
func TestAuthZRequestPluginError(t *testing.T) {
server := authZPluginTestServer{t: t}
server.start()
defer server.stop()
authZPlugin := createTestPlugin(t, server.socketAddress())
request := Request{
User: "user",
RequestBody: []byte("sample body"),
RequestURI: "www.authz.com/auth",
RequestMethod: http.MethodGet,
RequestHeaders: map[string]string{"header": "value"},
}
server.replayResponse = Response{
Err: "an error",
}
actualResponse, err := authZPlugin.AuthZRequest(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatal("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatal("Requests must be equal")
}
}
func TestAuthZRequestPlugin(t *testing.T) {
server := authZPluginTestServer{t: t}
server.start()
defer server.stop()
authZPlugin := createTestPlugin(t, server.socketAddress())
request := Request{
User: "user",
RequestBody: []byte("sample body"),
RequestURI: "www.authz.com/auth",
RequestMethod: http.MethodGet,
RequestHeaders: map[string]string{"header": "value"},
}
server.replayResponse = Response{
Allow: true,
Msg: "Sample message",
}
actualResponse, err := authZPlugin.AuthZRequest(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatal("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatal("Requests must be equal")
}
}
func TestAuthZResponsePlugin(t *testing.T) {
server := authZPluginTestServer{t: t}
server.start()
defer server.stop()
authZPlugin := createTestPlugin(t, server.socketAddress())
request := Request{
User: "user",
RequestURI: "something.com/auth",
RequestBody: []byte("sample body"),
}
server.replayResponse = Response{
Allow: true,
Msg: "Sample message",
}
actualResponse, err := authZPlugin.AuthZResponse(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatal("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatal("Requests must be equal")
}
}
func TestResponseModifier(t *testing.T) {
r := httptest.NewRecorder()
m := NewResponseModifier(r)
m.Header().Set("h1", "v1")
m.Write([]byte("body"))
m.WriteHeader(http.StatusInternalServerError)
m.FlushAll()
if r.Header().Get("h1") != "v1" {
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
}
if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
t.Fatalf("Body value must exists %s", r.Body.Bytes())
}
if r.Code != http.StatusInternalServerError {
t.Fatalf("Status code must be correct %d", r.Code)
}
}
func TestDrainBody(t *testing.T) {
tests := []struct {
length int // length is the message length send to drainBody
expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
}{
{10, 10}, // Small message size
{maxBodySize - 1, maxBodySize - 1}, // Max message size
{maxBodySize * 2, 0}, // Large message size (skip copying body)
}
for _, test := range tests {
msg := strings.Repeat("a", test.length)
body, closer, err := drainBody(io.NopCloser(bytes.NewReader([]byte(msg))))
if err != nil {
t.Fatal(err)
}
if len(body) != test.expectedBodyLength {
t.Fatalf("Body must be copied, actual length: '%d'", len(body))
}
if closer == nil {
t.Fatal("Closer must not be nil")
}
modified, err := io.ReadAll(closer)
if err != nil {
t.Fatalf("Error must not be nil: '%v'", err)
}
if len(modified) != len(msg) {
t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
}
}
}
func TestSendBody(t *testing.T) {
var (
url = "nothing.com"
testcases = []struct {
contentType string
expected bool
}{
{
contentType: "application/json",
expected: true,
},
{
contentType: "Application/json",
expected: true,
},
{
contentType: "application/JSON",
expected: true,
},
{
contentType: "APPLICATION/JSON",
expected: true,
},
{
contentType: "application/json; charset=utf-8",
expected: true,
},
{
contentType: "application/json;charset=utf-8",
expected: true,
},
{
contentType: "application/json; charset=UTF8",
expected: true,
},
{
contentType: "application/json;charset=UTF8",
expected: true,
},
{
contentType: "text/html",
expected: false,
},
{
contentType: "",
expected: false,
},
}
)
for _, testcase := range testcases {
header := http.Header{}
header.Set("Content-Type", testcase.contentType)
if b := sendBody(url, header); b != testcase.expected {
t.Fatalf("Unexpected Content-Type; Expected: %t, Actual: %t", testcase.expected, b)
}
}
}
func TestResponseModifierOverride(t *testing.T) {
r := httptest.NewRecorder()
m := NewResponseModifier(r)
m.Header().Set("h1", "v1")
m.Write([]byte("body"))
m.WriteHeader(http.StatusInternalServerError)
overrideHeader := make(http.Header)
overrideHeader.Add("h1", "v2")
overrideHeaderBytes, err := json.Marshal(overrideHeader)
if err != nil {
t.Fatalf("override header failed %v", err)
}
m.OverrideHeader(overrideHeaderBytes)
m.OverrideBody([]byte("override body"))
m.OverrideStatusCode(http.StatusNotFound)
m.FlushAll()
if r.Header().Get("h1") != "v2" {
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
}
if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
t.Fatalf("Body value must exists %s", r.Body.Bytes())
}
if r.Code != http.StatusNotFound {
t.Fatalf("Status code must be correct %d", r.Code)
}
}
// createTestPlugin creates a new sample authorization plugin
func createTestPlugin(t *testing.T, socketAddress string) *authorizationPlugin {
client, err := plugins.NewClient("unix:///"+socketAddress, &tlsconfig.Options{InsecureSkipVerify: true})
if err != nil {
t.Fatalf("Failed to create client %v", err)
}
return &authorizationPlugin{name: "plugin", plugin: client}
}
// AuthZPluginTestServer is a simple server that implements the authZ plugin interface
type authZPluginTestServer struct {
listener net.Listener
t *testing.T
// request stores the request sent from the daemon to the plugin
recordedRequest Request
// response stores the response sent from the plugin to the daemon
replayResponse Response
server *httptest.Server
tmpDir string
}
func (t *authZPluginTestServer) socketAddress() string {
return path.Join(t.tmpDir, pluginAddress)
}
// start starts the test server that implements the plugin
func (t *authZPluginTestServer) start() {
var err error
t.tmpDir, err = os.MkdirTemp("", "authz")
if err != nil {
t.t.Fatal(err)
}
r := mux.NewRouter()
l, err := net.Listen("unix", t.socketAddress())
if err != nil {
t.t.Fatal(err)
}
t.listener = l
r.HandleFunc("/Plugin.Activate", t.activate)
r.HandleFunc("/"+AuthZApiRequest, t.auth)
r.HandleFunc("/"+AuthZApiResponse, t.auth)
t.server = &httptest.Server{
Listener: l,
Config: &http.Server{
Handler: r,
Addr: pluginAddress,
},
}
t.server.Start()
}
// stop stops the test server that implements the plugin
func (t *authZPluginTestServer) stop() {
t.server.Close()
_ = os.RemoveAll(t.tmpDir)
if t.listener != nil {
t.listener.Close()
}
}
// auth is a used to record/replay the authentication api messages
func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
t.recordedRequest = Request{}
body, err := io.ReadAll(r.Body)
if err != nil {
t.t.Fatal(err)
}
r.Body.Close()
json.Unmarshal(body, &t.recordedRequest)
b, err := json.Marshal(t.replayResponse)
if err != nil {
t.t.Fatal(err)
}
w.Write(b)
}
func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
if err != nil {
t.t.Fatal(err)
}
w.Write(b)
}