package csrf import ( "net/http" "net/http/httptest" "testing" ) func TestCookieGenerate(t *testing.T) { testCases := map[string]struct { Name string Path string Domain string Secure bool HTTPOnly bool ExistingCookie *http.Cookie ExpectToken string ExpectSetCookie bool }{ "use existing": { Name: "csrf", ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"}, ExpectToken: "existingvalue", ExpectSetCookie: false, }, "set missing": { Name: "csrf", ExpectSetCookie: true, }, "set missing with other cookies": { Name: "csrf", ExistingCookie: &http.Cookie{Name: "csrf2", Value: "existingvalue"}, ExpectSetCookie: true, }, "set missing with cookie options": { Name: "csrf", Path: "/", Domain: "foo.com", Secure: true, HTTPOnly: true, ExpectSetCookie: true, }, } for k, testCase := range testCases { csrf := NewCookieCSRF(testCase.Name, testCase.Path, testCase.Domain, testCase.Secure, testCase.HTTPOnly) req, _ := http.NewRequest("GET", "/", nil) if testCase.ExistingCookie != nil { req.AddCookie(testCase.ExistingCookie) } w := httptest.NewRecorder() token, err := csrf.Generate(w, req) if err != nil { t.Errorf("%s: Unexpected error: %v", k, err) continue } if len(testCase.ExpectToken) != 0 && token != testCase.ExpectToken { t.Errorf("%s: Unexpected token %s, got %s", k, testCase.ExpectToken, token) continue } setCookie := w.Header().Get("Set-Cookie") if testCase.ExpectSetCookie { if len(setCookie) == 0 { t.Errorf("%s: Expected set-cookie header", k) continue } protoCookie := &http.Cookie{ Name: testCase.Name, Value: token, Path: testCase.Path, Domain: testCase.Domain, Secure: testCase.Secure, HttpOnly: testCase.HTTPOnly, } if setCookie != protoCookie.String() { t.Errorf("%s: Expected Set-Cookie header of \"%s\", got \"%s\"", k, protoCookie.String(), setCookie) continue } } else { if len(setCookie) > 0 { t.Errorf("%s: Didn't expect set-cookie header, got %s", k, setCookie) continue } } } } func TestCookieCheck(t *testing.T) { testCases := map[string]struct { Name string Token string ExistingCookie *http.Cookie ExpectCheck bool }{ "fail empty token": { Name: "csrf", Token: "", ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"}, ExpectCheck: false, }, "fail empty cookie": { Name: "csrf", Token: "mytoken", ExistingCookie: &http.Cookie{Name: "csrf", Value: ""}, ExpectCheck: false, }, "fail missing cookie": { Name: "csrf", Token: "mytoken", ExpectCheck: false, }, "fail mismatch cookie": { Name: "csrf", Token: "mytoken", ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"}, ExpectCheck: false, }, "fail mismatch cookie name": { Name: "csrf", Token: "mytoken", ExistingCookie: &http.Cookie{Name: "csrf2", Value: "mytoken"}, ExpectCheck: false, }, "pass matching cookie": { Name: "csrf", Token: "existingvalue", ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"}, ExpectCheck: true, }, } for k, testCase := range testCases { csrf := NewCookieCSRF(testCase.Name, "", "", false, false) req, _ := http.NewRequest("GET", "/", nil) if testCase.ExistingCookie != nil { req.AddCookie(testCase.ExistingCookie) } ok, err := csrf.Check(req, testCase.Token) if err != nil { t.Errorf("%s: Unexpected error: %v", k, err) continue } if ok != testCase.ExpectCheck { t.Errorf("%s: Expected check to return %v, returned %v", k, testCase.ExpectCheck, ok) continue } } }