package flightcontrol

import (
	"context"
	"io"
	"runtime"
	"sort"
	"sync"
	"time"

	"github.com/moby/buildkit/util/progress"
	"github.com/pkg/errors"
)

// flightcontrol is like singleflight but with support for cancellation and
// nested progress reporting

var (
	errRetry        = errors.Errorf("retry")
	errRetryTimeout = errors.Errorf("exceeded retry timeout")
)

type contextKeyT string

var contextKey = contextKeyT("buildkit/util/flightcontrol.progress")

// Group is a flightcontrol synchronization group
type Group struct {
	mu sync.Mutex       // protects m
	m  map[string]*call // lazily initialized
}

// Do executes a context function syncronized by the key
func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, err error) {
	var backoff time.Duration
	for {
		v, err = g.do(ctx, key, fn)
		if err == nil || !errors.Is(err, errRetry) {
			return v, err
		}
		// backoff logic
		if backoff >= 3*time.Second {
			err = errors.Wrapf(errRetryTimeout, "flightcontrol")
			return v, err
		}
		runtime.Gosched()
		if backoff > 0 {
			time.Sleep(backoff)
			backoff *= 2
		} else {
			backoff = time.Millisecond
		}
	}
}

func (g *Group) do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (interface{}, error) {
	g.mu.Lock()
	if g.m == nil {
		g.m = make(map[string]*call)
	}

	if c, ok := g.m[key]; ok { // register 2nd waiter
		g.mu.Unlock()
		return c.wait(ctx)
	}

	c := newCall(fn)
	g.m[key] = c
	go func() {
		// cleanup after a caller has returned
		<-c.ready
		g.mu.Lock()
		delete(g.m, key)
		g.mu.Unlock()
		close(c.cleaned)
	}()
	g.mu.Unlock()
	return c.wait(ctx)
}

type call struct {
	mu      sync.Mutex
	result  interface{}
	err     error
	ready   chan struct{}
	cleaned chan struct{}

	ctx  *sharedContext
	ctxs []context.Context
	fn   func(ctx context.Context) (interface{}, error)
	once sync.Once

	closeProgressWriter func()
	progressState       *progressState
	progressCtx         context.Context
}

func newCall(fn func(ctx context.Context) (interface{}, error)) *call {
	c := &call{
		fn:            fn,
		ready:         make(chan struct{}),
		cleaned:       make(chan struct{}),
		progressState: newProgressState(),
	}
	ctx := newContext(c) // newSharedContext
	pr, pctx, closeProgressWriter := progress.NewContext(context.Background())

	c.progressCtx = pctx
	c.ctx = ctx
	c.closeProgressWriter = closeProgressWriter

	go c.progressState.run(pr) // TODO: remove this, wrap writer instead

	return c
}

func (c *call) run() {
	defer c.closeProgressWriter()
	ctx, cancel := context.WithCancel(c.ctx)
	defer cancel()
	v, err := c.fn(ctx)
	c.mu.Lock()
	c.result = v
	c.err = err
	c.mu.Unlock()
	close(c.ready)
}

func (c *call) wait(ctx context.Context) (v interface{}, err error) {
	c.mu.Lock()
	// detect case where caller has just returned, let it clean up before
	select {
	case <-c.ready:
		c.mu.Unlock()
		<-c.cleaned
		return nil, errRetry
	case <-c.ctx.done: // could return if no error
		c.mu.Unlock()
		<-c.cleaned
		return nil, errRetry
	default:
	}

	pw, ok, ctx := progress.FromContext(ctx)
	if ok {
		c.progressState.add(pw)
	}

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	c.ctxs = append(c.ctxs, ctx)

	c.mu.Unlock()

	go c.once.Do(c.run)

	select {
	case <-ctx.Done():
		if c.ctx.checkDone() {
			// if this cancelled the last context, then wait for function to shut down
			// and don't accept any more callers
			<-c.ready
			return c.result, c.err
		}
		if ok {
			c.progressState.close(pw)
		}
		return nil, ctx.Err()
	case <-c.ready:
		return c.result, c.err // shared not implemented yet
	}
}

func (c *call) Deadline() (deadline time.Time, ok bool) {
	c.mu.Lock()
	defer c.mu.Unlock()
	for _, ctx := range c.ctxs {
		select {
		case <-ctx.Done():
		default:
			dl, ok := ctx.Deadline()
			if ok {
				return dl, ok
			}
		}
	}
	return time.Time{}, false
}

func (c *call) Done() <-chan struct{} {
	return c.ctx.done
}

func (c *call) Err() error {
	select {
	case <-c.ctx.Done():
		return c.ctx.err
	default:
		return nil
	}
}

func (c *call) Value(key interface{}) interface{} {
	if key == contextKey {
		return c.progressState
	}
	c.mu.Lock()
	defer c.mu.Unlock()

	ctx := c.progressCtx
	select {
	case <-ctx.Done():
	default:
		if v := ctx.Value(key); v != nil {
			return v
		}
	}

	if len(c.ctxs) > 0 {
		ctx = c.ctxs[0]
		select {
		case <-ctx.Done():
		default:
			if v := ctx.Value(key); v != nil {
				return v
			}
		}
	}

	return nil
}

type sharedContext struct {
	*call
	done chan struct{}
	err  error
}

func newContext(c *call) *sharedContext {
	return &sharedContext{call: c, done: make(chan struct{})}
}

func (sc *sharedContext) checkDone() bool {
	sc.mu.Lock()
	select {
	case <-sc.done:
		sc.mu.Unlock()
		return true
	default:
	}
	var err error
	for _, ctx := range sc.ctxs {
		select {
		case <-ctx.Done():
			err = ctx.Err()
		default:
			sc.mu.Unlock()
			return false
		}
	}
	sc.err = err
	close(sc.done)
	sc.mu.Unlock()
	return true
}

type rawProgressWriter interface {
	WriteRawProgress(*progress.Progress) error
	Close() error
}

type progressState struct {
	mu      sync.Mutex
	items   map[string]*progress.Progress
	writers []rawProgressWriter
	done    bool
}

func newProgressState() *progressState {
	return &progressState{
		items: make(map[string]*progress.Progress),
	}
}

func (ps *progressState) run(pr progress.Reader) {
	for {
		p, err := pr.Read(context.TODO())
		if err != nil {
			if err == io.EOF {
				ps.mu.Lock()
				ps.done = true
				ps.mu.Unlock()
				for _, w := range ps.writers {
					w.Close()
				}
			}
			return
		}
		ps.mu.Lock()
		for _, p := range p {
			for _, w := range ps.writers {
				w.WriteRawProgress(p)
			}
			ps.items[p.ID] = p
		}
		ps.mu.Unlock()
	}
}

func (ps *progressState) add(pw progress.Writer) {
	rw, ok := pw.(rawProgressWriter)
	if !ok {
		return
	}
	ps.mu.Lock()
	plist := make([]*progress.Progress, 0, len(ps.items))
	for _, p := range ps.items {
		plist = append(plist, p)
	}
	sort.Slice(plist, func(i, j int) bool {
		return plist[i].Timestamp.Before(plist[j].Timestamp)
	})
	for _, p := range plist {
		rw.WriteRawProgress(p)
	}
	if ps.done {
		rw.Close()
	} else {
		ps.writers = append(ps.writers, rw)
	}
	ps.mu.Unlock()
}

func (ps *progressState) close(pw progress.Writer) {
	rw, ok := pw.(rawProgressWriter)
	if !ok {
		return
	}
	ps.mu.Lock()
	for i, w := range ps.writers {
		if w == rw {
			w.Close()
			ps.writers = append(ps.writers[:i], ps.writers[i+1:]...)
			break
		}
	}
	ps.mu.Unlock()
}