package csrf

import (
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestCookieGenerate(t *testing.T) {

	testCases := map[string]struct {
		Name           string
		Path           string
		Domain         string
		Secure         bool
		HTTPOnly       bool
		ExistingCookie *http.Cookie

		ExpectToken     string
		ExpectSetCookie bool
	}{
		"use existing": {
			Name:           "csrf",
			ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"},

			ExpectToken:     "existingvalue",
			ExpectSetCookie: false,
		},

		"set missing": {
			Name: "csrf",

			ExpectSetCookie: true,
		},

		"set missing with other cookies": {
			Name:           "csrf",
			ExistingCookie: &http.Cookie{Name: "csrf2", Value: "existingvalue"},

			ExpectSetCookie: true,
		},

		"set missing with cookie options": {
			Name:     "csrf",
			Path:     "/",
			Domain:   "foo.com",
			Secure:   true,
			HTTPOnly: true,

			ExpectSetCookie: true,
		},
	}

	for k, testCase := range testCases {
		csrf := NewCookieCSRF(testCase.Name, testCase.Path, testCase.Domain, testCase.Secure, testCase.HTTPOnly)

		req, _ := http.NewRequest("GET", "/", nil)
		if testCase.ExistingCookie != nil {
			req.AddCookie(testCase.ExistingCookie)
		}

		w := httptest.NewRecorder()
		token, err := csrf.Generate(w, req)

		if err != nil {
			t.Errorf("%s: Unexpected error: %v", k, err)
			continue
		}

		if len(testCase.ExpectToken) != 0 && token != testCase.ExpectToken {
			t.Errorf("%s: Unexpected token %s, got %s", k, testCase.ExpectToken, token)
			continue
		}

		setCookie := w.Header().Get("Set-Cookie")
		if testCase.ExpectSetCookie {
			if len(setCookie) == 0 {
				t.Errorf("%s: Expected set-cookie header", k)
				continue
			}

			protoCookie := &http.Cookie{
				Name:     testCase.Name,
				Value:    token,
				Path:     testCase.Path,
				Domain:   testCase.Domain,
				Secure:   testCase.Secure,
				HttpOnly: testCase.HTTPOnly,
			}
			if setCookie != protoCookie.String() {
				t.Errorf("%s: Expected Set-Cookie header of \"%s\", got \"%s\"", k, protoCookie.String(), setCookie)
				continue
			}
		} else {
			if len(setCookie) > 0 {
				t.Errorf("%s: Didn't expect set-cookie header, got %s", k, setCookie)
				continue
			}
		}
	}
}

func TestCookieCheck(t *testing.T) {

	testCases := map[string]struct {
		Name           string
		Token          string
		ExistingCookie *http.Cookie

		ExpectCheck bool
	}{
		"fail empty token": {
			Name:           "csrf",
			Token:          "",
			ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"},

			ExpectCheck: false,
		},

		"fail empty cookie": {
			Name:           "csrf",
			Token:          "mytoken",
			ExistingCookie: &http.Cookie{Name: "csrf", Value: ""},

			ExpectCheck: false,
		},

		"fail missing cookie": {
			Name:  "csrf",
			Token: "mytoken",

			ExpectCheck: false,
		},

		"fail mismatch cookie": {
			Name:           "csrf",
			Token:          "mytoken",
			ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"},

			ExpectCheck: false,
		},

		"fail mismatch cookie name": {
			Name:           "csrf",
			Token:          "mytoken",
			ExistingCookie: &http.Cookie{Name: "csrf2", Value: "mytoken"},

			ExpectCheck: false,
		},

		"pass matching cookie": {
			Name:           "csrf",
			Token:          "existingvalue",
			ExistingCookie: &http.Cookie{Name: "csrf", Value: "existingvalue"},

			ExpectCheck: true,
		},
	}

	for k, testCase := range testCases {
		csrf := NewCookieCSRF(testCase.Name, "", "", false, false)

		req, _ := http.NewRequest("GET", "/", nil)
		if testCase.ExistingCookie != nil {
			req.AddCookie(testCase.ExistingCookie)
		}

		ok, err := csrf.Check(req, testCase.Token)

		if err != nil {
			t.Errorf("%s: Unexpected error: %v", k, err)
			continue
		}

		if ok != testCase.ExpectCheck {
			t.Errorf("%s: Expected check to return %v, returned %v", k, testCase.ExpectCheck, ok)
			continue
		}
	}
}