package portallocator
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"syscall"
"github.com/containerd/log"
"github.com/ishidawataru/sctp"
"github.com/moby/moby/v2/daemon/libnetwork/types"
"golang.org/x/net/bpf"
"golang.org/x/sys/unix"
)
// maxAllocateAttempts is the maximum number of times OSAllocator.RequestPortsInRange
// will try to allocate a port before returning an error. This is an arbitrary
// limit.
const maxAllocateAttempts = 10
type OSAllocator struct {
// allocator is used to logically reserve ports, to avoid those we know
// are already in use. This is useful to ensure callers don't burn their
// retry budget unnecessarily.
allocator *PortAllocator
}
func NewOSAllocator() OSAllocator {
return OSAllocator{
allocator: Get(),
}
}
// RequestPortsInRange reserves a port available in the range [portStart, portEnd]
// for all the specified addrs, and then try to bind/listen those addresses to
// allocate the port from the OS.
//
// It returns the allocated port, and all the sockets bound, or an error if the
// reserved port isn't available. These sockets have a filter set to ensure that
// the kernel doesn't accept connections on these. Callers must take care of
// calling DetachSocketFilter once they're ready to accept connections (e.g. after
// setting up DNAT rules, and before starting the userland proxy), and they must
// take care of closing the returned sockets.
//
// It's safe for concurrent use.
func (pa OSAllocator) RequestPortsInRange(addrs []net.IP, proto types.Protocol, portStart, portEnd int) (_ int, _ []*os.File, retErr error) {
var port int
var socks []*os.File
var err error
// Try up to maxAllocatePortAttempts times to get a port that's not already allocated.
for i := range maxAllocateAttempts {
port, socks, err = pa.attemptAllocation(addrs, proto, portStart, portEnd)
if err == nil {
break
}
// There is no point in immediately retrying to map an explicitly chosen port.
if portStart != 0 && portStart == portEnd {
log.G(context.TODO()).WithError(err).Warnf("Failed to allocate port")
return 0, nil, err
}
// Do not retry if a port range is specified and all ports in that range are already allocated.
if errors.Is(err, errAllPortsAllocated) {
return 0, nil, err
}
log.G(context.TODO()).WithFields(log.Fields{
"error": err,
"attempt": i + 1,
}).Warn("Failed to allocate port")
}
if err != nil {
// If the retry budget is exhausted and no free port could be found, return
// the latest error.
return 0, nil, err
}
return port, socks, nil
}
// attemptAllocation requests a port from the allocator and tries to bind/listen on that port
// on each of addrs. If the bind/listen fails, it means the allocator thought the port was free,
// but it was in use by some other process.
func (pa OSAllocator) attemptAllocation(addrs []net.IP, proto types.Protocol, portStart, portEnd int) (_ int, _ []*os.File, retErr error) {
port, err := pa.allocator.RequestPortsInRange(addrs, proto.String(), portStart, portEnd)
if err != nil {
return 0, nil, err
}
defer func() {
if retErr != nil {
for _, addr := range addrs {
pa.allocator.ReleasePort(addr, proto.String(), port)
}
}
}()
var boundSocks []*os.File
defer func() {
if retErr != nil {
for i, sock := range boundSocks {
if err := sock.Close(); err != nil {
log.G(context.TODO()).WithFields(log.Fields{
"addr": addrs[i],
"port": port,
}).WithError(err).Warnf("failed to close socket during port allocation")
}
}
}
}()
for _, addr := range addrs {
addr, _ := netip.AddrFromSlice(addr)
addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(port))
var sock *os.File
var err error
switch proto {
case types.TCP:
sock, err = listenTCP(addrPort)
case types.UDP:
sock, err = bindTCPOrUDP(addrPort, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
case types.SCTP:
sock, err = listenSCTP(addrPort)
default:
return 0, nil, fmt.Errorf("protocol %s not supported", proto)
}
if err != nil {
return 0, nil, err
}
boundSocks = append(boundSocks, sock)
}
return port, boundSocks, nil
}
// ReleasePorts releases a common port reserved for a list of addrs. It doesn't
// close the sockets bound by [RequestPortsInRange]. This must be taken care of
// independently by the caller.
func (pa OSAllocator) ReleasePorts(addrs []net.IP, proto types.Protocol, port int) {
for _, addr := range addrs {
pa.allocator.ReleasePort(addr, proto.String(), port)
}
}
func listenTCP(addr netip.AddrPort) (_ *os.File, retErr error) {
boundSocket, err := bindTCPOrUDP(addr, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
if err != nil {
return nil, err
}
somaxconn := -1 // silently capped to "/proc/sys/net/core/somaxconn"
if err := syscall.Listen(int(boundSocket.Fd()), somaxconn); err != nil {
return nil, fmt.Errorf("failed to listen on tcp socket: %w", err)
}
return boundSocket, nil
}
func bindTCPOrUDP(addr netip.AddrPort, typ int, proto types.Protocol) (_ *os.File, retErr error) {
var domain int
var sa syscall.Sockaddr
if addr.Addr().Unmap().Is4() {
domain = syscall.AF_INET
sa = &syscall.SockaddrInet4{Addr: addr.Addr().As4(), Port: int(addr.Port())}
} else {
domain = syscall.AF_INET6
sa = &syscall.SockaddrInet6{Addr: addr.Addr().Unmap().As16(), Port: int(addr.Port())}
}
sd, err := syscall.Socket(domain, typ|syscall.SOCK_CLOEXEC, int(proto))
if err != nil {
return nil, fmt.Errorf("failed to create socket for %s/%s: %w", addr, proto, err)
}
defer func() {
if retErr != nil {
syscall.Close(sd)
}
}()
if proto == syscall.IPPROTO_TCP {
if err := syscall.SetsockoptInt(sd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
return nil, fmt.Errorf("failed to setsockopt(SO_REUSEADDR) for %s/%s: %w", addr, proto, err)
}
}
// We need to listen to make sure that the port is free, and no other process is racing against us to acquire this
// port. But listening means that connections could be accepted before DNAT rules are inserted, and they'd never
// reach the container. To avoid this, set a socket filter to drop all connections — TCP SYNs will be
// re-transmitted anyway. Callers must call DetachSocketFilter.
//
// Set the socket filter _before_ binding the socket to make sure that no UDP datagrams will fill the queue.
if err := setSocketFilter(sd); err != nil {
return nil, fmt.Errorf("failed to set drop packets filter for %s/%s: %w", addr, proto, err)
}
if domain == syscall.AF_INET6 {
syscall.SetsockoptInt(sd, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1)
}
if typ == syscall.SOCK_DGRAM {
// Enable IP_PKTINFO for UDP sockets to get the destination address.
// The destination address will be used as the source address when
// sending back replies coming from the container.
lvl := syscall.IPPROTO_IP
opt := syscall.IP_PKTINFO
optName := "IP_PKTINFO"
if domain == syscall.AF_INET6 {
lvl = syscall.IPPROTO_IPV6
opt = syscall.IPV6_RECVPKTINFO
optName = "IPV6_RECVPKTINFO"
}
if err := syscall.SetsockoptInt(sd, lvl, opt, 1); err != nil {
return nil, fmt.Errorf("failed to setsockopt(%s) for %s/%s: %w", optName, addr, proto, err)
}
}
if err := syscall.Bind(sd, sa); err != nil {
return nil, fmt.Errorf("failed to bind host port %s/%s: %w", addr, proto, err)
}
boundSocket := os.NewFile(uintptr(sd), "listener")
if boundSocket == nil {
return nil, fmt.Errorf("failed to convert socket to file for %s/%s", addr, proto)
}
return boundSocket, nil
}
// listenSCTP is based on sctp.ListenSCTP.
func listenSCTP(addr netip.AddrPort) (_ *os.File, retErr error) {
boundSocket, err := bindSCTP(addr)
if err != nil {
return nil, err
}
somaxconn := -1 // silently capped to "/proc/sys/net/core/somaxconn"
if err := syscall.Listen(int(boundSocket.Fd()), somaxconn); err != nil {
return nil, fmt.Errorf("failed to listen on sctp socket: %w", err)
}
return boundSocket, nil
}
func bindSCTP(addr netip.AddrPort) (_ *os.File, retErr error) {
domain := syscall.AF_INET
if addr.Addr().Unmap().Is6() {
domain = syscall.AF_INET6
}
sd, err := syscall.Socket(domain, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, syscall.IPPROTO_SCTP)
if err != nil {
return nil, fmt.Errorf("failed to create socket for %s/sctp: %w", addr, err)
}
defer func() {
if retErr != nil {
syscall.Close(sd)
}
}()
if domain == syscall.AF_INET6 {
syscall.SetsockoptInt(sd, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1)
}
if errno := setSCTPInitMsg(sd, sctp.InitMsg{NumOstreams: sctp.SCTP_MAX_STREAM}); errno != 0 {
return nil, errno
}
if err := sctp.SCTPBind(sd,
&sctp.SCTPAddr{IPAddrs: []net.IPAddr{{IP: addr.Addr().Unmap().AsSlice()}}, Port: int(addr.Port())},
sctp.SCTP_BINDX_ADD_ADDR); err != nil {
return nil, fmt.Errorf("failed to bind host port %s/sctp: %w", addr, err)
}
// We need to listen to make sure that the port is free, and no other process is racing against us to acquire this
// port. But listening means that connections could be accepted before DNAT rules are inserted, and they'd never
// reach the container. To avoid this, set a socket filter to drop all connections — SCTP handshake will be
// re-transmitted anyway. Callers must call DetachSocketFilter.
if err := setSocketFilter(sd); err != nil {
return nil, fmt.Errorf("failed to set drop packets filter for %s/sctp: %w", addr, err)
}
boundSocket := os.NewFile(uintptr(sd), "listener")
if boundSocket == nil {
return nil, fmt.Errorf("failed to convert socket %s/sctp", addr)
}
return boundSocket, nil
}
// DetachSocketFilter removes the BPF filter set during port allocation to prevent the kernel from accepting connections
// before DNAT rules are inserted.
func DetachSocketFilter(f *os.File) error {
return unix.SetsockoptInt(int(f.Fd()), syscall.SOL_SOCKET, syscall.SO_DETACH_FILTER, 0 /* ignored */)
}
// setSocketFilter sets a cBPF program on socket sd to drop all packets. To start receiving packets on this socket,
// callers must call DetachSocketFilter.
func setSocketFilter(sd int) error {
asm, err := bpf.Assemble([]bpf.Instruction{
// A cBPF program attached to a socket with SO_ATTACH_FILTER and
// returning 0 tells the kernel to drop all packets.
bpf.RetConstant{Val: 0x0},
})
if err != nil {
// (bpf.RetConstant).Assemble() doesn't return an error, so this should
// be unreachable code.
return fmt.Errorf("attaching socket filter: %w", err)
}
// Make sure the asm slice is not GC'd before setsockopt is called
defer runtime.KeepAlive(asm)
if len(asm) == 0 {
return errors.New("attaching socket filter: empty BPF program")
}
f := make([]unix.SockFilter, len(asm))
for i := range asm {
f[i] = unix.SockFilter{
Code: asm[i].Op,
Jt: asm[i].Jt,
Jf: asm[i].Jf,
K: asm[i].K,
}
}
return unix.SetsockoptSockFprog(sd, syscall.SOL_SOCKET, syscall.SO_ATTACH_FILTER, &unix.SockFprog{
Len: uint16(len(f)),
Filter: &f[0],
})
}