/*
 *  Copyright 2006 Everton da Silva Marques <everton.marques@gmail.com>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

#if HAVE_CONFIG_H
#include "clamav-config.h"
#endif

#include "nonblock.h"

#include <stdio.h>
#include <stdlib.h>
#ifdef	HAVE_UNISTD_H
#include <unistd.h>
#endif
#include <string.h>
#include <ctype.h>
#ifndef	_WIN32
#include <netinet/in.h>
#include <netdb.h>
#include <sys/socket.h>
#include <sys/time.h>
#endif
#include <sys/types.h>
#include <time.h>
#include <fcntl.h>
#include <sys/stat.h>
#include <errno.h>

#include "shared/output.h"
#include "libclamav/clamav.h"

#ifdef SO_ERROR

#ifndef timercmp
# define timercmp(a, b, cmp)          \
  (((a)->tv_sec == (b)->tv_sec) ?     \
   ((a)->tv_usec cmp (b)->tv_usec) :  \
   ((a)->tv_sec cmp (b)->tv_sec))
#endif /* timercmp */

#ifndef timersub
# define timersub(a, b, result)                       \
  do {                                                \
    (result)->tv_sec = (a)->tv_sec - (b)->tv_sec;     \
    (result)->tv_usec = (a)->tv_usec - (b)->tv_usec;  \
    if ((result)->tv_usec < 0) {                      \
      --(result)->tv_sec;                             \
      (result)->tv_usec += 1000000;                   \
    }                                                 \
  } while (0)
#endif /* timersub */

#define NONBLOCK_SELECT_MAX_FAILURES 3
#define NONBLOCK_MAX_BOGUS_LOOPS     10
#undef  NONBLOCK_DEBUG

static int connect_error(int sock)
{
	int optval;
	socklen_t optlen;

	optlen = sizeof(optval);
	getsockopt(sock, SOL_SOCKET, SO_ERROR, &optval, &optlen);

	if (optval) {
		logg("connect_error: getsockopt(SO_ERROR): fd=%d error=%d: %s\n",
		     sock, optval, strerror(optval));
	}

	return optval ? -1 : 0;
}

static int nonblock_connect(int sock, const struct sockaddr *addr, socklen_t addrlen, int secs)
{
	/* Max. of unexpected select() failures */
	int select_failures = NONBLOCK_SELECT_MAX_FAILURES;
	/* Max. of useless loops */
	int bogus_loops = NONBLOCK_MAX_BOGUS_LOOPS;
	struct timeval timeout;  /* When we should time out */
	int numfd;               /* Highest fdset fd plus 1 */

	/* Calculate into 'timeout' when we should time out */
	gettimeofday(&timeout, 0);
	timeout.tv_sec += secs;

	/* Launch (possibly) non-blocking connect() request */
	if (connect(sock, addr, addrlen)) {
		int e = errno;
#ifdef NONBLOCK_DEBUG
		logg("DEBUG nonblock_connect: connect(): fd=%d errno=%d: %s\n",
		     sock, e, strerror(e));
#endif
		switch (e) {
		case EALREADY:
		case EINPROGRESS:
		case EAGAIN:
			break; /* wait for connection */
		case EISCONN:
			return 0; /* connected */
		default:
			logg("nonblock_connect: connect(): fd=%d errno=%d: %s\n",
			     sock, e, strerror(e));
			return -1; /* failed */
		}
	}
	else {
		return connect_error(sock);
	}

	numfd = sock + 1; /* Highest fdset fd plus 1 */

	for (;;) {
		fd_set fds;
		struct timeval now;
		struct timeval wait;
		int n;

		/* Force timeout if we ran out of time */
		gettimeofday(&now, 0);
		if (timercmp(&now, &timeout, >)) {
			logg("nonblock_connect: connect timing out (%d secs)\n",
			     secs);
			break; /* failed */
		}

		/* Calculate into 'wait' how long to wait */
		timersub(&timeout, &now, &wait); /* wait = timeout - now */

		/* Init fds with 'sock' as the only fd */
		FD_ZERO(&fds);
		FD_SET(sock, &fds);

		n = select(numfd, 0, &fds, 0, &wait);
		if (n < 0) {
			logg("nonblock_connect: select() failure %d: errno=%d: %s\n",
			     select_failures, errno, strerror(errno));
			if (--select_failures >= 0)
				continue; /* keep waiting */
			break; /* failed */
		}

#ifdef NONBLOCK_DEBUG
		logg("DEBUG nonblock_connect: select = %d\n", n);
#endif

		if (n) {
			return connect_error(sock);
		}

		/* Select returned, but there is no work to do... */
		if (--bogus_loops < 0) {
			logg("nonblock_connect: giving up due to excessive bogus loops\n");
			break; /* failed */
		}

	} /* for loop: keep waiting */

	return -1; /* failed */
}

static ssize_t nonblock_recv(int sock, void *buf, size_t len, int flags, int secs)
{
	/* Max. of unexpected select() failures */
	int select_failures = NONBLOCK_SELECT_MAX_FAILURES;
	/* Max. of useless loops */
	int bogus_loops = NONBLOCK_MAX_BOGUS_LOOPS;
	struct timeval timeout;  /* When we should time out */
	int numfd;               /* Highest fdset fd plus 1 */

	/* Calculate into 'timeout' when we should time out */
	gettimeofday(&timeout, 0);
	timeout.tv_sec += secs;

	numfd = sock + 1; /* Highest fdset fd plus 1 */

	for (;;) {
		fd_set fds;
		struct timeval now;
		struct timeval wait;
		int n;

		/* Force timeout if we ran out of time */
		gettimeofday(&now, 0);
		if (timercmp(&now, &timeout, >)) {
			logg("nonblock_recv: recv timing out (%d secs)\n", secs);
			break; /* failed */
		}

		/* Calculate into 'wait' how long to wait */
		timersub(&timeout, &now, &wait); /* wait = timeout - now */

		/* Init fds with 'sock' as the only fd */
		FD_ZERO(&fds);
		FD_SET(sock, &fds);

		n = select(numfd, &fds, 0, 0, &wait);
		if (n < 0) {
			logg("nonblock_recv: select() failure %d: errno=%d: %s\n",
			     select_failures, errno, strerror(errno));
			if (--select_failures >= 0)
				continue; /* keep waiting */
			break; /* failed */
		}

		if (n) {
			return recv(sock, buf, len, flags);
		}

		/* Select returned, but there is no work to do... */
		if (--bogus_loops < 0) {
			logg("nonblock_recv: giving up due to excessive bogus loops\n");
			break; /* failed */
		}

	} /* for loop: keep waiting */

	return -1; /* failed */
}

static long nonblock_fcntl(int sock)
{
#ifdef	F_GETFL
	long fcntl_flags; /* Save fcntl() flags */

	fcntl_flags = fcntl(sock, F_GETFL, 0);
	if (fcntl_flags == -1) {
		logg("nonblock_fcntl: saving: fcntl(%d, F_GETFL): errno=%d: %s\n",
		     sock, errno, strerror(errno));
	}
	else if (fcntl(sock, F_SETFL, fcntl_flags | O_NONBLOCK)) {
		logg("nonblock_fcntl: fcntl(%d, F_SETFL, O_NONBLOCK): errno=%d: %s\n",
		     sock, errno, strerror(errno));
	}

	return fcntl_flags;
#else
	return 0;
#endif
}

static void restore_fcntl(int sock, long fcntl_flags)
{
#ifdef	F_SETFL
	if (fcntl_flags != -1) {
		if (fcntl(sock, F_SETFL, fcntl_flags)) {
			logg("restore_fcntl: restoring: fcntl(%d, F_SETFL): errno=%d: %s\n",
			     sock, errno, strerror(errno));
		}
	}
#endif
}

/*
	wait_connect(): wrapper for connect(), with explicit 'secs' timeout
*/
int wait_connect(int sock, const struct sockaddr *addr, socklen_t addrlen, int secs)
{
	long fcntl_flags; /* Save fcntl() flags */
	int ret;

	/* Temporarily set socket to non-blocking mode */
	fcntl_flags = nonblock_fcntl(sock);

	ret = nonblock_connect(sock, addr, addrlen, secs);

	/* Restore socket's default blocking mode */
	restore_fcntl(sock, fcntl_flags);

	return ret;
}

/*
	wait_recv(): wrapper for recv(), with explicit 'secs' timeout
*/
ssize_t wait_recv(int sock, void *buf, size_t len, int flags, int secs)
{
	long fcntl_flags; /* Save fcntl() flags */
	int ret;

	/* Temporarily set socket to non-blocking mode */
	fcntl_flags = nonblock_fcntl(sock);

	ret = nonblock_recv(sock, buf, len, flags, secs);

	/* Restore socket's default blocking mode */
	restore_fcntl(sock, fcntl_flags);

	return ret;
}

#endif /* SO_ERROR */