/* This file is part of libmspack.
 * (C) 2003-2011 Stuart Caie.
 *
 * KWAJ is a format very similar to SZDD. KWAJ method 3 (LZH) was
 * written by Jeff Johnson.
 *
 * libmspack is free software; you can redistribute it and/or modify it under
 * the terms of the GNU Lesser General Public License (LGPL) version 2.1
 *
 * For further details, see the file COPYING.LIB distributed with libmspack
 */

/* KWAJ decompression implementation */

#include <system.h>
#include <kwaj.h>
#include <mszip.h>

/* prototypes */
static struct mskwajd_header *kwajd_open(
    struct mskwaj_decompressor *base, const char *filename);
static void kwajd_close(
    struct mskwaj_decompressor *base, struct mskwajd_header *hdr);
static int kwajd_read_headers(
    struct mspack_system *sys, struct mspack_file *fh,
    struct mskwajd_header *hdr);
static int kwajd_extract(
    struct mskwaj_decompressor *base, struct mskwajd_header *hdr,
    const char *filename);
static int kwajd_decompress(
    struct mskwaj_decompressor *base, const char *input, const char *output);
static int kwajd_error(
    struct mskwaj_decompressor *base);

static struct kwajd_stream *lzh_init(
    struct mspack_system *sys, struct mspack_file *in, struct mspack_file *out);
static int lzh_decompress(
    struct kwajd_stream *kwaj);
static void lzh_free(
    struct kwajd_stream *kwaj);
static int lzh_read_lens(
    struct kwajd_stream *kwaj,
    unsigned int type, unsigned int numsyms,
    unsigned char *lens);
static int lzh_read_input(
    struct kwajd_stream *kwaj);


/***************************************
 * MSPACK_CREATE_KWAJ_DECOMPRESSOR
 ***************************************
 * constructor
 */
struct mskwaj_decompressor *
    mspack_create_kwaj_decompressor(struct mspack_system *sys)
{
  struct mskwaj_decompressor_p *self = NULL;

  if (!sys) sys = mspack_default_system;
  if (!mspack_valid_system(sys)) return NULL;

  if ((self = (struct mskwaj_decompressor_p *) sys->alloc(sys, sizeof(struct mskwaj_decompressor_p)))) {
    self->base.open       = &kwajd_open;
    self->base.close      = &kwajd_close;
    self->base.extract    = &kwajd_extract;
    self->base.decompress = &kwajd_decompress;
    self->base.last_error = &kwajd_error;
    self->system          = sys;
    self->error           = MSPACK_ERR_OK;
  }
  return (struct mskwaj_decompressor *) self;
}

/***************************************
 * MSPACK_DESTROY_KWAJ_DECOMPRESSOR
 ***************************************
 * destructor
 */
void mspack_destroy_kwaj_decompressor(struct mskwaj_decompressor *base)
{
    struct mskwaj_decompressor_p *self = (struct mskwaj_decompressor_p *) base;
    if (self) {
        struct mspack_system *sys = self->system;
        sys->free(self);
    }
}

/***************************************
 * KWAJD_OPEN
 ***************************************
 * opens a KWAJ file without decompressing, reads header
 */
static struct mskwajd_header *kwajd_open(struct mskwaj_decompressor *base,
                                         const char *filename)
{
    struct mskwaj_decompressor_p *self = (struct mskwaj_decompressor_p *) base;
    struct mskwajd_header *hdr;
    struct mspack_system *sys;
    struct mspack_file *fh;

    if (!self) return NULL;
    sys = self->system;

    fh  = sys->open(sys, filename, MSPACK_SYS_OPEN_READ);
    hdr = (struct mskwajd_header *) sys->alloc(sys, sizeof(struct mskwajd_header_p));
    if (fh && hdr) {
        ((struct mskwajd_header_p *) hdr)->fh = fh;
        self->error = kwajd_read_headers(sys, fh, hdr);
    }
    else {
        if (!fh)  self->error = MSPACK_ERR_OPEN;
        if (!hdr) self->error = MSPACK_ERR_NOMEMORY;
    }
    
    if (self->error) {
        if (fh) sys->close(fh);
        sys->free(hdr);
        hdr = NULL;
    }

    return hdr;
}

/***************************************
 * KWAJD_CLOSE
 ***************************************
 * closes a KWAJ file
 */
static void kwajd_close(struct mskwaj_decompressor *base,
                        struct mskwajd_header *hdr)
{
    struct mskwaj_decompressor_p *self = (struct mskwaj_decompressor_p *) base;
    struct mskwajd_header_p *hdr_p = (struct mskwajd_header_p *) hdr;

    if (!self || !self->system) return;

    /* close the file handle associated */
    self->system->close(hdr_p->fh);

    /* free the memory associated */
    self->system->free(hdr);

    self->error = MSPACK_ERR_OK;
}

/***************************************
 * KWAJD_READ_HEADERS
 ***************************************
 * reads the headers of a KWAJ format file
 */
static int kwajd_read_headers(struct mspack_system *sys,
                              struct mspack_file *fh,
                              struct mskwajd_header *hdr)
{
    unsigned char buf[16];
    int i;

    /* read in the header */
    if (sys->read(fh, &buf[0], kwajh_SIZEOF) != kwajh_SIZEOF) {
        return MSPACK_ERR_READ;
    }

    /* check for "KWAJ" signature */
    if (((unsigned int) EndGetI32(&buf[kwajh_Signature1]) != 0x4A41574B) ||
        ((unsigned int) EndGetI32(&buf[kwajh_Signature2]) != 0xD127F088))
    {
        return MSPACK_ERR_SIGNATURE;
    }

    /* basic header fields */
    hdr->comp_type    = EndGetI16(&buf[kwajh_CompMethod]);
    hdr->data_offset  = EndGetI16(&buf[kwajh_DataOffset]);
    hdr->headers      = EndGetI16(&buf[kwajh_Flags]);
    hdr->length       = 0;
    hdr->filename     = NULL;
    hdr->extra        = NULL;
    hdr->extra_length = 0;

    /* optional headers */

    /* 4 bytes: length of unpacked file */
    if (hdr->headers & MSKWAJ_HDR_HASLENGTH) {
        if (sys->read(fh, &buf[0], 4) != 4) return MSPACK_ERR_READ;
        hdr->length = EndGetI32(&buf[0]);
    }

    /* 2 bytes: unknown purpose */
    if (hdr->headers & MSKWAJ_HDR_HASUNKNOWN1) {
        if (sys->read(fh, &buf[0], 2) != 2) return MSPACK_ERR_READ;
    }

    /* 2 bytes: length of section, then [length] bytes: unknown purpose */
    if (hdr->headers & MSKWAJ_HDR_HASUNKNOWN2) {
        if (sys->read(fh, &buf[0], 2) != 2) return MSPACK_ERR_READ;
        i = EndGetI16(&buf[0]);
        if (sys->seek(fh, (off_t)i, MSPACK_SYS_SEEK_CUR)) return MSPACK_ERR_SEEK;
    }

    /* filename and extension */
    if (hdr->headers & (MSKWAJ_HDR_HASFILENAME | MSKWAJ_HDR_HASFILEEXT)) {
        int len;
        /* allocate memory for maximum length filename */
        char *fn = (char *) sys->alloc(sys, (size_t) 13);
        if (!(hdr->filename = fn)) return MSPACK_ERR_NOMEMORY;

        /* copy filename if present */
        if (hdr->headers & MSKWAJ_HDR_HASFILENAME) {
            /* read and copy up to 9 bytes of a null terminated string */
            if ((len = sys->read(fh, &buf[0], 9)) < 2) return MSPACK_ERR_READ;
            for (i = 0; i < len; i++) if (!(*fn++ = buf[i])) break;
            /* if string was 9 bytes with no null terminator, reject it */
            if (i == 9 && buf[8] != '\0') return MSPACK_ERR_DATAFORMAT;
            /* seek to byte after string ended in file */
            if (sys->seek(fh, (off_t)(i + 1 - len), MSPACK_SYS_SEEK_CUR))
                return MSPACK_ERR_SEEK;
            fn--; /* remove the null terminator */
        }

        /* copy extension if present */
        if (hdr->headers & MSKWAJ_HDR_HASFILEEXT) {
            *fn++ = '.';
            /* read and copy up to 4 bytes of a null terminated string */
            if ((len = sys->read(fh, &buf[0], 4)) < 2) return MSPACK_ERR_READ;
            for (i = 0; i < len; i++) if (!(*fn++ = buf[i])) break;
            /* if string was 4 bytes with no null terminator, reject it */
            if (i == 4 && buf[3] != '\0') return MSPACK_ERR_DATAFORMAT;
            /* seek to byte after string ended in file */
            if (sys->seek(fh, (off_t)(i + 1 - len), MSPACK_SYS_SEEK_CUR))
                return MSPACK_ERR_SEEK;
            fn--; /* remove the null terminator */
        }
        *fn = '\0';
    }

    /* 2 bytes: extra text length then [length] bytes of extra text data */
    if (hdr->headers & MSKWAJ_HDR_HASEXTRATEXT) {
        if (sys->read(fh, &buf[0], 2) != 2) return MSPACK_ERR_READ;
        i = EndGetI16(&buf[0]);
        hdr->extra = (char *) sys->alloc(sys, (size_t)i+1);
        if (! hdr->extra) return MSPACK_ERR_NOMEMORY;
        if (sys->read(fh, hdr->extra, i) != i) return MSPACK_ERR_READ;
        hdr->extra[i] = '\0';
        hdr->extra_length = i;
    }
    return MSPACK_ERR_OK;
}

/***************************************
 * KWAJD_EXTRACT
 ***************************************
 * decompresses a KWAJ file
 */
static int kwajd_extract(struct mskwaj_decompressor *base,
                         struct mskwajd_header *hdr, const char *filename)
{
    struct mskwaj_decompressor_p *self = (struct mskwaj_decompressor_p *) base;
    struct mspack_system *sys;
    struct mspack_file *fh, *outfh;

    if (!self) return MSPACK_ERR_ARGS;
    if (!hdr) return self->error = MSPACK_ERR_ARGS;

    sys = self->system;
    fh = ((struct mskwajd_header_p *) hdr)->fh;

    /* seek to the compressed data */
    if (sys->seek(fh, hdr->data_offset, MSPACK_SYS_SEEK_START)) {
        return self->error = MSPACK_ERR_SEEK;
    }

    /* open file for output */
    if (!(outfh = sys->open(sys, filename, MSPACK_SYS_OPEN_WRITE))) {
        return self->error = MSPACK_ERR_OPEN;
    }

    self->error = MSPACK_ERR_OK;

    /* decompress based on format */
    if (hdr->comp_type == MSKWAJ_COMP_NONE ||
        hdr->comp_type == MSKWAJ_COMP_XOR)
    {
        /* NONE is a straight copy. XOR is a copy xored with 0xFF */
        unsigned char *buf = (unsigned char *) sys->alloc(sys, (size_t) KWAJ_INPUT_SIZE);
        if (buf) {
            int read, i;
            while ((read = sys->read(fh, buf, KWAJ_INPUT_SIZE)) > 0) {
                if (hdr->comp_type == MSKWAJ_COMP_XOR) {
                    for (i = 0; i < read; i++) buf[i] ^= 0xFF;
                }
                if (sys->write(outfh, buf, read) != read) {
                    self->error = MSPACK_ERR_WRITE;
                    break;
                }
            }
            if (read < 0) self->error = MSPACK_ERR_READ;
            sys->free(buf);
        }
        else {
            self->error = MSPACK_ERR_NOMEMORY;
        }
    }
    else if (hdr->comp_type == MSKWAJ_COMP_SZDD) {
        self->error = lzss_decompress(sys, fh, outfh, KWAJ_INPUT_SIZE,
                                      LZSS_MODE_EXPAND);
    }
    else if (hdr->comp_type == MSKWAJ_COMP_LZH) {
        struct kwajd_stream *lzh = lzh_init(sys, fh, outfh);
        self->error = (lzh) ? lzh_decompress(lzh) : MSPACK_ERR_NOMEMORY;
        lzh_free(lzh);
    }
    else if (hdr->comp_type == MSKWAJ_COMP_MSZIP) {
        struct mszipd_stream *zip = mszipd_init(sys,fh,outfh,KWAJ_INPUT_SIZE,0);
        self->error = (zip) ? mszipd_decompress_kwaj(zip) : MSPACK_ERR_NOMEMORY;
        mszipd_free(zip);
    }
    else {
        self->error = MSPACK_ERR_DATAFORMAT;
    }

    /* close output file */
    sys->close(outfh);

    return self->error;
}

/***************************************
 * KWAJD_DECOMPRESS
 ***************************************
 * unpacks directly from input to output
 */
static int kwajd_decompress(struct mskwaj_decompressor *base,
                            const char *input, const char *output)
{
    struct mskwaj_decompressor_p *self = (struct mskwaj_decompressor_p *) base;
    struct mskwajd_header *hdr;
    int error;

    if (!self) return MSPACK_ERR_ARGS;

    if (!(hdr = kwajd_open(base, input))) return self->error;
    error = kwajd_extract(base, hdr, output);
    kwajd_close(base, hdr);
    return self->error = error;
}

/***************************************
 * KWAJD_ERROR
 ***************************************
 * returns the last error that occurred
 */
static int kwajd_error(struct mskwaj_decompressor *base)
{
    struct mskwaj_decompressor_p *self = (struct mskwaj_decompressor_p *) base;
    return (self) ? self->error : MSPACK_ERR_ARGS;
}

/***************************************
 * LZH_INIT, LZH_DECOMPRESS, LZH_FREE
 ***************************************
 * unpacks KWAJ method 3 files
 */

/* import bit-reading macros and code */
#define BITS_TYPE struct kwajd_stream
#define BITS_VAR lzh
#define BITS_ORDER_MSB
#define BITS_NO_READ_INPUT
#define READ_BYTES do {                                 \
    if (i_ptr >= i_end) {                               \
        if ((err = lzh_read_input(lzh))) return err;    \
        i_ptr = lzh->i_ptr;                             \
        i_end = lzh->i_end;                             \
    }                                                   \
    INJECT_BITS(*i_ptr++, 8);                           \
} while (0)
#include <readbits.h>

/* import huffman-reading macros and code */
#define TABLEBITS(tbl)      KWAJ_TABLEBITS
#define MAXSYMBOLS(tbl)     KWAJ_##tbl##_SYMS
#define HUFF_TABLE(tbl,idx) lzh->tbl##_table[idx]
#define HUFF_LEN(tbl,idx)   lzh->tbl##_len[idx]
#define HUFF_ERROR          return MSPACK_ERR_DATAFORMAT
#include <readhuff.h>

/* In the KWAJ LZH format, there is no special 'eof' marker, it just
 * ends. Depending on how many bits are left in the final byte when
 * the stream ends, that might be enough to start another literal or
 * match. The only easy way to detect that we've come to an end is to
 * guard all bit-reading. We allow fake bits to be read once we reach
 * the end of the stream, but we check if we then consumed any of
 * those fake bits, after doing the READ_BITS / READ_HUFFSYM. This
 * isn't how the default readbits.h read_input() works (it simply lets
 * 2 fake bytes in then stops), so we implement our own.
 */
#define READ_BITS_SAFE(val, n) do {                     \
    READ_BITS(val, n);                                  \
    if (lzh->input_end && bits_left < lzh->input_end)   \
        return MSPACK_ERR_OK;                           \
} while (0)

#define READ_HUFFSYM_SAFE(tbl, val) do {                \
    READ_HUFFSYM(tbl, val);                             \
    if (lzh->input_end && bits_left < lzh->input_end)   \
        return MSPACK_ERR_OK;                           \
} while (0)

#define BUILD_TREE(tbl, type)                                           \
    STORE_BITS;                                                         \
    err = lzh_read_lens(lzh, type, MAXSYMBOLS(tbl), &HUFF_LEN(tbl,0));  \
    if (err) return err;                                                \
    RESTORE_BITS;                                                       \
    if (make_decode_table(MAXSYMBOLS(tbl), TABLEBITS(tbl),              \
        &HUFF_LEN(tbl,0), &HUFF_TABLE(tbl,0)))                          \
        return MSPACK_ERR_DATAFORMAT;

#define WRITE_BYTE do {                                                 \
    if (lzh->sys->write(lzh->output, &lzh->window[pos], 1) != 1)        \
        return MSPACK_ERR_WRITE;                                        \
} while (0)

static struct kwajd_stream *lzh_init(struct mspack_system *sys,
    struct mspack_file *in, struct mspack_file *out)
{
    struct kwajd_stream *lzh;

    if (!sys || !in || !out) return NULL;
    if (!(lzh = (struct kwajd_stream *) sys->alloc(sys, sizeof(struct kwajd_stream)))) return NULL;

    lzh->sys    = sys;
    lzh->input  = in;
    lzh->output = out;
    return lzh;
}

static int lzh_decompress(struct kwajd_stream *lzh)
{
    register unsigned int bit_buffer;
    register int bits_left, i;
    register unsigned short sym;
    unsigned char *i_ptr, *i_end, lit_run = 0;
    int j, pos = 0, len, offset, err;
    unsigned int types[6];

    /* reset global state */
    INIT_BITS;
    RESTORE_BITS;
    memset(&lzh->window[0], LZSS_WINDOW_FILL, (size_t) LZSS_WINDOW_SIZE);

    /* read 6 encoding types (for byte alignment) but only 5 are needed */
    for (i = 0; i < 6; i++) READ_BITS_SAFE(types[i], 4);

    /* read huffman table symbol lengths and build huffman trees */
    BUILD_TREE(MATCHLEN1, types[0]);
    BUILD_TREE(MATCHLEN2, types[1]);
    BUILD_TREE(LITLEN,    types[2]);
    BUILD_TREE(OFFSET,    types[3]);
    BUILD_TREE(LITERAL,   types[4]);

    while (!lzh->input_end) {
        if (lit_run) READ_HUFFSYM_SAFE(MATCHLEN2, len);
        else         READ_HUFFSYM_SAFE(MATCHLEN1, len);

        if (len > 0) {
            len += 2;
            lit_run = 0; /* not the end of a literal run */
            READ_HUFFSYM_SAFE(OFFSET, j); offset = j << 6;
            READ_BITS_SAFE(j, 6);         offset |= j;

            /* copy match as output and into the ring buffer */
            while (len-- > 0) {
                lzh->window[pos] = lzh->window[(pos+4096-offset) & 4095];
                WRITE_BYTE;
                pos++; pos &= 4095;
            }
        }
        else {
            READ_HUFFSYM_SAFE(LITLEN, len); len++;
            lit_run = (len == 32) ? 0 : 1; /* end of a literal run? */
            while (len-- > 0) {
                READ_HUFFSYM_SAFE(LITERAL, j);
                /* copy as output and into the ring buffer */
                lzh->window[pos] = j;
                WRITE_BYTE;
                pos++; pos &= 4095;
            }
        }
    }
    return MSPACK_ERR_OK;
}

static void lzh_free(struct kwajd_stream *lzh)
{
    struct mspack_system *sys;
    if (!lzh || !lzh->sys) return;
    sys = lzh->sys;
    sys->free(lzh);
}

static int lzh_read_lens(struct kwajd_stream *lzh,
                         unsigned int type, unsigned int numsyms,
                         unsigned char *lens)
{
    register unsigned int bit_buffer;
    register int bits_left;
    unsigned char *i_ptr, *i_end;
    unsigned int i, c, sel;
    int err;

    RESTORE_BITS;
    switch (type) {
    case 0:
        i = numsyms; c = (i==16)?4: (i==32)?5: (i==64)?6: (i==256)?8 :0;
        for (i = 0; i < numsyms; i++) lens[i] = c;
        break;

    case 1:
        READ_BITS_SAFE(c, 4); lens[0] = c;
        for (i = 1; i < numsyms; i++) {
                   READ_BITS_SAFE(sel, 1); if (sel == 0)  lens[i] = c;
            else { READ_BITS_SAFE(sel, 1); if (sel == 0)  lens[i] = ++c;
            else { READ_BITS_SAFE(c, 4);                  lens[i] = c; }}
        }
        break;

    case 2:
        READ_BITS_SAFE(c, 4); lens[0] = c;
        for (i = 1; i < numsyms; i++) {
            READ_BITS_SAFE(sel, 2);
            if (sel == 3) READ_BITS_SAFE(c, 4); else c += (char) sel-1;
            lens[i] = c;
        }
        break;

    case 3:
        for (i = 0; i < numsyms; i++) {
            READ_BITS_SAFE(c, 4); lens[i] = c;
        }
        break;
    }
    STORE_BITS;
    return MSPACK_ERR_OK;
}

static int lzh_read_input(struct kwajd_stream *lzh) {
    int read;
    if (lzh->input_end) {
        lzh->input_end += 8;
        lzh->inbuf[0] = 0;
        read = 1;
    }
    else {
        read = lzh->sys->read(lzh->input, &lzh->inbuf[0], KWAJ_INPUT_SIZE);
        if (read < 0) return MSPACK_ERR_READ;
        if (read == 0) {
            lzh->input_end = 8;
            lzh->inbuf[0] = 0;
            read = 1;
        }
    }

    /* update i_ptr and i_end */
    lzh->i_ptr = &lzh->inbuf[0];
    lzh->i_end = &lzh->inbuf[read];
    return MSPACK_ERR_OK;
}