/*
 *  OpenVPN -- An application to securely tunnel IP networks
 *             over a single TCP/UDP port, with support for SSL/TLS-based
 *             session authentication and key exchange,
 *             packet encryption, packet authentication, and
 *             packet compression.
 *
 *  Copyright (C) 2002-2010 OpenVPN Technologies, Inc. <sales@openvpn.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License version 2
 *  as published by the Free Software Foundation.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program (see the file COPYING included with this
 *  distribution); if not, write to the Free Software Foundation, Inc.,
 *  59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

/*
 * These routines are designed to catch replay attacks,
 * where a man-in-the-middle captures packets and then
 * attempts to replay them back later.
 */

#ifdef USE_CRYPTO

#ifndef PACKET_ID_H
#define PACKET_ID_H

#include "circ_list.h"
#include "buffer.h"
#include "error.h"
#include "otime.h"

/*
 * Enables OpenVPN to be compiled in special packet_id test mode.
 */
/*#define PID_TEST*/

#if 1
/*
 * These are the types that members of
 * a struct packet_id_net are converted
 * to for network transmission.
 */
typedef uint32_t packet_id_type;
typedef uint32_t net_time_t;

/*
 * In TLS mode, when a packet ID gets to this level,
 * start thinking about triggering a new
 * SSL/TLS handshake.
 */
#define PACKET_ID_WRAP_TRIGGER 0xFF000000

/* convert a packet_id_type from host to network order */
#define htonpid(x) htonl(x)

/* convert a packet_id_type from network to host order */
#define ntohpid(x) ntohl(x)

/* convert a time_t in host order to a net_time_t in network order */
#define htontime(x) htonl((net_time_t)x)

/* convert a net_time_t in network order to a time_t in host order */
#define ntohtime(x) ((time_t)ntohl(x))

#else

/*
 * DEBUGGING ONLY.
 * Make packet_id_type and net_time_t small
 * to test wraparound logic and corner cases.
 */

typedef uint8_t packet_id_type;
typedef uint16_t net_time_t;

#define PACKET_ID_WRAP_TRIGGER 0x80

#define htonpid(x) (x)
#define ntohpid(x) (x)
#define htontime(x) htons((net_time_t)x)
#define ntohtime(x) ((time_t)ntohs(x))

#endif

/*
 * Printf formats for special types
 */
#define packet_id_format "%u"
typedef unsigned int packet_id_print_type;

/*
 * Maximum allowed backtrack in
 * sequence number due to packets arriving
 * out of order.
 */
#define MIN_SEQ_BACKTRACK 0
#define MAX_SEQ_BACKTRACK 65536
#define DEFAULT_SEQ_BACKTRACK 64

/*
 * Maximum allowed backtrack in
 * seconds due to packets arriving
 * out of order.
 */
#define MIN_TIME_BACKTRACK 0
#define MAX_TIME_BACKTRACK 600
#define DEFAULT_TIME_BACKTRACK 15

/*
 * Do a reap pass through the sequence number
 * array once every n seconds in order to
 * expire sequence numbers which can no longer
 * be accepted because they would violate
 * TIME_BACKTRACK.
 */
#define SEQ_REAP_INTERVAL 5

CIRC_LIST (seq_list, time_t);

/*
 * This is the data structure we keep on the receiving side,
 * to check that no packet-id (i.e. sequence number + optional timestamp)
 * is accepted more than once.
 */
struct packet_id_rec
{
  time_t last_reap;           /* last call of packet_id_reap */
  time_t time;                /* highest time stamp received */
  packet_id_type id;          /* highest sequence number received */
  int seq_backtrack;          /* set from --replay-window */
  int time_backtrack;         /* set from --replay-window */
  bool initialized;           /* true if packet_id_init was called */
  struct seq_list *seq_list;  /* packet-id "memory" */
};

/*
 * file to facilitate cross-session persistence
 * of time/id
 */
struct packet_id_persist
{
  const char *filename;
  int fd;
  time_t time;             /* time stamp */
  packet_id_type id;       /* sequence number */
  time_t time_last_written;
  packet_id_type id_last_written;
};

struct packet_id_persist_file_image
{
  time_t time;             /* time stamp */
  packet_id_type id;       /* sequence number */
};

/*
 * Keep a record of our current packet-id state
 * on the sending side.
 */
struct packet_id_send
{
  packet_id_type id;
  time_t time;
};

/*
 * Communicate packet-id over the wire.
 * A short packet-id is just a 32 bit
 * sequence number.  A long packet-id
 * includes a timestamp as well.
 *
 * Long packet-ids are used as IVs for
 * CFB/OFB ciphers.
 *
 * This data structure is always sent
 * over the net in network byte order,
 * by calling htonpid, ntohpid,
 * htontime, and ntohtime on the
 * data elements to change them
 * to and from standard sizes.
 *
 * In addition, time is converted to
 * a net_time_t before sending,
 * since openvpn always
 * uses a 32-bit time_t but some
 * 64 bit platforms use a
 * 64 bit time_t.
 */
struct packet_id_net
{
  packet_id_type id;
  time_t time; /* converted to net_time_t before transmission */
};

struct packet_id
{
  struct packet_id_send send;
  struct packet_id_rec rec;
};

void packet_id_init (struct packet_id *p, int seq_backtrack, int time_backtrack);
void packet_id_free (struct packet_id *p);

/* should we accept an incoming packet id ? */
bool packet_id_test (const struct packet_id_rec *p,
		     const struct packet_id_net *pin);

/* change our current state to reflect an accepted packet id */
void packet_id_add (struct packet_id_rec *p,
		    const struct packet_id_net *pin);

/* expire TIME_BACKTRACK sequence numbers */ 
void packet_id_reap (struct packet_id_rec *p);

/*
 * packet ID persistence
 */

/* initialize the packet_id_persist structure in a disabled state */
void packet_id_persist_init (struct packet_id_persist *p);

/* close the file descriptor if it is open, and switch to disabled state */
void packet_id_persist_close (struct packet_id_persist *p);

/* load persisted rec packet_id (time and id) only once from file, and set state to enabled */
void packet_id_persist_load (struct packet_id_persist *p, const char *filename);

/* save persisted rec packet_id (time and id) to file (only if enabled state) */
void packet_id_persist_save (struct packet_id_persist *p);

/* transfer packet_id_persist -> packet_id */
void packet_id_persist_load_obj (const struct packet_id_persist *p, struct packet_id* pid);

/* return an ascii string representing a packet_id_persist object */
const char *packet_id_persist_print (const struct packet_id_persist *p, struct gc_arena *gc);

/*
 * Read/write a packet ID to/from the buffer.  Short form is sequence number
 * only.  Long form is sequence number and timestamp.
 */

bool packet_id_read (struct packet_id_net *pin, struct buffer *buf, bool long_form);
bool packet_id_write (const struct packet_id_net *pin, struct buffer *buf, bool long_form, bool prepend);

/*
 * Inline functions.
 */

/* are we in enabled state? */
static inline bool
packet_id_persist_enabled (const struct packet_id_persist *p)
{
  return p->fd >= 0;
}

/* transfer packet_id -> packet_id_persist */
static inline void
packet_id_persist_save_obj (struct packet_id_persist *p, const struct packet_id* pid)
{
  if (packet_id_persist_enabled (p) && pid->rec.time)
    {
      p->time = pid->rec.time;
      p->id = pid->rec.id;
    }
}

const char* packet_id_net_print(const struct packet_id_net *pin, bool print_timestamp, struct gc_arena *gc);

#ifdef PID_TEST
void packet_id_interactive_test();
#endif

static inline int
packet_id_size (bool long_form)
{
  return sizeof (packet_id_type) + (long_form ? sizeof (net_time_t) : 0);
} 

static inline bool
packet_id_close_to_wrapping (const struct packet_id_send *p)
{
  return p->id >= PACKET_ID_WRAP_TRIGGER;
}

/*
 * Allocate an outgoing packet id.
 * Sequence number ranges from 1 to 2^32-1.
 * In long_form, a time_t is added as well.
 */
static inline void
packet_id_alloc_outgoing (struct packet_id_send *p, struct packet_id_net *pin, bool long_form)
{
  if (!p->time)
    p->time = now;
  pin->id = ++p->id;
  if (!pin->id)
    {
      ASSERT (long_form);
      p->time = now;
      pin->id = p->id = 1;
    }
  pin->time = p->time;
}

static inline bool
check_timestamp_delta (time_t remote, unsigned int max_delta)
{
  unsigned int abs;
  const time_t local_now = now;

  if (local_now >= remote)
    abs = local_now - remote;
  else
    abs = remote - local_now;
  return abs <= max_delta;
}

static inline void
packet_id_reap_test (struct packet_id_rec *p)
{
  if (p->last_reap + SEQ_REAP_INTERVAL <= now)
    packet_id_reap (p);
}

#endif /* PACKET_ID_H */
#endif /* USE_CRYPTO */