Signed-off-by: Raja Sami <raja.sami@tenpearls.com>
| 1 | 1 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,75 @@ |
| 0 |
+package authorization |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "crypto/rand" |
|
| 4 |
+ "crypto/rsa" |
|
| 5 |
+ "crypto/tls" |
|
| 6 |
+ "crypto/x509" |
|
| 7 |
+ "crypto/x509/pkix" |
|
| 8 |
+ "math/big" |
|
| 9 |
+ "net/http" |
|
| 10 |
+ "testing" |
|
| 11 |
+ "time" |
|
| 12 |
+ |
|
| 13 |
+ "github.com/stretchr/testify/require" |
|
| 14 |
+) |
|
| 15 |
+ |
|
| 16 |
+func TestPeerCertificateMarshalJSON(t *testing.T) {
|
|
| 17 |
+ template := &x509.Certificate{
|
|
| 18 |
+ IsCA: true, |
|
| 19 |
+ BasicConstraintsValid: true, |
|
| 20 |
+ SubjectKeyId: []byte{1, 2, 3},
|
|
| 21 |
+ SerialNumber: big.NewInt(1234), |
|
| 22 |
+ Subject: pkix.Name{
|
|
| 23 |
+ Country: []string{"Earth"},
|
|
| 24 |
+ Organization: []string{"Mother Nature"},
|
|
| 25 |
+ }, |
|
| 26 |
+ NotBefore: time.Now(), |
|
| 27 |
+ NotAfter: time.Now().AddDate(5, 5, 5), |
|
| 28 |
+ |
|
| 29 |
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
| 30 |
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, |
|
| 31 |
+ } |
|
| 32 |
+ // generate private key |
|
| 33 |
+ privatekey, err := rsa.GenerateKey(rand.Reader, 2048) |
|
| 34 |
+ require.NoError(t, err) |
|
| 35 |
+ publickey := &privatekey.PublicKey |
|
| 36 |
+ |
|
| 37 |
+ // create a self-signed certificate. template = parent |
|
| 38 |
+ var parent = template |
|
| 39 |
+ raw, err := x509.CreateCertificate(rand.Reader, template, parent, publickey, privatekey) |
|
| 40 |
+ require.NoError(t, err) |
|
| 41 |
+ |
|
| 42 |
+ cert, err := x509.ParseCertificate(raw) |
|
| 43 |
+ require.NoError(t, err) |
|
| 44 |
+ |
|
| 45 |
+ var certs = []*x509.Certificate{cert}
|
|
| 46 |
+ addr := "www.authz.com/auth" |
|
| 47 |
+ req, err := http.NewRequest("GET", addr, nil)
|
|
| 48 |
+ require.NoError(t, err) |
|
| 49 |
+ |
|
| 50 |
+ req.RequestURI = addr |
|
| 51 |
+ req.TLS = &tls.ConnectionState{}
|
|
| 52 |
+ req.TLS.PeerCertificates = certs |
|
| 53 |
+ req.Header.Add("header", "value")
|
|
| 54 |
+ |
|
| 55 |
+ for _, c := range req.TLS.PeerCertificates {
|
|
| 56 |
+ pcObj := PeerCertificate(*c) |
|
| 57 |
+ |
|
| 58 |
+ t.Run("Marshalling :", func(t *testing.T) {
|
|
| 59 |
+ raw, err = pcObj.MarshalJSON() |
|
| 60 |
+ require.NotNil(t, raw) |
|
| 61 |
+ require.Nil(t, err) |
|
| 62 |
+ }) |
|
| 63 |
+ |
|
| 64 |
+ t.Run("UnMarshalling :", func(t *testing.T) {
|
|
| 65 |
+ err := pcObj.UnmarshalJSON(raw) |
|
| 66 |
+ require.Nil(t, err) |
|
| 67 |
+ require.Equal(t, "Earth", pcObj.Subject.Country[0]) |
|
| 68 |
+ require.Equal(t, true, pcObj.IsCA) |
|
| 69 |
+ |
|
| 70 |
+ }) |
|
| 71 |
+ |
|
| 72 |
+ } |
|
| 73 |
+ |
|
| 74 |
+} |
| 0 | 75 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,53 @@ |
| 0 |
+package authorization |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "net/http" |
|
| 4 |
+ "net/http/httptest" |
|
| 5 |
+ "strings" |
|
| 6 |
+ "testing" |
|
| 7 |
+ |
|
| 8 |
+ "github.com/docker/docker/pkg/plugingetter" |
|
| 9 |
+ "github.com/stretchr/testify/require" |
|
| 10 |
+) |
|
| 11 |
+ |
|
| 12 |
+func TestMiddleware(t *testing.T) {
|
|
| 13 |
+ pluginNames := []string{"testPlugin1", "testPlugin2"}
|
|
| 14 |
+ var pluginGetter plugingetter.PluginGetter |
|
| 15 |
+ m := NewMiddleware(pluginNames, pluginGetter) |
|
| 16 |
+ authPlugins := m.getAuthzPlugins() |
|
| 17 |
+ require.Equal(t, 2, len(authPlugins)) |
|
| 18 |
+ require.EqualValues(t, pluginNames[0], authPlugins[0].Name()) |
|
| 19 |
+ require.EqualValues(t, pluginNames[1], authPlugins[1].Name()) |
|
| 20 |
+} |
|
| 21 |
+ |
|
| 22 |
+func TestNewResponseModifier(t *testing.T) {
|
|
| 23 |
+ recorder := httptest.NewRecorder() |
|
| 24 |
+ modifier := NewResponseModifier(recorder) |
|
| 25 |
+ modifier.Header().Set("H1", "V1")
|
|
| 26 |
+ modifier.Write([]byte("body"))
|
|
| 27 |
+ require.False(t, modifier.Hijacked()) |
|
| 28 |
+ modifier.WriteHeader(http.StatusInternalServerError) |
|
| 29 |
+ require.NotNil(t, modifier.RawBody()) |
|
| 30 |
+ |
|
| 31 |
+ raw, err := modifier.RawHeaders() |
|
| 32 |
+ require.NotNil(t, raw) |
|
| 33 |
+ require.Nil(t, err) |
|
| 34 |
+ |
|
| 35 |
+ headerData := strings.Split(strings.TrimSpace(string(raw)), ":") |
|
| 36 |
+ require.EqualValues(t, "H1", strings.TrimSpace(headerData[0])) |
|
| 37 |
+ require.EqualValues(t, "V1", strings.TrimSpace(headerData[1])) |
|
| 38 |
+ |
|
| 39 |
+ modifier.Flush() |
|
| 40 |
+ modifier.FlushAll() |
|
| 41 |
+ |
|
| 42 |
+ if recorder.Header().Get("H1") != "V1" {
|
|
| 43 |
+ t.Fatalf("Header value must exists %s", recorder.Header().Get("H1"))
|
|
| 44 |
+ } |
|
| 45 |
+ |
|
| 46 |
+} |
|
| 47 |
+ |
|
| 48 |
+func setAuthzPlugins(m *Middleware, plugins []Plugin) {
|
|
| 49 |
+ m.mu.Lock() |
|
| 50 |
+ m.plugins = plugins |
|
| 51 |
+ m.mu.Unlock() |
|
| 52 |
+} |
| 0 | 53 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,65 @@ |
| 0 |
+// +build !windows |
|
| 1 |
+ |
|
| 2 |
+package authorization |
|
| 3 |
+ |
|
| 4 |
+import ( |
|
| 5 |
+ "net/http" |
|
| 6 |
+ "net/http/httptest" |
|
| 7 |
+ "testing" |
|
| 8 |
+ |
|
| 9 |
+ "github.com/docker/docker/pkg/plugingetter" |
|
| 10 |
+ "github.com/stretchr/testify/require" |
|
| 11 |
+ "golang.org/x/net/context" |
|
| 12 |
+) |
|
| 13 |
+ |
|
| 14 |
+func TestMiddlewareWrapHandler(t *testing.T) {
|
|
| 15 |
+ server := authZPluginTestServer{t: t}
|
|
| 16 |
+ server.start() |
|
| 17 |
+ defer server.stop() |
|
| 18 |
+ |
|
| 19 |
+ authZPlugin := createTestPlugin(t) |
|
| 20 |
+ pluginNames := []string{authZPlugin.name}
|
|
| 21 |
+ |
|
| 22 |
+ var pluginGetter plugingetter.PluginGetter |
|
| 23 |
+ middleWare := NewMiddleware(pluginNames, pluginGetter) |
|
| 24 |
+ handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
|
| 25 |
+ return nil |
|
| 26 |
+ } |
|
| 27 |
+ |
|
| 28 |
+ authList := []Plugin{authZPlugin}
|
|
| 29 |
+ middleWare.SetPlugins([]string{"My Test Plugin"})
|
|
| 30 |
+ setAuthzPlugins(middleWare, authList) |
|
| 31 |
+ mdHandler := middleWare.WrapHandler(handler) |
|
| 32 |
+ require.NotNil(t, mdHandler) |
|
| 33 |
+ |
|
| 34 |
+ addr := "www.example.com/auth" |
|
| 35 |
+ req, _ := http.NewRequest("GET", addr, nil)
|
|
| 36 |
+ req.RequestURI = addr |
|
| 37 |
+ req.Header.Add("header", "value")
|
|
| 38 |
+ |
|
| 39 |
+ resp := httptest.NewRecorder() |
|
| 40 |
+ ctx := context.Background() |
|
| 41 |
+ |
|
| 42 |
+ t.Run("Error Test Case :", func(t *testing.T) {
|
|
| 43 |
+ server.replayResponse = Response{
|
|
| 44 |
+ Allow: false, |
|
| 45 |
+ Msg: "Server Auth Not Allowed", |
|
| 46 |
+ } |
|
| 47 |
+ if err := mdHandler(ctx, resp, req, map[string]string{}); err == nil {
|
|
| 48 |
+ require.Error(t, err) |
|
| 49 |
+ } |
|
| 50 |
+ |
|
| 51 |
+ }) |
|
| 52 |
+ |
|
| 53 |
+ t.Run("Positive Test Case :", func(t *testing.T) {
|
|
| 54 |
+ server.replayResponse = Response{
|
|
| 55 |
+ Allow: true, |
|
| 56 |
+ Msg: "Server Auth Allowed", |
|
| 57 |
+ } |
|
| 58 |
+ if err := mdHandler(ctx, resp, req, map[string]string{}); err != nil {
|
|
| 59 |
+ require.NoError(t, err) |
|
| 60 |
+ } |
|
| 61 |
+ |
|
| 62 |
+ }) |
|
| 63 |
+ |
|
| 64 |
+} |