Browse code

Add option to unionauth to fail on error, fail on invalid or expired bearer tokens

Jordan Liggitt authored on 2015/01/29 01:19:27
Showing 7 changed files
... ...
@@ -11,23 +11,29 @@ import (
11 11
 
12 12
 // TODO remove this in favor of kubernetes types
13 13
 
14
-type unionAuthRequestHandler []authenticator.Request
14
+type Authenticator struct {
15
+	Handlers    []authenticator.Request
16
+	FailOnError bool
17
+}
15 18
 
16 19
 // NewUnionAuthentication returns a request authenticator that validates credentials using a chain of authenticator.Request objects
17
-func NewUnionAuthentication(authRequestHandlers []authenticator.Request) authenticator.Request {
18
-	return unionAuthRequestHandler(authRequestHandlers)
20
+func NewUnionAuthentication(authRequestHandlers ...authenticator.Request) authenticator.Request {
21
+	return &Authenticator{Handlers: authRequestHandlers}
19 22
 }
20 23
 
21 24
 // AuthenticateRequest authenticates the request using a chain of authenticator.Request objects.  The first
22 25
 // success returns that identity.  Errors are only returned if no matches are found.
23
-func (authHandler unionAuthRequestHandler) AuthenticateRequest(req *http.Request) (authapi.UserInfo, bool, error) {
26
+func (authHandler *Authenticator) AuthenticateRequest(req *http.Request) (authapi.UserInfo, bool, error) {
24 27
 	errors := []error{}
25
-	for _, currAuthRequestHandler := range authHandler {
28
+	for _, currAuthRequestHandler := range authHandler.Handlers {
26 29
 		info, ok, err := currAuthRequestHandler.AuthenticateRequest(req)
27 30
 		if err == nil && ok {
28 31
 			return info, ok, err
29 32
 		}
30 33
 		if err != nil {
34
+			if authHandler.FailOnError {
35
+				return nil, false, err
36
+			}
31 37
 			errors = append(errors, err)
32 38
 		}
33 39
 	}
... ...
@@ -7,7 +7,6 @@ import (
7 7
 	"testing"
8 8
 
9 9
 	authapi "github.com/openshift/origin/pkg/auth/api"
10
-	"github.com/openshift/origin/pkg/auth/authenticator"
11 10
 )
12 11
 
13 12
 type mockAuthRequestHandler struct {
... ...
@@ -23,7 +22,7 @@ func (mock *mockAuthRequestHandler) AuthenticateRequest(req *http.Request) (auth
23 23
 func TestAuthenticateRequestSecondPasses(t *testing.T) {
24 24
 	handler1 := &mockAuthRequestHandler{}
25 25
 	handler2 := &mockAuthRequestHandler{isAuthenticated: true}
26
-	authRequestHandler := NewUnionAuthentication([]authenticator.Request{handler1, handler2})
26
+	authRequestHandler := NewUnionAuthentication(handler1, handler2)
27 27
 	req, _ := http.NewRequest("GET", "http://example.org", nil)
28 28
 
29 29
 	_, isAuthenticated, err := authRequestHandler.AuthenticateRequest(req)
... ...
@@ -38,7 +37,7 @@ func TestAuthenticateRequestSecondPasses(t *testing.T) {
38 38
 func TestAuthenticateRequestSuppressUnnecessaryErrors(t *testing.T) {
39 39
 	handler1 := &mockAuthRequestHandler{err: errors.New("first")}
40 40
 	handler2 := &mockAuthRequestHandler{isAuthenticated: true}
41
-	authRequestHandler := NewUnionAuthentication([]authenticator.Request{handler1, handler2})
41
+	authRequestHandler := NewUnionAuthentication(handler1, handler2)
42 42
 	req, _ := http.NewRequest("GET", "http://example.org", nil)
43 43
 
44 44
 	_, isAuthenticated, err := authRequestHandler.AuthenticateRequest(req)
... ...
@@ -53,7 +52,7 @@ func TestAuthenticateRequestSuppressUnnecessaryErrors(t *testing.T) {
53 53
 func TestAuthenticateRequestNonePass(t *testing.T) {
54 54
 	handler1 := &mockAuthRequestHandler{}
55 55
 	handler2 := &mockAuthRequestHandler{}
56
-	authRequestHandler := NewUnionAuthentication([]authenticator.Request{handler1, handler2})
56
+	authRequestHandler := NewUnionAuthentication(handler1, handler2)
57 57
 	req, _ := http.NewRequest("GET", "http://example.org", nil)
58 58
 
59 59
 	_, isAuthenticated, err := authRequestHandler.AuthenticateRequest(req)
... ...
@@ -68,7 +67,7 @@ func TestAuthenticateRequestNonePass(t *testing.T) {
68 68
 func TestAuthenticateRequestAdditiveErrors(t *testing.T) {
69 69
 	handler1 := &mockAuthRequestHandler{err: errors.New("first")}
70 70
 	handler2 := &mockAuthRequestHandler{err: errors.New("second")}
71
-	authRequestHandler := NewUnionAuthentication([]authenticator.Request{handler1, handler2})
71
+	authRequestHandler := NewUnionAuthentication(handler1, handler2)
72 72
 	req, _ := http.NewRequest("GET", "http://example.org", nil)
73 73
 
74 74
 	_, isAuthenticated, err := authRequestHandler.AuthenticateRequest(req)
... ...
@@ -2,6 +2,7 @@ package filetoken
2 2
 
3 3
 import (
4 4
 	"encoding/csv"
5
+	"errors"
5 6
 	"io"
6 7
 	"os"
7 8
 
... ...
@@ -52,7 +53,7 @@ func NewTokenAuthenticator(path string) (*TokenAuthenticator, error) {
52 52
 func (a *TokenAuthenticator) AuthenticateToken(value string) (api.UserInfo, bool, error) {
53 53
 	user, ok := a.tokens[value]
54 54
 	if !ok {
55
-		return nil, false, nil
55
+		return nil, false, errors.New("Invalid token")
56 56
 	}
57 57
 	return user, true, nil
58 58
 }
... ...
@@ -249,8 +249,11 @@ func TestAuthenticateTokenNotFound(t *testing.T) {
249 249
 	if found {
250 250
 		t.Error("Found token, but it should be missing!")
251 251
 	}
252
-	if err != nil {
253
-		t.Error("Unexpected error: %v", err)
252
+	if err == nil {
253
+		t.Error("Expected not found error")
254
+	}
255
+	if !apierrs.IsNotFound(err) {
256
+		t.Error("Expected not found error")
254 257
 	}
255 258
 	if userInfo != nil {
256 259
 		t.Error("Unexpected user: %v", userInfo)
... ...
@@ -288,7 +291,7 @@ func TestAuthenticateTokenExpired(t *testing.T) {
288 288
 	if found {
289 289
 		t.Error("Found token, but it should be missing!")
290 290
 	}
291
-	if err != nil {
291
+	if err != ErrExpired {
292 292
 		t.Error("Unexpected error: %v", err)
293 293
 	}
294 294
 	if userInfo != nil {
... ...
@@ -1,10 +1,9 @@
1 1
 package registry
2 2
 
3 3
 import (
4
+	"errors"
4 5
 	"time"
5 6
 
6
-	"github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors"
7
-
8 7
 	"github.com/openshift/origin/pkg/auth/api"
9 8
 	"github.com/openshift/origin/pkg/oauth/registry/accesstoken"
10 9
 	"github.com/openshift/origin/pkg/oauth/scope"
... ...
@@ -14,6 +13,8 @@ type TokenAuthenticator struct {
14 14
 	registry accesstoken.Registry
15 15
 }
16 16
 
17
+var ErrExpired = errors.New("Token is expired")
18
+
17 19
 func NewTokenAuthenticator(registry accesstoken.Registry) *TokenAuthenticator {
18 20
 	return &TokenAuthenticator{
19 21
 		registry: registry,
... ...
@@ -22,14 +23,11 @@ func NewTokenAuthenticator(registry accesstoken.Registry) *TokenAuthenticator {
22 22
 
23 23
 func (a *TokenAuthenticator) AuthenticateToken(value string) (api.UserInfo, bool, error) {
24 24
 	token, err := a.registry.GetAccessToken(value)
25
-	if errors.IsNotFound(err) {
26
-		return nil, false, nil
27
-	}
28 25
 	if err != nil {
29 26
 		return nil, false, err
30 27
 	}
31 28
 	if token.CreationTimestamp.Time.Add(time.Duration(token.ExpiresIn) * time.Second).Before(time.Now()) {
32
-		return nil, false, nil
29
+		return nil, false, ErrExpired
33 30
 	}
34 31
 	return &api.DefaultUserInfo{
35 32
 		Name:  token.UserName,
... ...
@@ -369,7 +369,7 @@ func (c *AuthConfig) getAuthenticationRequestHandler(sessionStore session.Store)
369 369
 		authRequestHandlers = append(authRequestHandlers, c.getAuthenticationRequestHandlerFromType(currType, sessionStore))
370 370
 	}
371 371
 
372
-	authRequestHandler := unionrequest.NewUnionAuthentication(authRequestHandlers)
372
+	authRequestHandler := unionrequest.NewUnionAuthentication(authRequestHandlers...)
373 373
 	return authRequestHandler
374 374
 }
375 375
 
... ...
@@ -316,7 +316,7 @@ func start(cfg *config, args []string) error {
316 316
 		if err != nil {
317 317
 			glog.Fatalf("Error creating TokenAuthenticator: %v", err)
318 318
 		}
319
-		authenticators = append(authenticators, group.NewGroupAdder(bearertoken.New(tokenAuthenticator), []string{authenticatedGroup}))
319
+		authenticators = append(authenticators, bearertoken.New(tokenAuthenticator))
320 320
 
321 321
 		var roots *x509.CertPool
322 322
 		if osmaster.TLS {
... ...
@@ -375,7 +375,7 @@ func start(cfg *config, args []string) error {
375 375
 			opts := x509request.DefaultVerifyOptions()
376 376
 			opts.Roots = roots
377 377
 			certauth := x509request.New(opts, x509request.CommonNameUserConversion)
378
-			authenticators = append(authenticators, group.NewGroupAdder(certauth, []string{authenticatedGroup}))
378
+			authenticators = append(authenticators, certauth)
379 379
 		} else {
380 380
 			// No security, use the same client config for all OpenShift clients
381 381
 			osClientConfig := kclient.Config{Host: cfg.MasterAddr.URL.String(), Version: latest.Version}
... ...
@@ -384,11 +384,15 @@ func start(cfg *config, args []string) error {
384 384
 		}
385 385
 
386 386
 		// TODO: make anonymous auth optional?
387
-		// TODO: should this map to a real user persisted in etcd?
388
-		authenticators = append(authenticators, authenticator.RequestFunc(func(req *http.Request) (api.UserInfo, bool, error) {
389
-			return &api.DefaultUserInfo{Name: unauthenticatedUsername, Groups: []string{unauthenticatedGroup}}, true, nil
390
-		}))
391
-		osmaster.Authenticator = unionrequest.NewUnionAuthentication(authenticators)
387
+		osmaster.Authenticator = &unionrequest.Authenticator{
388
+			FailOnError: true,
389
+			Handlers: []authenticator.Request{
390
+				group.NewGroupAdder(unionrequest.NewUnionAuthentication(authenticators...), []string{authenticatedGroup}),
391
+				authenticator.RequestFunc(func(req *http.Request) (api.UserInfo, bool, error) {
392
+					return &api.DefaultUserInfo{Name: unauthenticatedUsername, Groups: []string{unauthenticatedGroup}}, true, nil
393
+				}),
394
+			},
395
+		}
392 396
 
393 397
 		osmaster.BuildClients()
394 398
 		osmaster.EnsureCORSAllowedOrigins(cfg.CORSAllowedOrigins)