/*
 *  Copyright (C) 2013-2019 Cisco Systems, Inc. and/or its affiliates. All rights reserved.
 *  Copyright (C) 2007-2013 Sourcefire, Inc.
 *
 *  Authors: Tomasz Kojm, Török Edvin
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License version 2 as
 *  published by the Free Software Foundation.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 *  MA 02110-1301, USA.
 */

#if HAVE_CONFIG_H
#include "clamav-config.h"
#endif

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif
#include <sys/types.h>
#include <dirent.h>
#ifndef _WIN32
#include <sys/socket.h>
#ifdef HAVE_SYS_SELECT_H
#include <sys/select.h>
#endif
#ifdef HAVE_FD_PASSING
#ifdef HAVE_SYS_UIO_H
#include <sys/uio.h>
#endif
#endif

#include <sys/time.h>
#endif
#include <pthread.h>
#include <time.h>
#include <errno.h>
#include <stddef.h>
#include <limits.h>

#include "libclamav/clamav.h"
#include "libclamav/str.h"
#include "libclamav/others.h"

#include "shared/optparser.h"
#include "shared/output.h"
#include "shared/misc.h"

#include "others.h"
#include "scanner.h"
#include "server.h"
#include "session.h"
#include "thrmgr.h"

#ifndef HAVE_FDPASSING
#define FEATURE_FDPASSING 0
#else
#define FEATURE_FDPASSING 1
#endif

static struct {
    const char *cmd;
    const size_t len;
    enum commands cmdtype;
    int need_arg;
    int support_old;
    int enabled;
} commands[] = {
    {CMD1, sizeof(CMD1) - 1, COMMAND_SCAN, 1, 1, 0},
    {CMD3, sizeof(CMD3) - 1, COMMAND_SHUTDOWN, 0, 1, 0},
    {CMD4, sizeof(CMD4) - 1, COMMAND_RELOAD, 0, 1, 0},
    {CMD5, sizeof(CMD5) - 1, COMMAND_PING, 0, 1, 0},
    {CMD6, sizeof(CMD6) - 1, COMMAND_CONTSCAN, 1, 1, 0},
    /* must be before VERSION, because they share common prefix! */
    {CMD18, sizeof(CMD18) - 1, COMMAND_COMMANDS, 0, 0, 1},
    {CMD7, sizeof(CMD7) - 1, COMMAND_VERSION, 0, 1, 1},
    {CMD8, sizeof(CMD8) - 1, COMMAND_STREAM, 0, 1, 1},
    {CMD10, sizeof(CMD10) - 1, COMMAND_END, 0, 0, 1},
    {CMD11, sizeof(CMD11) - 1, COMMAND_SHUTDOWN, 0, 1, 1},
    {CMD13, sizeof(CMD13) - 1, COMMAND_MULTISCAN, 1, 1, 1},
    {CMD14, sizeof(CMD14) - 1, COMMAND_FILDES, 0, 1, FEATURE_FDPASSING},
    {CMD15, sizeof(CMD15) - 1, COMMAND_STATS, 0, 0, 1},
    {CMD16, sizeof(CMD16) - 1, COMMAND_IDSESSION, 0, 0, 1},
    {CMD17, sizeof(CMD17) - 1, COMMAND_INSTREAM, 0, 0, 1},
    {CMD19, sizeof(CMD19) - 1, COMMAND_DETSTATSCLEAR, 0, 1, 1},
    {CMD20, sizeof(CMD20) - 1, COMMAND_DETSTATS, 0, 1, 1},
    {CMD21, sizeof(CMD21) - 1, COMMAND_ALLMATCHSCAN, 1, 0, 1}};

enum commands parse_command(const char *cmd, const char **argument, int oldstyle)
{
    size_t i;
    *argument = NULL;
    for (i = 0; i < sizeof(commands) / sizeof(commands[0]); i++) {
        const size_t len = commands[i].len;
        if (!strncmp(cmd, commands[i].cmd, len)) {
            const char *arg = cmd + len;
            if (commands[i].need_arg) {
                if (!*arg) { /* missing argument */
                    logg("$Command %s missing argument!\n", commands[i].cmd);
                    return COMMAND_UNKNOWN;
                }
                *argument = arg + 1;
            } else {
                if (*arg) { /* extra stuff after command */
                    logg("$Command %s has trailing garbage!\n", commands[i].cmd);
                    return COMMAND_UNKNOWN;
                }
                *argument = NULL;
            }
            if (oldstyle && !commands[i].support_old) {
                logg("$Command sent as old-style when not supported: %s\n", commands[i].cmd);
                return COMMAND_UNKNOWN;
            }
            return commands[i].cmdtype;
        }
    }
    return COMMAND_UNKNOWN;
}

int conn_reply_single(const client_conn_t *conn, const char *path, const char *status)
{
    if (conn->id) {
        if (path)
            return mdprintf(conn->sd, "%u: %s: %s%c", conn->id, path, status, conn->term);
        return mdprintf(conn->sd, "%u: %s%c", conn->id, status, conn->term);
    }
    if (path)
        return mdprintf(conn->sd, "%s: %s%c", path, status, conn->term);
    return mdprintf(conn->sd, "%s%c", status, conn->term);
}

int conn_reply(const client_conn_t *conn, const char *path,
               const char *msg, const char *status)
{
    if (conn->id) {
        if (path)
            return mdprintf(conn->sd, "%u: %s: %s %s%c", conn->id, path, msg,
                            status, conn->term);
        return mdprintf(conn->sd, "%u: %s %s%c", conn->id, msg, status,
                        conn->term);
    }
    if (path)
        return mdprintf(conn->sd, "%s: %s %s%c", path, msg, status, conn->term);
    return mdprintf(conn->sd, "%s %s%c", msg, status, conn->term);
}

int conn_reply_virus(const client_conn_t *conn, const char *file,
                     const char *virname)
{
    if (conn->id) {
        return mdprintf(conn->sd, "%u: %s: %s FOUND%c", conn->id, file, virname,
                        conn->term);
    }
    return mdprintf(conn->sd, "%s: %s FOUND%c", file, virname, conn->term);
}

int conn_reply_error(const client_conn_t *conn, const char *msg)
{
    return conn_reply(conn, NULL, msg, "ERROR");
}

#define BUFFSIZE 1024
int conn_reply_errno(const client_conn_t *conn, const char *path,
                     const char *msg)
{
    char err[BUFFSIZE + sizeof(". ERROR")];
    cli_strerror(errno, err, BUFFSIZE - 1);
    strcat(err, ". ERROR");
    return conn_reply(conn, path, msg, err);
}

/* returns
 *  -1 on fatal error (shutdown)
 *  0 on ok
 *  >0 errors encountered
 */
int command(client_conn_t *conn, int *virus)
{
    int desc                        = conn->sd;
    struct cl_engine *engine        = conn->engine;
    struct cl_scan_options *options = conn->options;
    const struct optstruct *opts    = conn->opts;
    enum scan_type type             = TYPE_INIT;
    int maxdirrec;
    int ret   = 0;
    int flags = CLI_FTW_STD;

    struct scan_cb_data scandata;
    struct cli_ftw_cbdata data;
    unsigned ok, error, total;
    STATBUF sb;
    jobgroup_t *group = NULL;

    if (thrmgr_group_need_terminate(conn->group)) {
        logg("$Client disconnected while command was active\n");
        if (conn->scanfd != -1)
            close(conn->scanfd);
        return 1;
    }
    thrmgr_setactiveengine(engine);

    data.data = &scandata;
    memset(&scandata, 0, sizeof(scandata));
    scandata.id            = conn->id;
    scandata.group         = conn->group;
    scandata.odesc         = desc;
    scandata.conn          = conn;
    scandata.options       = options;
    scandata.engine        = engine;
    scandata.opts          = opts;
    scandata.thr_pool      = conn->thrpool;
    scandata.toplevel_path = conn->filename;

    switch (conn->cmdtype) {
        case COMMAND_SCAN:
            thrmgr_setactivetask(NULL, "SCAN");
            type = TYPE_SCAN;
            break;
        case COMMAND_CONTSCAN:
            thrmgr_setactivetask(NULL, "CONTSCAN");
            type = TYPE_CONTSCAN;
            break;
        case COMMAND_MULTISCAN: {
            int multiscan, max, alive;

            /* use MULTISCAN only for directories (bb #1869) */
            if (CLAMSTAT(conn->filename, &sb) == 0 &&
                !S_ISDIR(sb.st_mode)) {
                thrmgr_setactivetask(NULL, "CONTSCAN");
                type = TYPE_CONTSCAN;
                break;
            }

            pthread_mutex_lock(&conn->thrpool->pool_mutex);
            multiscan = conn->thrpool->thr_multiscan;
            max       = conn->thrpool->thr_max;
            if (multiscan + 1 < max)
                conn->thrpool->thr_multiscan = multiscan + 1;
            else {
                alive = conn->thrpool->thr_alive;
                ret   = -1;
            }
            pthread_mutex_unlock(&conn->thrpool->pool_mutex);
            if (ret) {
                /* multiscan has 1 control thread, so there needs to be at least
		   1 threads that is a non-multiscan controlthread to scan and
		   make progress. */
                logg("^Not enough threads for multiscan. Max: %d, Alive: %d, Multiscan: %d+1\n",
                     max, alive, multiscan);
                conn_reply(conn, conn->filename, "Not enough threads for multiscan. Increase MaxThreads.", "ERROR");
                return 1;
            }
            flags &= ~CLI_FTW_NEED_STAT;
            thrmgr_setactivetask(NULL, "MULTISCAN");
            type           = TYPE_MULTISCAN;
            scandata.group = group = thrmgr_group_new();
            if (!group) {
                if (optget(opts, "ExitOnOOM")->enabled)
                    return -1;
                else
                    return 1;
            }
            break;
        }
        case COMMAND_MULTISCANFILE:
            thrmgr_setactivetask(NULL, "MULTISCANFILE");
            scandata.group    = NULL;
            scandata.type     = TYPE_SCAN;
            scandata.thr_pool = NULL;
            /* TODO: check ret value */
            ret            = scan_callback(NULL, conn->filename, conn->filename, visit_file, &data); /* callback freed it */
            conn->filename = NULL;
            *virus         = scandata.infected;
            if (ret == CL_BREAK) {
                thrmgr_group_terminate(conn->group);
                return 1;
            }
            return scandata.errors > 0 ? scandata.errors : 0;
        case COMMAND_FILDES:
            thrmgr_setactivetask(NULL, "FILDES");
#ifdef HAVE_FD_PASSING
            if (conn->scanfd == -1) {
                conn_reply_error(conn, "FILDES: didn't receive file descriptor.");
                return 1;
            } else {
                ret = scanfd(conn, NULL, engine, options, opts, desc, 0);
                if (ret == CL_VIRUS) {
                    *virus = 1;
                    ret    = 0;
                } else if (ret == CL_EMEM) {
                    if (optget(opts, "ExitOnOOM")->enabled)
                        ret = -1;
                    else
                        ret = 1;
                } else if (ret == CL_ETIMEOUT) {
                    thrmgr_group_terminate(conn->group);
                    ret = 1;
                } else
                    ret = 0;
                logg("$Closed fd %d\n", conn->scanfd);
                close(conn->scanfd);
            }
            return ret;
#else
            conn_reply_error(conn, "FILDES support not compiled in.");
            close(conn->scanfd);
            return 0;
#endif
        case COMMAND_STATS:
            thrmgr_setactivetask(NULL, "STATS");
            if (conn->group)
                mdprintf(desc, "%u: ", conn->id);
            thrmgr_printstats(desc, conn->term);
            return 0;
        case COMMAND_STREAM:
            thrmgr_setactivetask(NULL, "STREAM");
            ret = scanstream(desc, NULL, engine, options, opts, conn->term);
            if (ret == CL_VIRUS)
                *virus = 1;
            if (ret == CL_EMEM) {
                if (optget(opts, "ExitOnOOM")->enabled)
                    return -1;
                else
                    return 1;
            }
            return 0;
        case COMMAND_INSTREAMSCAN:
            thrmgr_setactivetask(NULL, "INSTREAM");
            ret = scanfd(conn, NULL, engine, options, opts, desc, 1);
            if (ret == CL_VIRUS) {
                *virus = 1;
                ret    = 0;
            } else if (ret == CL_EMEM) {
                if (optget(opts, "ExitOnOOM")->enabled)
                    ret = -1;
                else
                    ret = 1;
            } else if (ret == CL_ETIMEOUT) {
                thrmgr_group_terminate(conn->group);
                ret = 1;
            } else
                ret = 0;
            if (ftruncate(conn->scanfd, 0) == -1) {
                /* not serious, we're going to close it and unlink it anyway */
                logg("*ftruncate failed: %d\n", errno);
            }
            close(conn->scanfd);
            conn->scanfd = -1;
            cli_unlink(conn->filename);
            return ret;
        case COMMAND_ALLMATCHSCAN:
            if (!optget(opts, "AllowAllMatchScan")->enabled) {
                logg("$Rejecting ALLMATCHSCAN command.\n");
                conn_reply(conn, conn->filename, "ALLMATCHSCAN command disabled by clamd configuration.", "ERROR");
                return 1;
            }
            thrmgr_setactivetask(NULL, "ALLMATCHSCAN");
            scandata.options->general |= CL_SCAN_GENERAL_ALLMATCHES;
            type = TYPE_SCAN;
            break;
        default:
            logg("!Invalid command dispatched: %d\n", conn->cmdtype);
            return 1;
    }

    scandata.type = type;
    maxdirrec     = optget(opts, "MaxDirectoryRecursion")->numarg;
    if (optget(opts, "FollowDirectorySymlinks")->enabled)
        flags |= CLI_FTW_FOLLOW_DIR_SYMLINK;
    if (optget(opts, "FollowFileSymlinks")->enabled)
        flags |= CLI_FTW_FOLLOW_FILE_SYMLINK;

    if (!optget(opts, "CrossFilesystems")->enabled)
        if (CLAMSTAT(conn->filename, &sb) == 0)
            scandata.dev = sb.st_dev;

    ret = cli_ftw(conn->filename, flags, maxdirrec ? maxdirrec : INT_MAX, scan_callback, &data, scan_pathchk);
    if (ret == CL_EMEM) {
        if (optget(opts, "ExitOnOOM")->enabled)
            return -1;
        else
            return 1;
    }
    if (scandata.group && type == TYPE_MULTISCAN) {
        thrmgr_group_waitforall(group, &ok, &error, &total);
        pthread_mutex_lock(&conn->thrpool->pool_mutex);
        conn->thrpool->thr_multiscan--;
        pthread_mutex_unlock(&conn->thrpool->pool_mutex);
    } else {
        error = scandata.errors;
        total = scandata.total;
        ok    = total - error - scandata.infected;
    }

    if (ok + error == total && (error != total)) {
        if (conn_reply_single(conn, conn->filename, "OK") == -1)
            ret = CL_ETIMEOUT;
    }
    *virus = total - (ok + error);

    if (ret == CL_ETIMEOUT)
        thrmgr_group_terminate(conn->group);
    return error;
}

static int dispatch_command(client_conn_t *conn, enum commands cmd, const char *argument)
{
    int ret = 0;
    int bulk;
    client_conn_t *dup_conn = (client_conn_t *)malloc(sizeof(struct client_conn_tag));

    if (!dup_conn) {
        logg("!Can't allocate memory for client_conn\n");
        return -1;
    }
    memcpy(dup_conn, conn, sizeof(*conn));
    dup_conn->cmdtype = cmd;
    if (cl_engine_addref(dup_conn->engine)) {
        logg("!cl_engine_addref() failed\n");
        free(dup_conn);
        return -1;
    }
    dup_conn->scanfd = -1;
    bulk             = 1;
    switch (cmd) {
        case COMMAND_FILDES:
            if (conn->scanfd == -1) {
                conn_reply_error(dup_conn, "No file descriptor received.");
                ret = 1;
            }
            dup_conn->scanfd = conn->scanfd;
            /* consume FD */
            conn->scanfd = -1;
            break;
        case COMMAND_SCAN:
        case COMMAND_CONTSCAN:
        case COMMAND_MULTISCAN:
        case COMMAND_ALLMATCHSCAN:
            dup_conn->filename = cli_strdup_to_utf8(argument);
            if (!dup_conn->filename) {
                logg("!Failed to allocate memory for filename\n");
                ret = -1;
            }
            break;
        case COMMAND_INSTREAMSCAN:
            dup_conn->scanfd = conn->scanfd;
            conn->scanfd     = -1;
            break;
        case COMMAND_STREAM:
        case COMMAND_STATS:
            /* not a scan command, don't queue to bulk */
            bulk = 0;
            /* just dispatch the command */
            break;
        default:
            logg("!Invalid command dispatch: %d\n", cmd);
            ret = -2;
            break;
    }
    if (!dup_conn->group)
        bulk = 0;
    if (!ret && !thrmgr_group_dispatch(dup_conn->thrpool, dup_conn->group, dup_conn, bulk)) {
        logg("!thread dispatch failed\n");
        ret = -2;
    }
    if (ret) {
        cl_engine_free(dup_conn->engine);
        free(dup_conn);
    }
    return ret;
}

static int print_ver(int desc, char term, const struct cl_engine *engine)
{
    uint32_t ver;

    ver = cl_engine_get_num(engine, CL_ENGINE_DB_VERSION, NULL);
    if (ver) {
        char timestr[32];
        const char *tstr;
        time_t t;
        t    = cl_engine_get_num(engine, CL_ENGINE_DB_TIME, NULL);
        tstr = cli_ctime(&t, timestr, sizeof(timestr));
        /* cut trailing \n */
        timestr[strlen(tstr) - 1] = '\0';
        return mdprintf(desc, "ClamAV %s/%u/%s%c", get_version(), (unsigned int)ver, tstr, term);
    }
    return mdprintf(desc, "ClamAV %s%c", get_version(), term);
}

static void print_commands(int desc, char term, const struct cl_engine *engine)
{
    unsigned i, n;
    const char *engine_ver = cl_retver();
    const char *clamd_ver  = get_version();
    if (strcmp(engine_ver, clamd_ver)) {
        mdprintf(desc, "ENGINE VERSION MISMATCH: %s != %s. ERROR%c",
                 engine_ver, clamd_ver, term);
        return;
    }
    print_ver(desc, '|', engine);
    mdprintf(desc, " COMMANDS:");
    n = sizeof(commands) / sizeof(commands[0]);
    for (i = 0; i < n; i++) {
        mdprintf(desc, " %s", commands[i].cmd);
    }
    mdprintf(desc, "%c", term);
}

/* returns:
 *  <0 for error
 *     -1 out of memory
 *     -2 other
 *   0 for async dispatched
 *   1 for command completed (connection can be closed)
 */
int execute_or_dispatch_command(client_conn_t *conn, enum commands cmd, const char *argument)
{
    int desc                       = conn->sd;
    char term                      = conn->term;
    const struct cl_engine *engine = conn->engine;
    /* execute commands that can be executed quickly on the recvloop thread,
     * these must:
     *  - not involve any operation that can block for a long time, such as disk
     *  I/O
     *  - send of atomic message is allowed.
     * Dispatch other commands */
    if (conn->group) {
        switch (cmd) {
            case COMMAND_FILDES:
            case COMMAND_SCAN:
            case COMMAND_END:
            case COMMAND_INSTREAM:
            case COMMAND_INSTREAMSCAN:
            case COMMAND_VERSION:
            case COMMAND_PING:
            case COMMAND_STATS:
            case COMMAND_COMMANDS:
                /* These commands are accepted inside IDSESSION */
                break;
            default:
                /* these commands are not recognized inside an IDSESSION */
                conn_reply_error(conn, "Command invalid inside IDSESSION.");
                logg("$SESSION: command is not valid inside IDSESSION: %d\n", cmd);
                conn->group = NULL;
                return 1;
        }
    }

    switch (cmd) {
        case COMMAND_SHUTDOWN:
            pthread_mutex_lock(&exit_mutex);
            progexit = 1;
            pthread_mutex_unlock(&exit_mutex);
            return 1;
        case COMMAND_RELOAD:
            pthread_mutex_lock(&reload_mutex);
            reload = 1;
            pthread_mutex_unlock(&reload_mutex);
            mdprintf(desc, "RELOADING%c", term);
            /* we set reload flag, and we'll reload before closing the
	     * connection */
            return 1;
        case COMMAND_PING:
            if (conn->group)
                mdprintf(desc, "%u: PONG%c", conn->id, term);
            else
                mdprintf(desc, "PONG%c", term);
            return conn->group ? 0 : 1;
        case COMMAND_VERSION: {
            if (conn->group)
                mdprintf(desc, "%u: ", conn->id);
            print_ver(desc, conn->term, engine);
            return conn->group ? 0 : 1;
        }
        case COMMAND_COMMANDS: {
            if (conn->group)
                mdprintf(desc, "%u: ", conn->id);
            print_commands(desc, conn->term, engine);
            return conn->group ? 0 : 1;
        }
        case COMMAND_DETSTATSCLEAR: {
            /* TODO: tell client this command has been removed */
            return 1;
        }
        case COMMAND_DETSTATS: {
            /* TODO: tell client this command has been removed */
            return 1;
        }
        case COMMAND_INSTREAM: {
            int rc = cli_gentempfd(optget(conn->opts, "TemporaryDirectory")->strarg, &conn->filename, &conn->scanfd);
            if (rc != CL_SUCCESS)
                return rc;
            conn->quota = optget(conn->opts, "StreamMaxLength")->numarg;
            conn->mode  = MODE_STREAM;
            return 0;
        }
        case COMMAND_STREAM:
        case COMMAND_MULTISCAN:
        case COMMAND_CONTSCAN:
        case COMMAND_STATS:
        case COMMAND_FILDES:
        case COMMAND_SCAN:
        case COMMAND_INSTREAMSCAN:
        case COMMAND_ALLMATCHSCAN:
            return dispatch_command(conn, cmd, argument);
        case COMMAND_IDSESSION:
            conn->group = thrmgr_group_new();
            if (!conn->group)
                return CL_EMEM;
            return 0;
        case COMMAND_END:
            if (!conn->group) {
                /* end without idsession? */
                conn_reply_single(conn, NULL, "UNKNOWN COMMAND");
                return 1;
            }
            /* need to close connection  if we were last in group */
            return 1;
        /*case COMMAND_UNKNOWN:*/
        default:
            conn_reply_single(conn, NULL, "UNKNOWN COMMAND");
            return 1;
    }
}