Browse code

Add organization restriction to github IDP

Jordan Liggitt authored on 2016/02/10 03:08:20
Showing 8 changed files
... ...
@@ -4,27 +4,40 @@ import (
4 4
 	"encoding/json"
5 5
 	"errors"
6 6
 	"fmt"
7
-	"io/ioutil"
8 7
 	"net/http"
8
+	"strings"
9
+
10
+	"k8s.io/kubernetes/pkg/util/sets"
9 11
 
10 12
 	"github.com/RangelReale/osincli"
11 13
 	"github.com/golang/glog"
12 14
 
13 15
 	authapi "github.com/openshift/origin/pkg/auth/api"
14 16
 	"github.com/openshift/origin/pkg/auth/oauth/external"
17
+	"github.com/openshift/origin/pkg/util/http/links"
15 18
 )
16 19
 
17 20
 const (
18 21
 	githubAuthorizeURL = "https://github.com/login/oauth/authorize"
19 22
 	githubTokenURL     = "https://github.com/login/oauth/access_token"
20 23
 	githubUserApiURL   = "https://api.github.com/user"
24
+	githubUserOrgURL   = "https://api.github.com/user/orgs"
21 25
 	githubOAuthScope   = "user:email"
26
+	githubOrgScope     = "read:org"
27
+
28
+	// https://developer.github.com/v3/#current-version
29
+	// https://developer.github.com/v3/media/#request-specific-version
30
+	githubAccept = "application/vnd.github.v3+json"
22 31
 )
23 32
 
24 33
 type provider struct {
25
-	providerName, clientID, clientSecret string
34
+	providerName         string
35
+	clientID             string
36
+	clientSecret         string
37
+	allowedOrganizations sets.String
26 38
 }
27 39
 
40
+// https://developer.github.com/v3/users/#response
28 41
 type githubUser struct {
29 42
 	ID    uint64
30 43
 	Login string
... ...
@@ -32,16 +45,40 @@ type githubUser struct {
32 32
 	Name  string
33 33
 }
34 34
 
35
-func NewProvider(providerName, clientID, clientSecret string) external.Provider {
36
-	return provider{providerName, clientID, clientSecret}
35
+// https://developer.github.com/v3/orgs/#response
36
+type githubOrg struct {
37
+	ID    uint64
38
+	Login string
37 39
 }
38 40
 
39
-func (p provider) GetTransport() (http.RoundTripper, error) {
41
+func NewProvider(providerName, clientID, clientSecret string, organizations []string) external.Provider {
42
+	allowedOrganizations := sets.NewString()
43
+	for _, org := range organizations {
44
+		if len(org) > 0 {
45
+			allowedOrganizations.Insert(strings.ToLower(org))
46
+		}
47
+	}
48
+
49
+	return &provider{
50
+		providerName:         providerName,
51
+		clientID:             clientID,
52
+		clientSecret:         clientSecret,
53
+		allowedOrganizations: allowedOrganizations,
54
+	}
55
+}
56
+
57
+func (p *provider) GetTransport() (http.RoundTripper, error) {
40 58
 	return nil, nil
41 59
 }
42 60
 
43 61
 // NewConfig implements external/interfaces/Provider.NewConfig
44
-func (p provider) NewConfig() (*osincli.ClientConfig, error) {
62
+func (p *provider) NewConfig() (*osincli.ClientConfig, error) {
63
+	scopes := []string{githubOAuthScope}
64
+	// if we're limiting to specific organizations, we also need to read their org membership
65
+	if len(p.allowedOrganizations) > 0 {
66
+		scopes = append(scopes, githubOrgScope)
67
+	}
68
+
45 69
 	config := &osincli.ClientConfig{
46 70
 		ClientId:                 p.clientID,
47 71
 		ClientSecret:             p.clientSecret,
... ...
@@ -49,7 +86,7 @@ func (p provider) NewConfig() (*osincli.ClientConfig, error) {
49 49
 		SendClientSecretInParams: true,
50 50
 		AuthorizeUrl:             githubAuthorizeURL,
51 51
 		TokenUrl:                 githubTokenURL,
52
-		Scope:                    githubOAuthScope,
52
+		Scope:                    strings.Join(scopes, " "),
53 53
 	}
54 54
 	return config, nil
55 55
 }
... ...
@@ -59,31 +96,26 @@ func (p provider) AddCustomParameters(req *osincli.AuthorizeRequest) {
59 59
 }
60 60
 
61 61
 // GetUserIdentity implements external/interfaces/Provider.GetUserIdentity
62
-func (p provider) GetUserIdentity(data *osincli.AccessData) (authapi.UserIdentityInfo, bool, error) {
63
-	req, _ := http.NewRequest("GET", githubUserApiURL, nil)
64
-	req.Header.Set("Authorization", fmt.Sprintf("bearer %s", data.AccessToken))
65
-
66
-	res, err := http.DefaultClient.Do(req)
67
-	if err != nil {
68
-		return nil, false, err
69
-	}
70
-	defer res.Body.Close()
71
-
72
-	body, err := ioutil.ReadAll(res.Body)
73
-	if err != nil {
74
-		return nil, false, err
75
-	}
76
-
62
+func (p *provider) GetUserIdentity(data *osincli.AccessData) (authapi.UserIdentityInfo, bool, error) {
77 63
 	userdata := githubUser{}
78
-	err = json.Unmarshal(body, &userdata)
79
-	if err != nil {
64
+	if _, err := getJSON(githubUserApiURL, data.AccessToken, &userdata); err != nil {
80 65
 		return nil, false, err
81 66
 	}
82
-
83 67
 	if userdata.ID == 0 {
84 68
 		return nil, false, errors.New("Could not retrieve GitHub id")
85 69
 	}
86 70
 
71
+	if len(p.allowedOrganizations) > 0 {
72
+		userOrgs, err := getUserOrgs(data.AccessToken)
73
+		if err != nil {
74
+			return nil, false, err
75
+		}
76
+
77
+		if !userOrgs.HasAny(p.allowedOrganizations.List()...) {
78
+			return nil, false, fmt.Errorf("User %s is not a member of any allowed organizations %v (user is a member of %v)", userdata.Login, p.allowedOrganizations.List(), userOrgs.List())
79
+		}
80
+	}
81
+
87 82
 	identity := authapi.NewDefaultUserIdentityInfo(p.providerName, fmt.Sprintf("%d", userdata.ID))
88 83
 	if len(userdata.Name) > 0 {
89 84
 		identity.Extra[authapi.IdentityDisplayNameKey] = userdata.Name
... ...
@@ -98,3 +130,67 @@ func (p provider) GetUserIdentity(data *osincli.AccessData) (authapi.UserIdentit
98 98
 
99 99
 	return identity, true, nil
100 100
 }
101
+
102
+// getUserOrgs retrieves the organization membership for the user with the given access token.
103
+func getUserOrgs(token string) (sets.String, error) {
104
+	// start with the empty set, and the initial org url
105
+	userOrgs := sets.NewString()
106
+	orgURL := githubUserOrgURL
107
+	// track urls we've fetched to avoid cycles
108
+	fetchedURLs := sets.NewString(orgURL)
109
+	for {
110
+		// fetch organizations
111
+		organizations := []githubOrg{}
112
+		links, err := getJSON(orgURL, token, &organizations)
113
+		if err != nil {
114
+			return nil, err
115
+		}
116
+		for _, org := range organizations {
117
+			if len(org.Login) > 0 {
118
+				userOrgs.Insert(strings.ToLower(org.Login))
119
+			}
120
+		}
121
+
122
+		// see if we need to page
123
+		// https://developer.github.com/v3/#link-header
124
+		nextURL := links["next"]
125
+		if len(nextURL) == 0 {
126
+			// no next URL, we're done paging
127
+			break
128
+		}
129
+		if fetchedURLs.Has(nextURL) {
130
+			// break to avoid a loop
131
+			break
132
+		}
133
+		// remember to avoid a loop
134
+		fetchedURLs.Insert(nextURL)
135
+		orgURL = nextURL
136
+	}
137
+
138
+	return userOrgs, nil
139
+}
140
+
141
+// getJSON fetches and deserializes JSON into the given object.
142
+// returns a (possibly empty) map of link relations to url strings, or an error.
143
+func getJSON(url string, token string, data interface{}) (map[string]string, error) {
144
+	req, _ := http.NewRequest("GET", url, nil)
145
+	req.Header.Set("Authorization", fmt.Sprintf("bearer %s", token))
146
+	req.Header.Set("Accept", githubAccept)
147
+
148
+	res, err := http.DefaultClient.Do(req)
149
+	if err != nil {
150
+		return nil, err
151
+	}
152
+	defer res.Body.Close()
153
+
154
+	if res.StatusCode != http.StatusOK {
155
+		return nil, fmt.Errorf("Non-200 response from GitHub API call %s: %d", url, res.StatusCode)
156
+	}
157
+
158
+	if err := json.NewDecoder(res.Body).Decode(&data); err != nil {
159
+		return nil, err
160
+	}
161
+
162
+	links := links.ParseLinks(res.Header.Get("Link"))
163
+	return links, nil
164
+}
... ...
@@ -7,5 +7,5 @@ import (
7 7
 )
8 8
 
9 9
 func TestGitHub(t *testing.T) {
10
-	_ = external.Provider(NewProvider("github", "clientid", "clientsecret"))
10
+	_ = external.Provider(NewProvider("github", "clientid", "clientsecret", nil))
11 11
 }
... ...
@@ -675,6 +675,8 @@ type GitHubIdentityProvider struct {
675 675
 	ClientID string
676 676
 	// ClientSecret is the oauth client secret
677 677
 	ClientSecret string
678
+	// Organizations optionally restricts which organizations are allowed to log in
679
+	Organizations []string
678 680
 }
679 681
 
680 682
 type GitLabIdentityProvider struct {
... ...
@@ -628,6 +628,8 @@ type GitHubIdentityProvider struct {
628 628
 	ClientID string `json:"clientID"`
629 629
 	// ClientSecret is the oauth client secret
630 630
 	ClientSecret string `json:"clientSecret"`
631
+	// Organizations optionally restricts which organizations are allowed to log in
632
+	Organizations []string `json:"organizations"`
631 633
 }
632 634
 
633 635
 type GitLabIdentityProvider struct {
... ...
@@ -268,6 +268,7 @@ oauthConfig:
268 268
       clientID: ""
269 269
       clientSecret: ""
270 270
       kind: GitHubIdentityProvider
271
+      organizations: null
271 272
   - challenge: false
272 273
     login: false
273 274
     mappingMethod: ""
... ...
@@ -465,7 +465,7 @@ func (c *AuthConfig) getAuthenticationHandler(mux cmdutil.Mux, errorHandler hand
465 465
 func (c *AuthConfig) getOAuthProvider(identityProvider configapi.IdentityProvider) (external.Provider, error) {
466 466
 	switch provider := identityProvider.Provider.(type) {
467 467
 	case (*configapi.GitHubIdentityProvider):
468
-		return github.NewProvider(identityProvider.Name, provider.ClientID, provider.ClientSecret), nil
468
+		return github.NewProvider(identityProvider.Name, provider.ClientID, provider.ClientSecret, provider.Organizations), nil
469 469
 
470 470
 	case (*configapi.GitLabIdentityProvider):
471 471
 		transport, err := cmdutil.TransportFor(provider.CA, "", "")
472 472
new file mode 100644
... ...
@@ -0,0 +1,25 @@
0
+package links
1
+
2
+import "regexp"
3
+
4
+// Matches URL+rel links defined by https://tools.ietf.org/html/rfc5988
5
+// Examples header values:
6
+//   <http://www.example.com/foo?page=3>; rel="next"
7
+//   <http://www.example.com/foo?page=3>; rel="next", <http://www.example.com/foo?page=1>; rel="prev"
8
+var linkRegex = regexp.MustCompile(`\<(.+?)\>\s*;\s*rel="(.+?)"(?:\s*,\s*)?`)
9
+
10
+// ParseLinks extracts link relations from the given header value.
11
+func ParseLinks(header string) map[string]string {
12
+	links := map[string]string{}
13
+	if len(header) == 0 {
14
+		return links
15
+	}
16
+
17
+	matches := linkRegex.FindAllStringSubmatch(header, -1)
18
+	for _, match := range matches {
19
+		url := match[1]
20
+		rel := match[2]
21
+		links[rel] = url
22
+	}
23
+	return links
24
+}
0 25
new file mode 100644
... ...
@@ -0,0 +1,61 @@
0
+package links
1
+
2
+import (
3
+	"reflect"
4
+	"testing"
5
+)
6
+
7
+func TestLinks(t *testing.T) {
8
+	testcases := map[string]struct {
9
+		header string
10
+		links  map[string]string
11
+	}{
12
+		"empty": {
13
+			header: "",
14
+			links:  map[string]string{},
15
+		},
16
+		"unparseable": {
17
+			header: "foo bar baz",
18
+			links:  map[string]string{},
19
+		},
20
+		"single link": {
21
+			header: `<https://example.com/user/orgs?per_page=1&page=2>; rel="next"`,
22
+			links: map[string]string{
23
+				"next": "https://example.com/user/orgs?per_page=1&page=2",
24
+			},
25
+		},
26
+		"single link with unknown suffix": {
27
+			header: `<https://example.com/user/orgs?per_page=1&page=2>; rel="next", foo bar baz`,
28
+			links: map[string]string{
29
+				"next": "https://example.com/user/orgs?per_page=1&page=2",
30
+			},
31
+		},
32
+		"duplicate link": {
33
+			header: `<https://example.com/user/orgs?per_page=1&page=2>; rel="next", <https://example.com/user/orgs?per_page=1&page=3>; rel="next"`,
34
+			links: map[string]string{
35
+				"next": "https://example.com/user/orgs?per_page=1&page=3",
36
+			},
37
+		},
38
+		"no whitespace": {
39
+			header: `<https://example.com/user/orgs?per_page=1&page=2>;rel="next",<https://example.com/user/orgs?per_page=1&page=8>;rel="last"`,
40
+			links: map[string]string{
41
+				"next": "https://example.com/user/orgs?per_page=1&page=2",
42
+				"last": "https://example.com/user/orgs?per_page=1&page=8",
43
+			},
44
+		},
45
+		"extra whitespace": {
46
+			header: `  <https://example.com/user/orgs?per_page=1&page=2>;  rel="next"  ,		<https://example.com/user/orgs?per_page=1&page=8>		;		rel="last"		`,
47
+			links: map[string]string{
48
+				"next": "https://example.com/user/orgs?per_page=1&page=2",
49
+				"last": "https://example.com/user/orgs?per_page=1&page=8",
50
+			},
51
+		},
52
+	}
53
+
54
+	for k, tc := range testcases {
55
+		links := ParseLinks(tc.header)
56
+		if !reflect.DeepEqual(links, tc.links) {
57
+			t.Errorf("%s: Expected\n%#v\ngot\n%#v", k, tc.links, links)
58
+		}
59
+	}
60
+}