package ttrpc

import (
	"context"
	"net"
	"sync"

	"github.com/containerd/containerd/log"
	"github.com/gogo/protobuf/proto"
	"github.com/pkg/errors"
	"google.golang.org/grpc/status"
)

type Client struct {
	codec   codec
	conn    net.Conn
	channel *channel
	calls   chan *callRequest

	closed    chan struct{}
	closeOnce sync.Once
	done      chan struct{}
	err       error
}

func NewClient(conn net.Conn) *Client {
	c := &Client{
		codec:   codec{},
		conn:    conn,
		channel: newChannel(conn, conn),
		calls:   make(chan *callRequest),
		closed:  make(chan struct{}),
		done:    make(chan struct{}),
	}

	go c.run()
	return c
}

type callRequest struct {
	ctx  context.Context
	req  *Request
	resp *Response  // response will be written back here
	errs chan error // error written here on completion
}

func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
	payload, err := c.codec.Marshal(req)
	if err != nil {
		return err
	}

	var (
		creq = &Request{
			Service: service,
			Method:  method,
			Payload: payload,
		}

		cresp = &Response{}
	)

	if err := c.dispatch(ctx, creq, cresp); err != nil {
		return err
	}

	if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
		return err
	}

	if cresp.Status == nil {
		return errors.New("no status provided on response")
	}

	return status.ErrorProto(cresp.Status)
}

func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
	errs := make(chan error, 1)
	call := &callRequest{
		req:  req,
		resp: resp,
		errs: errs,
	}

	select {
	case c.calls <- call:
	case <-c.done:
		return c.err
	}

	select {
	case err := <-errs:
		return err
	case <-c.done:
		return c.err
	}
}

func (c *Client) Close() error {
	c.closeOnce.Do(func() {
		close(c.closed)
	})

	return nil
}

type message struct {
	messageHeader
	p   []byte
	err error
}

func (c *Client) run() {
	var (
		streamID    uint32 = 1
		waiters            = make(map[uint32]*callRequest)
		calls              = c.calls
		incoming           = make(chan *message)
		shutdown           = make(chan struct{})
		shutdownErr error
	)

	go func() {
		defer close(shutdown)

		// start one more goroutine to recv messages without blocking.
		for {
			mh, p, err := c.channel.recv(context.TODO())
			if err != nil {
				_, ok := status.FromError(err)
				if !ok {
					// treat all errors that are not an rpc status as terminal.
					// all others poison the connection.
					shutdownErr = err
					return
				}
			}
			select {
			case incoming <- &message{
				messageHeader: mh,
				p:             p[:mh.Length],
				err:           err,
			}:
			case <-c.done:
				return
			}
		}
	}()

	defer c.conn.Close()
	defer close(c.done)

	for {
		select {
		case call := <-calls:
			if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
				call.errs <- err
				continue
			}

			waiters[streamID] = call
			streamID += 2 // enforce odd client initiated request ids
		case msg := <-incoming:
			call, ok := waiters[msg.StreamID]
			if !ok {
				log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
				continue
			}

			call.errs <- c.recv(call.resp, msg)
			delete(waiters, msg.StreamID)
		case <-shutdown:
			shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
			c.err = shutdownErr
			for _, waiter := range waiters {
				waiter.errs <- shutdownErr
			}
			c.Close()
			return
		case <-c.closed:
			// broadcast the shutdown error to the remaining waiters.
			for _, waiter := range waiters {
				waiter.errs <- shutdownErr
			}
			return
		}
	}
}

func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
	p, err := c.codec.Marshal(msg)
	if err != nil {
		return err
	}

	return c.channel.send(ctx, streamID, mtype, p)
}

func (c *Client) recv(resp *Response, msg *message) error {
	if msg.err != nil {
		return msg.err
	}

	if msg.Type != messageTypeResponse {
		return errors.New("unkown message type received")
	}

	defer c.channel.putmbuf(msg.p)
	return proto.Unmarshal(msg.p, resp)
}