/*
 *  Copyright (C) 2013-2019 Cisco Systems, Inc. and/or its affiliates. All rights reserved.
 *  Copyright (C) 2009-2013 Sourcefire, Inc.
 *
 *  Authors: aCaB <acab@clamav.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License version 2 as
 *  published by the Free Software Foundation.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 *  MA 02110-1301, USA.
 */

#include <stdio.h>
#include <winsock2.h>
#include <Ws2tcpip.h>
/* #define W2K_DNSAAPI_COMPAT */
#ifdef W2K_DNSAAPI_COMPAT
#include <Wspiapi.h>
#endif
#include <stdlib.h>
#include "net.h"
#include "w32_errno.h"

static void wsock2errno()
{
    switch (WSAGetLastError()) {
        case WSA_INVALID_HANDLE:
        case WSA_INVALID_PARAMETER:
        case WSAVERNOTSUPPORTED:
        case WSANOTINITIALISED:
        case WSAEINVALIDPROCTABLE:
        case WSAEINVALIDPROVIDER:
        case WSAEPROVIDERFAILEDINIT:
        case WSASYSCALLFAILURE:
        case WSASERVICE_NOT_FOUND:
        case WSATYPE_NOT_FOUND:
            errno = EINVAL;
            break;
        case WSA_OPERATION_ABORTED:
        case WSAENOMORE:
        case WSAECANCELLED:
        case WSA_E_NO_MORE:
        case WSA_E_CANCELLED:
        case WSA_IO_INCOMPLETE:
        case WSA_IO_PENDING:
        case WSAEREFUSED:
        case WSA_QOS_RECEIVERS:
        case WSA_QOS_SENDERS:
        case WSA_QOS_NO_SENDERS:
        case WSA_QOS_NO_RECEIVERS:
        case WSA_QOS_REQUEST_CONFIRMED:
        case WSA_QOS_ADMISSION_FAILURE:
        case WSA_QOS_POLICY_FAILURE:
        case WSA_QOS_BAD_STYLE:
        case WSA_QOS_BAD_OBJECT:
        case WSA_QOS_TRAFFIC_CTRL_ERROR:
        case WSA_QOS_GENERIC_ERROR:
        case WSA_QOS_ESERVICETYPE:
        case WSA_QOS_EFLOWSPEC:
        case WSA_QOS_EPROVSPECBUF:
        case WSA_QOS_EFILTERSTYLE:
        case WSA_QOS_EFILTERTYPE:
        case WSA_QOS_EFILTERCOUNT:
        case WSA_QOS_EOBJLENGTH:
        case WSA_QOS_EFLOWCOUNT:
        case WSA_QOS_EUNKOWNPSOBJ:
        case WSA_QOS_EPOLICYOBJ:
        case WSA_QOS_EFLOWDESC:
        case WSA_QOS_EPSFLOWSPEC:
        case WSA_QOS_EPSFILTERSPEC:
        case WSA_QOS_ESDMODEOBJ:
        case WSA_QOS_ESHAPERATEOBJ:
        case WSA_QOS_RESERVED_PETYPE:
            errno = EBOGUSWSOCK;
            break;
        case WSA_NOT_ENOUGH_MEMORY:
            errno = ENOMEM;
            break;
        case WSAEINTR:
            errno = EINTR;
            break;
        case WSAEBADF:
            errno = EBADF;
            break;
        case WSAEACCES:
            errno = EACCES;
            break;
        case WSAEFAULT:
            errno = EFAULT;
            break;
        case WSAEINVAL:
            errno = EINVAL;
            break;
        case WSAEMFILE:
            errno = EMFILE;
            break;
        case WSAEWOULDBLOCK:
            errno = EAGAIN;
            break;
        case WSAEINPROGRESS:
            errno = EINPROGRESS;
            break;
        case WSAEALREADY:
            errno = EALREADY;
            break;
        case WSAENOTSOCK:
            errno = ENOTSOCK;
            break;
        case WSAEDESTADDRREQ:
            errno = EDESTADDRREQ;
            break;
        case WSAEMSGSIZE:
            errno = EMSGSIZE;
            break;
        case WSAEPROTOTYPE:
            errno = EPROTOTYPE;
            break;
        case WSAENOPROTOOPT:
            errno = ENOPROTOOPT;
            break;
        case WSAEPROTONOSUPPORT:
            errno = EPROTONOSUPPORT;
            break;
        case WSAESOCKTNOSUPPORT:
            errno = ESOCKTNOSUPPORT;
            break;
        case WSAEOPNOTSUPP:
            errno = EOPNOTSUPP;
            break;
        case WSAEPFNOSUPPORT:
            errno = EPFNOSUPPORT;
            break;
        case WSAEAFNOSUPPORT:
            errno = EAFNOSUPPORT;
            break;
        case WSAEADDRINUSE:
            errno = EADDRINUSE;
            break;
        case WSAEADDRNOTAVAIL:
            errno = EADDRNOTAVAIL;
            break;
        case WSASYSNOTREADY:
        case WSAENETDOWN:
            errno = ENETDOWN;
            break;
        case WSAENETUNREACH:
            errno = ENETUNREACH;
            break;
        case WSAENETRESET:
            errno = ENETRESET;
            break;
        case WSAECONNABORTED:
            errno = ECONNABORTED;
            break;
        case WSAECONNRESET:
        case WSAEDISCON:
            errno = ECONNRESET;
            break;
        case WSAENOBUFS:
            errno = ENOBUFS;
            break;
        case WSAEISCONN:
            errno = EISCONN;
            break;
        case WSAENOTCONN:
            errno = ENOTCONN;
            break;
        case WSAESHUTDOWN:
            errno = ESHUTDOWN;
            break;
        case WSAETOOMANYREFS:
            errno = ETOOMANYREFS;
            break;
        case WSAETIMEDOUT:
            errno = ETIMEDOUT;
            break;
        case WSAECONNREFUSED:
            errno = ECONNREFUSED;
            break;
        case WSAELOOP:
            errno = ELOOP;
            break;
        case WSAENAMETOOLONG:
            errno = ENAMETOOLONG;
            break;
        case WSAEHOSTDOWN:
            errno = EHOSTDOWN;
            break;
        case WSAEHOSTUNREACH:
            errno = EHOSTUNREACH;
            break;
        case WSAENOTEMPTY:
            errno = ENOTEMPTY;
            break;
        case WSAEPROCLIM:
        case WSAEUSERS:
            errno = EUSERS;
            break;
        case WSAEDQUOT:
            errno = EDQUOT;
            break;
        case WSAESTALE:
            errno = ESTALE;
            break;
        case WSAEREMOTE:
            errno = EREMOTE;
            break;
    }
}

int w32_socket(int domain, int type, int protocol)
{
    SOCKET s = socket(domain, type, protocol);
    if (s == INVALID_SOCKET) {
        wsock2errno();
        return -1;
    }
    return (int)s;
}

int w32_getsockopt(int sockfd, int level, int optname, void *optval, socklen_t *optlen)
{
    if (getsockopt((SOCKET)sockfd, level, optname, (char *)optval, optlen) == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return 0;
}

int w32_setsockopt(int sockfd, int level, int optname, const void *optval, socklen_t optlen)
{
    if (setsockopt((SOCKET)sockfd, level, optname, (const char *)optval, optlen) == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return 0;
}

int w32_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
{
    if (bind((SOCKET)sockfd, addr, addrlen) == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return 0;
}

int w32_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
{
    if (connect((SOCKET)sockfd, addr, addrlen)) {
        wsock2errno();
        return -1;
    }
    return 0;
}

ssize_t w32_send(int sockfd, const void *buf, size_t len, int flags)
{
    int ret = send((SOCKET)sockfd, (const char *)buf, (int)len, flags);
    if (ret == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return (ssize_t)ret;
}

ssize_t w32_recv(int sockfd, void *buf, size_t len, int flags)
{
    int ret = recv((SOCKET)sockfd, (char *)buf, len, flags);
    if (ret == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return (ssize_t)ret;
}

int w32_getpeername(int s, struct sockaddr *name, int *namelen)
{
    int ret = getpeername((SOCKET)s, name, namelen);
    if (ret == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return ret;
}

char *w32_inet_ntoa(struct in_addr in)
{
    return inet_ntoa(in);
}

int w32_closesocket(int sockfd)
{
    if (closesocket((SOCKET)sockfd) == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return 0;
}

struct servent *w32_getservbyname(const char *name, const char *proto)
{
    return getservbyname(name, proto);
}

int w32_getaddrinfo(const char *node, const char *service, const struct addrinfo *hints, struct addrinfo **res)
{
    int ret = getaddrinfo(node, service, hints, res);
    if (ret) wsock2errno();
    return ret;
}

void w32_freeaddrinfo(struct addrinfo *res)
{
    freeaddrinfo(res);
}

const char *w32_inet_ntop(int af, const void *src, char *dst, socklen_t size)
{
    const char *ret;

    if (af != AF_INET) {
        errno = EAFNOSUPPORT;
        return NULL;
    }
    ret = inet_ntoa(*(struct in_addr *)src);
    if (!ret) {
        wsock2errno();
        return NULL;
    }
    if (strlen(ret) >= size) {
        errno = ENOSPC;
        return NULL;
    }
    strcpy(dst, ret);
    return ret;
}

int w32_select(int nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds, struct timeval *timeout)
{
    int ret = select(nfds, readfds, writefds, exceptfds, timeout);
    if (ret == SOCKET_ERROR) {
        wsock2errno();
        return -1;
    }
    return ret;
}

int w32_accept(SOCKET sockfd, const struct sockaddr *addr, socklen_t *addrlen)
{
    if ((sockfd = accept(sockfd, addr, addrlen)) == INVALID_SOCKET) {
        wsock2errno();
        return -1;
    }
    return (int)sockfd;
}

int w32_listen(int sockfd, int backlog)
{
    if (listen((SOCKET)sockfd, backlog)) {
        wsock2errno();
        return -1;
    }
    return 0;
}

int w32_shutdown(int sockfd, int how)
{
    if (shutdown((SOCKET)sockfd, how)) {
        wsock2errno();
        return -1;
    }
    return 0;
}

struct w32polldata {
    HANDLE setme;
    HANDLE event;
    HANDLE waiter;
    struct pollfd *polldata;
};

VOID CALLBACK poll_cb(PVOID param, BOOLEAN timedout)
{
    WSANETWORKEVENTS evt;
    struct w32polldata *item = (struct w32polldata *)param;
    if (!timedout) {
        unsigned int i;
        WSAEnumNetworkEvents(item->polldata->fd, item->event, &evt);
        if (evt.lNetworkEvents & FD_ACCEPT) {
            item->polldata->revents |= POLLIN;
            if (evt.iErrorCode[FD_ACCEPT_BIT])
                item->polldata->revents = POLLERR;
        }
        if (evt.lNetworkEvents & FD_READ) {
            item->polldata->revents |= POLLIN;
            if (evt.iErrorCode[FD_READ_BIT])
                item->polldata->revents = POLLERR;
        }
        if (evt.lNetworkEvents & FD_CLOSE) {
            item->polldata->revents |= POLLHUP;
            if (evt.iErrorCode[FD_CLOSE_BIT])
                item->polldata->revents = POLLERR;
        }
        SetEvent(item->setme);
    }
}

int poll_with_event(struct pollfd *fds, int nfds, int timeout, HANDLE event)
{
    HANDLE *setme, cankill;
    struct w32polldata *items;
    unsigned int i, ret = 0, reallywait = 1;

    if (timeout < 0) timeout = INFINITE;
    if (!nfds) {
        if (event) {
            if (WaitForSingleObject(event, timeout) == WAIT_OBJECT_0)
                return 1;
        } else
            Sleep(timeout);
        return 0;
    }
    setme = malloc(2 * sizeof(HANDLE));
    if (setme == NULL) { /* oops, malloc() failed */
        fprintf(stderr, "warning: malloc() for variable 'setme' failed in function 'poll_with_event'...\n");
        return -1;
    }
    setme[0] = CreateEvent(NULL, TRUE, FALSE, NULL);
    setme[1] = event;
    items    = malloc(nfds * sizeof(struct w32polldata));
    if (items == NULL) { /* oops, malloc() failed */
        fprintf(stderr, "warning: malloc() for variable 'items' failed in function 'poll_with_event'...\n");
        return -1;
    }
    for (i = 0; i < nfds; i++) {
        items[i].polldata = &fds[i];
        items[i].event    = CreateEvent(NULL, TRUE, FALSE, NULL);
        if (items[i].event) {
            items[i].setme = setme[0];
            if (WSAEventSelect(fds[i].fd, items[i].event, FD_ACCEPT | FD_READ | FD_CLOSE)) {
                CloseHandle(items[i].event);
                items[i].event = NULL;
            } else {
                char c; /* Ugly workaround to FD_CLOSE not being persistent
			   better win32 code is possible at the cost of a larger diff vs. the unix
			   netcode - for now it stays ugly...
			*/
                int n = recv(fds[i].fd, &c, 1, MSG_PEEK);
                if (!n)
                    items[i].polldata->revents = POLLHUP;
                if (n == 1)
                    items[i].polldata->revents = POLLIN;
                if (n >= 0 || !RegisterWaitForSingleObject(&items[i].waiter, items[i].event, poll_cb, &items[i], timeout, WT_EXECUTEONLYONCE)) {
                    WSAEventSelect(fds[i].fd, items[i].event, 0);
                    CloseHandle(items[i].event);
                    items[i].event = NULL;
                    reallywait     = 0;
                }
            }
        }
    }
    if (reallywait) {
        if (WaitForMultipleObjects(2 - (event == NULL), setme, FALSE, timeout) == WAIT_OBJECT_0 + 1)
            ret = 1;
        else
            ret = 0;
    }
    cankill = CreateEvent(NULL, TRUE, FALSE, NULL);
    for (i = 0; i < nfds; i++) {
        if (items[i].event) {
            ResetEvent(cankill);
            UnregisterWaitEx(items[i].waiter, cankill);
            WSAEventSelect(fds[i].fd, items[i].event, 0);
            WaitForSingleObject(cankill, INFINITE);
            CloseHandle(items[i].event);
        }
        ret += (items[i].polldata->revents != 0);
    }
    CloseHandle(cankill);
    free(items);
    CloseHandle(setme[0]);
    free(setme);
    return ret;
}

int fcntl(int fd, int cmd, ...)
{
    va_list ap;
    va_start(ap, cmd);

    if (cmd == F_GETFL)
        return 0;
    if (cmd == F_SETFL) {
        u_long arg = va_arg(ap, long) == O_NONBLOCK;
        if (ioctlsocket((SOCKET)fd, FIONBIO, &arg)) {
            wsock2errno();
            return -1;
        }
        return 0;
    }
    return -1;
}