package main

import (
	"net"
	"testing"
	"time"

	"gotest.tools/v3/assert"
)

// TestUDPOneSided makes sure that the conntrack entry isn't GC'd if the
// backend never writes to the UDP client.
func TestUDPOneSided(t *testing.T) {
	frontend, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
	assert.NilError(t, err)
	defer frontend.Close()

	backend, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
	assert.NilError(t, err)
	defer backend.Close()

	type udpMsg struct {
		data  []byte
		saddr *net.UDPAddr
	}
	msgs := make(chan udpMsg)
	go func() {
		for {
			buf := make([]byte, 1024)
			n, saddr, err := backend.ReadFromUDP(buf)
			if err != nil {
				return
			}
			msgs <- udpMsg{data: buf[:n], saddr: saddr}
		}
	}()

	proxy, err := NewUDPProxy(frontend, backend.LocalAddr().(*net.UDPAddr), ip4)
	assert.NilError(t, err)
	defer proxy.Close()

	const connTrackTimeout = 1 * time.Second
	proxy.connTrackTimeout = connTrackTimeout

	go func() {
		proxy.Run()
	}()

	client, err := net.DialUDP("udp", nil, frontend.LocalAddr().(*net.UDPAddr))
	assert.NilError(t, err)
	defer client.Close()

	var expSaddr *net.UDPAddr
	for i := range 15 {
		_, err = client.Write([]byte("hello"))
		assert.NilError(t, err)
		time.Sleep(100 * time.Millisecond)

		msg := <-msgs
		assert.Equal(t, string(msg.data), "hello")
		if i == 0 {
			expSaddr = msg.saddr
		} else {
			assert.Equal(t, msg.saddr.Port, expSaddr.Port)
		}
	}

	// The conntrack entry is checked every connTrackTimeout, but the latest
	// write might be less than connTrackTimeout ago. So we need to wait for
	// at least twice the conntrack timeout to make sure the entry is GC'd.
	time.Sleep(2 * connTrackTimeout)
	_, err = client.Write([]byte("hello"))
	assert.NilError(t, err)

	msg := <-msgs
	assert.Equal(t, string(msg.data), "hello")
	assert.Check(t, msg.saddr.Port != expSaddr.Port)
}