// +build !windows

package authz

import (
	"encoding/json"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"

	"github.com/docker/docker/integration-cli/daemon"
	"github.com/docker/docker/internal/test/environment"
	"github.com/docker/docker/pkg/authorization"
	"github.com/docker/docker/pkg/plugins"
)

var (
	testEnv *environment.Execution
	d       *daemon.Daemon
	server  *httptest.Server
)

const dockerdBinary = "dockerd"

func TestMain(m *testing.M) {
	var err error
	testEnv, err = environment.New()
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	err = environment.EnsureFrozenImagesLinux(testEnv)
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}

	testEnv.Print()
	setupSuite()
	exitCode := m.Run()
	teardownSuite()

	os.Exit(exitCode)
}

func setupTest(t *testing.T) func() {
	environment.ProtectAll(t, testEnv)

	d = daemon.New(t, "", dockerdBinary, daemon.Config{
		Experimental: testEnv.DaemonInfo.ExperimentalBuild,
	})

	return func() {
		if d != nil {
			d.Stop(t)
		}
		testEnv.Clean(t)
	}
}

func setupSuite() {
	mux := http.NewServeMux()
	server = httptest.NewServer(mux)

	mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
		b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
		if err != nil {
			panic("could not marshal json for /Plugin.Activate: " + err.Error())
		}
		w.Write(b)
	})

	mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
		defer r.Body.Close()
		body, err := ioutil.ReadAll(r.Body)
		if err != nil {
			panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
		}
		authReq := authorization.Request{}
		err = json.Unmarshal(body, &authReq)
		if err != nil {
			panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
		}

		assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
		assertAuthHeaders(authReq.RequestHeaders)

		// Count only server version api
		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
			ctrl.versionReqCount++
		}

		ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)

		reqRes := ctrl.reqRes
		if isAllowed(authReq.RequestURI) {
			reqRes = authorization.Response{Allow: true}
		}
		if reqRes.Err != "" {
			w.WriteHeader(http.StatusInternalServerError)
		}
		b, err := json.Marshal(reqRes)
		if err != nil {
			panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
		}

		ctrl.reqUser = authReq.User
		w.Write(b)
	})

	mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
		defer r.Body.Close()
		body, err := ioutil.ReadAll(r.Body)
		if err != nil {
			panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
		}
		authReq := authorization.Request{}
		err = json.Unmarshal(body, &authReq)
		if err != nil {
			panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
		}

		assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
		assertAuthHeaders(authReq.ResponseHeaders)

		// Count only server version api
		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
			ctrl.versionResCount++
		}
		resRes := ctrl.resRes
		if isAllowed(authReq.RequestURI) {
			resRes = authorization.Response{Allow: true}
		}
		if resRes.Err != "" {
			w.WriteHeader(http.StatusInternalServerError)
		}
		b, err := json.Marshal(resRes)
		if err != nil {
			panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
		}
		ctrl.resUser = authReq.User
		w.Write(b)
	})
}

func teardownSuite() {
	if server == nil {
		return
	}

	server.Close()
}

// assertAuthHeaders validates authentication headers are removed
func assertAuthHeaders(headers map[string]string) error {
	for k := range headers {
		if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
			panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
		}
	}
	return nil
}

// assertBody asserts that body is removed for non text/json requests
func assertBody(requestURI string, headers map[string]string, body []byte) {
	if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
		panic("Body included for authentication endpoint " + string(body))
	}

	for k, v := range headers {
		if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
			return
		}
	}
	if len(body) > 0 {
		panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
	}
}