package portallocator

import (
	"errors"
	"fmt"
	"net"
	"sync"
)

type portMap struct {
	p    map[int]struct{}
	last int
}

type (
	protocolMap map[string]*portMap
	ipMapping   map[string]protocolMap
)

const (
	BeginPortRange = 49153
	EndPortRange   = 65535
)

var (
	ErrAllPortsAllocated = errors.New("all ports are allocated")
	ErrUnknownProtocol   = errors.New("unknown protocol")
)

var (
	mutex sync.Mutex

	defaultIP = net.ParseIP("0.0.0.0")
	globalMap = ipMapping{}
)

type ErrPortAlreadyAllocated struct {
	ip   string
	port int
}

func NewErrPortAlreadyAllocated(ip string, port int) ErrPortAlreadyAllocated {
	return ErrPortAlreadyAllocated{
		ip:   ip,
		port: port,
	}
}

func (e ErrPortAlreadyAllocated) IP() string {
	return e.ip
}

func (e ErrPortAlreadyAllocated) Port() int {
	return e.port
}

func (e ErrPortAlreadyAllocated) IPPort() string {
	return fmt.Sprintf("%s:%d", e.ip, e.port)
}

func (e ErrPortAlreadyAllocated) Error() string {
	return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
}

func RequestPort(ip net.IP, proto string, port int) (int, error) {
	mutex.Lock()
	defer mutex.Unlock()

	if err := validateProto(proto); err != nil {
		return 0, err
	}

	ip = getDefault(ip)

	mapping := getOrCreate(ip)

	if port > 0 {
		if _, ok := mapping[proto].p[port]; !ok {
			mapping[proto].p[port] = struct{}{}
			return port, nil
		} else {
			return 0, NewErrPortAlreadyAllocated(ip.String(), port)
		}
	} else {
		port, err := findPort(ip, proto)

		if err != nil {
			return 0, err
		}

		return port, nil
	}
}

func ReleasePort(ip net.IP, proto string, port int) error {
	mutex.Lock()
	defer mutex.Unlock()

	ip = getDefault(ip)

	mapping := getOrCreate(ip)[proto]
	delete(mapping.p, port)

	return nil
}

func ReleaseAll() error {
	mutex.Lock()
	defer mutex.Unlock()

	globalMap = ipMapping{}

	return nil
}

func getOrCreate(ip net.IP) protocolMap {
	ipstr := ip.String()

	if _, ok := globalMap[ipstr]; !ok {
		globalMap[ipstr] = protocolMap{
			"tcp": &portMap{p: map[int]struct{}{}, last: 0},
			"udp": &portMap{p: map[int]struct{}{}, last: 0},
		}
	}

	return globalMap[ipstr]
}

func findPort(ip net.IP, proto string) (int, error) {
	mapping := getOrCreate(ip)[proto]

	if mapping.last == 0 {
		mapping.p[BeginPortRange] = struct{}{}
		mapping.last = BeginPortRange
		return BeginPortRange, nil
	}

	for port := mapping.last + 1; port != mapping.last; port++ {
		if port > EndPortRange {
			port = BeginPortRange
		}

		if _, ok := mapping.p[port]; !ok {
			mapping.p[port] = struct{}{}
			mapping.last = port
			return port, nil
		}

	}

	return 0, ErrAllPortsAllocated
}

func getDefault(ip net.IP) net.IP {
	if ip == nil {
		return defaultIP
	}

	return ip
}

func validateProto(proto string) error {
	if proto != "tcp" && proto != "udp" {
		return ErrUnknownProtocol
	}

	return nil
}