pkg/authorization/authz_unix_test.go
57faef5c
 // +build !windows
 
 // TODO Windows: This uses a Unix socket for testing. This might be possible
 // to port to Windows using a named pipe instead.
 
75c353f0
 package authorization
 
 import (
d1b7e837
 	"bytes"
75c353f0
 	"encoding/json"
 	"io/ioutil"
 	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
 	"path"
 	"reflect"
1a630234
 	"strings"
d1b7e837
 	"testing"
1a630234
 
8435ea52
 	"github.com/docker/docker/pkg/plugins"
8e034802
 	"github.com/docker/go-connections/tlsconfig"
8435ea52
 	"github.com/gorilla/mux"
75c353f0
 )
 
d1b7e837
 const (
 	pluginAddress = "authz-test-plugin.sock"
 )
75c353f0
 
46e3a249
 func TestAuthZRequestPluginError(t *testing.T) {
 	server := authZPluginTestServer{t: t}
f437e2d1
 	server.start()
46e3a249
 	defer server.stop()
 
 	authZPlugin := createTestPlugin(t)
 
 	request := Request{
 		User:           "user",
 		RequestBody:    []byte("sample body"),
d1b7e837
 		RequestURI:     "www.authz.com/auth",
46e3a249
 		RequestMethod:  "GET",
 		RequestHeaders: map[string]string{"header": "value"},
 	}
 	server.replayResponse = Response{
 		Err: "an error",
 	}
 
 	actualResponse, err := authZPlugin.AuthZRequest(&request)
 	if err != nil {
 		t.Fatalf("Failed to authorize request %v", err)
 	}
 
 	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
d1b7e837
 		t.Fatal("Response must be equal")
46e3a249
 	}
 	if !reflect.DeepEqual(request, server.recordedRequest) {
d1b7e837
 		t.Fatal("Requests must be equal")
46e3a249
 	}
 }
 
75c353f0
 func TestAuthZRequestPlugin(t *testing.T) {
 	server := authZPluginTestServer{t: t}
f437e2d1
 	server.start()
75c353f0
 	defer server.stop()
 
 	authZPlugin := createTestPlugin(t)
 
 	request := Request{
 		User:           "user",
 		RequestBody:    []byte("sample body"),
d1b7e837
 		RequestURI:     "www.authz.com/auth",
75c353f0
 		RequestMethod:  "GET",
 		RequestHeaders: map[string]string{"header": "value"},
 	}
 	server.replayResponse = Response{
 		Allow: true,
 		Msg:   "Sample message",
 	}
 
 	actualResponse, err := authZPlugin.AuthZRequest(&request)
 	if err != nil {
 		t.Fatalf("Failed to authorize request %v", err)
 	}
 
 	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
d1b7e837
 		t.Fatal("Response must be equal")
75c353f0
 	}
 	if !reflect.DeepEqual(request, server.recordedRequest) {
d1b7e837
 		t.Fatal("Requests must be equal")
75c353f0
 	}
 }
 
 func TestAuthZResponsePlugin(t *testing.T) {
 	server := authZPluginTestServer{t: t}
f437e2d1
 	server.start()
75c353f0
 	defer server.stop()
 
 	authZPlugin := createTestPlugin(t)
 
 	request := Request{
 		User:        "user",
39bcaee4
 		RequestURI:  "something.com/auth",
75c353f0
 		RequestBody: []byte("sample body"),
 	}
 	server.replayResponse = Response{
 		Allow: true,
 		Msg:   "Sample message",
 	}
 
 	actualResponse, err := authZPlugin.AuthZResponse(&request)
 	if err != nil {
 		t.Fatalf("Failed to authorize request %v", err)
 	}
 
 	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
d1b7e837
 		t.Fatal("Response must be equal")
75c353f0
 	}
 	if !reflect.DeepEqual(request, server.recordedRequest) {
d1b7e837
 		t.Fatal("Requests must be equal")
75c353f0
 	}
 }
 
 func TestResponseModifier(t *testing.T) {
 	r := httptest.NewRecorder()
 	m := NewResponseModifier(r)
 	m.Header().Set("h1", "v1")
 	m.Write([]byte("body"))
6bec735c
 	m.WriteHeader(http.StatusInternalServerError)
75c353f0
 
5ffc810d
 	m.FlushAll()
75c353f0
 	if r.Header().Get("h1") != "v1" {
 		t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
 	}
 	if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
 		t.Fatalf("Body value must exists %s", r.Body.Bytes())
 	}
6bec735c
 	if r.Code != http.StatusInternalServerError {
75c353f0
 		t.Fatalf("Status code must be correct %d", r.Code)
 	}
 }
 
ca5c2abe
 func TestDrainBody(t *testing.T) {
 	tests := []struct {
 		length             int // length is the message length send to drainBody
 		expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
 	}{
 		{10, 10}, // Small message size
 		{maxBodySize - 1, maxBodySize - 1}, // Max message size
 		{maxBodySize * 2, 0},               // Large message size (skip copying body)
 
 	}
 
 	for _, test := range tests {
 		msg := strings.Repeat("a", test.length)
 		body, closer, err := drainBody(ioutil.NopCloser(bytes.NewReader([]byte(msg))))
6a966844
 		if err != nil {
 			t.Fatal(err)
 		}
ca5c2abe
 		if len(body) != test.expectedBodyLength {
 			t.Fatalf("Body must be copied, actual length: '%d'", len(body))
 		}
 		if closer == nil {
d1b7e837
 			t.Fatal("Closer must not be nil")
ca5c2abe
 		}
 		modified, err := ioutil.ReadAll(closer)
 		if err != nil {
 			t.Fatalf("Error must not be nil: '%v'", err)
 		}
 		if len(modified) != len(msg) {
 			t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
 		}
 	}
 }
 
75c353f0
 func TestResponseModifierOverride(t *testing.T) {
 	r := httptest.NewRecorder()
 	m := NewResponseModifier(r)
 	m.Header().Set("h1", "v1")
 	m.Write([]byte("body"))
6bec735c
 	m.WriteHeader(http.StatusInternalServerError)
75c353f0
 
 	overrideHeader := make(http.Header)
 	overrideHeader.Add("h1", "v2")
 	overrideHeaderBytes, err := json.Marshal(overrideHeader)
 	if err != nil {
 		t.Fatalf("override header failed %v", err)
 	}
 
 	m.OverrideHeader(overrideHeaderBytes)
 	m.OverrideBody([]byte("override body"))
6bec735c
 	m.OverrideStatusCode(http.StatusNotFound)
5ffc810d
 	m.FlushAll()
75c353f0
 	if r.Header().Get("h1") != "v2" {
 		t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
 	}
 	if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
 		t.Fatalf("Body value must exists %s", r.Body.Bytes())
 	}
6bec735c
 	if r.Code != http.StatusNotFound {
75c353f0
 		t.Fatalf("Status code must be correct %d", r.Code)
 	}
 }
 
 // createTestPlugin creates a new sample authorization plugin
 func createTestPlugin(t *testing.T) *authorizationPlugin {
 	pwd, err := os.Getwd()
 	if err != nil {
6a966844
 		t.Fatal(err)
75c353f0
 	}
 
f3711704
 	client, err := plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), &tlsconfig.Options{InsecureSkipVerify: true})
75c353f0
 	if err != nil {
 		t.Fatalf("Failed to create client %v", err)
 	}
 
f3711704
 	return &authorizationPlugin{name: "plugin", plugin: client}
75c353f0
 }
 
 // AuthZPluginTestServer is a simple server that implements the authZ plugin interface
 type authZPluginTestServer struct {
 	listener net.Listener
 	t        *testing.T
 	// request stores the request sent from the daemon to the plugin
 	recordedRequest Request
 	// response stores the response sent from the plugin to the daemon
 	replayResponse Response
f437e2d1
 	server         *httptest.Server
75c353f0
 }
 
 // start starts the test server that implements the plugin
 func (t *authZPluginTestServer) start() {
 	r := mux.NewRouter()
d1b7e837
 	l, err := net.Listen("unix", pluginAddress)
 	if err != nil {
 		t.t.Fatal(err)
 	}
75c353f0
 	t.listener = l
 	r.HandleFunc("/Plugin.Activate", t.activate)
 	r.HandleFunc("/"+AuthZApiRequest, t.auth)
 	r.HandleFunc("/"+AuthZApiResponse, t.auth)
f437e2d1
 	t.server = &httptest.Server{
 		Listener: l,
 		Config: &http.Server{
 			Handler: r,
 			Addr:    pluginAddress,
 		},
 	}
 	t.server.Start()
75c353f0
 }
 
 // stop stops the test server that implements the plugin
 func (t *authZPluginTestServer) stop() {
f437e2d1
 	t.server.Close()
75c353f0
 	os.Remove(pluginAddress)
 	if t.listener != nil {
 		t.listener.Close()
 	}
 }
 
 // auth is a used to record/replay the authentication api messages
 func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
 	t.recordedRequest = Request{}
d1b7e837
 	body, err := ioutil.ReadAll(r.Body)
 	if err != nil {
 		t.t.Fatal(err)
 	}
6a966844
 	r.Body.Close()
75c353f0
 	json.Unmarshal(body, &t.recordedRequest)
d1b7e837
 	b, err := json.Marshal(t.replayResponse)
 	if err != nil {
 		t.t.Fatal(err)
 	}
75c353f0
 	w.Write(b)
 }
 
 func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
d1b7e837
 	b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
 	if err != nil {
 		t.t.Fatal(err)
 	}
75c353f0
 	w.Write(b)
 }