libclamav/cache.c
60b21251
 #if HAVE_CONFIG_H
 #include "clamav-config.h"
 #endif
a99a6040
 #include <string.h>
 #include <stdlib.h>
 #include <pthread.h>
 #include <assert.h>
60b21251
 
511c2e79
 #include "md5.h"
 #include "mpool.h"
 #include "clamav.h"
 #include "cache.h"
ea6f1300
 #include "fmap.h"
511c2e79
 
 #define CACHE_PERTURB 10
 /* 1/10th */
 static mpool_t *mempool = NULL;
 static struct CACHE {
     struct CACHE_ENTRY {
 	unsigned char hash[15];
 	uint32_t dbver;
 	uint32_t hits;
     } *items;
     pthread_mutex_t mutex;
     uint32_t lastdb;
 } *cache = NULL;
 static unsigned int cache_entries = 0;
 
ea6f1300
 int cl_cache_init(unsigned int entries) {
511c2e79
     unsigned int i;
 
     if(!(mempool = mpool_create())) {
 	cli_errmsg("mpool init fail\n");
 	return 1;
     }
     if(!(cache = mpool_malloc(mempool, sizeof(struct CACHE) * 256))) {
 	cli_errmsg("mpool malloc fail\n");
 	mpool_destroy(mempool);
 	return 1;
     }
 
     for(i=0; i<256; i++) {
 	struct CACHE_ENTRY *e = mpool_calloc(mempool, sizeof(struct CACHE_ENTRY), entries);
 	if(!e) {
 	    cli_errmsg("mpool calloc fail\n");
 	    mpool_destroy(mempool);
 	    return 1;
 	}
 	cache[i].items = e;
 	cache[i].lastdb = 0;
 	if(pthread_mutex_init(&cache[i].mutex, NULL)) {
 	    cli_errmsg("mutex init fail\n");
 	    mpool_destroy(mempool);
 	    return 1;
 	}
     }
     cache_entries = entries;
     return 0;
 }
 
 void cache_swap(struct CACHE_ENTRY *e, unsigned int a) {
     struct CACHE_ENTRY t;
     unsigned int b = a-1;
 
     if(!a || e[a].hits <= e[b].hits)
 	return;
 
     do {
 	if(e[a].hits > e[b].hits)
 	    continue;
 	break;
     } while(b--);
     b++;
 
     memcpy(&t, &e[a], sizeof(t));
     memcpy(&e[a], &e[b], sizeof(t));
     memcpy(&e[b], &t, sizeof(t));
 }
 
ea6f1300
 static void updb(uint32_t db, unsigned int skip) {
511c2e79
     unsigned int i;
     for(i=0; i<256; i++) {
 	if(i==skip) continue;
 	if(pthread_mutex_lock(&cache[i].mutex)) {
 	    cli_errmsg("mutex lock fail\n");
 	    continue;
 	}
 	cache[i].lastdb = db;
 	pthread_mutex_unlock(&cache[i].mutex);	
     }
 }
 
ea6f1300
 static int cache_lookup_hash(unsigned char *md5, cli_ctx *ctx) {
511c2e79
     unsigned int i;
     int ret = CL_VIRUS;
     struct CACHE_ENTRY *e;
     struct CACHE *c;
 
     if(!cache) return ret;
 
     c = &cache[*md5];
     e = c->items;
     if(pthread_mutex_lock(&c->mutex)) {
 	cli_errmsg("mutex lock fail\n");
 	return ret;
     }
     if(c->lastdb <= ctx->engine->dbversion[0]) {
 	if(c->lastdb < ctx->engine->dbversion[0]) {
 	    c->lastdb = ctx->engine->dbversion[0];
 	    updb(c->lastdb, *md5);
 	} else {
 	    for(i=0; i<cache_entries; i++) {
 		if(!e[i].hits) break;
 		if(e[i].dbver == c->lastdb && !memcmp(e[i].hash, md5 + 1, 15)) {
 		    e[i].hits++;
 		    cache_swap(e, i);
 		    ret = CL_CLEAN;
 		    cli_warnmsg("cached\n");
 		    break;
 		}
 	    }
 	}
     }
     pthread_mutex_unlock(&c->mutex);
     return ret;
 }
 
 void cache_add(unsigned char *md5, cli_ctx *ctx) {
     unsigned int i, replace;
     struct CACHE_ENTRY *e;
     struct CACHE *c;
 
     if(!cache) return;
 
     c = &cache[*md5];
     e = c->items;
     if(pthread_mutex_lock(&c->mutex)) {
 	cli_errmsg("mutex lock fail\n");
 	return;
     }
     if(c->lastdb == ctx->engine->dbversion[0]) {
 	replace = cache_entries;
 	for(i=0; i<cache_entries; i++) {
 	    if(!e[i].hits) break;
 	    if(replace == cache_entries && e[i].dbver < c->lastdb) {
 		replace = i;
 	    } else if(e[i].hits && !memcmp(e[i].hash, md5 + 1, 15)) {
 		e[i].hits++;
 		cache_swap(e, i);
 		pthread_mutex_unlock(&c->mutex);
 		return;
 	    }
 	}
 	if(replace == cache_entries)
 	    replace = cache_entries - 1 - (rand() % (cache_entries / CACHE_PERTURB));
 	e[replace].hits = 1;
 	e[replace].dbver = c->lastdb;
 	memcpy(e[replace].hash, md5 + 1, 15);
 	cache_swap(e, replace);
     }
     pthread_mutex_unlock(&c->mutex);
     return;
 }
 
ea6f1300
 int cache_check(unsigned char *hash, cli_ctx *ctx) {
     fmap_t *map = *ctx->fmap;
54e6b31f
     size_t todo = map->len, at = 0;
511c2e79
     cli_md5_ctx md5;
 
     if(!cache) return CL_VIRUS;
 
     cli_md5_init(&md5);
ea6f1300
     while(todo) {
 	void *buf;
 	size_t readme = todo < FILEBUFF ? todo : FILEBUFF;
54e6b31f
 	if(!(buf = fmap_need_off_once(map, at, readme)))
511c2e79
 	    return CL_VIRUS;
ea6f1300
 	todo -= readme;
 	at += readme;
511c2e79
 	cli_md5_update(&md5, buf, readme);
     }
     cli_md5_final(hash, &md5);
ea6f1300
     return cache_lookup_hash(hash, ctx);
511c2e79
 }
a99a6040
 
 struct cache_key {
     char digest[16];
     uint32_t size; /* 0 is used to mark an empty hash slot! */
     struct cache_key *lru_next, *lru_prev;
 };
 
 struct cache_set {
     struct cache_key *data;
     size_t capacity;
     size_t maxelements;/* considering load factor */
     size_t elements;
     size_t version;
     struct cache_key *lru_head, *lru_tail;
     pthread_mutex_t mutex;
 };
 
 #define CACHE_INVALID_VERSION ~0u
 
 /* size must be power of 2! */
 static int cacheset_init(struct cache_set* map, size_t maxsize, uint8_t loadfactor)
 {
     map->data = cli_calloc(maxsize, sizeof(*map->data));
     if (!map->data)
 	return CL_EMEM;
     map->capacity = maxsize;
     map->maxelements = loadfactor*maxsize / 100;
     map->elements = 0;
     map->version = CACHE_INVALID_VERSION;
     map->lru_head = map->lru_tail = NULL;
     if (pthread_mutex_init(&map->mutex, NULL)) {
 	cli_errmsg("mutex init fail\n");
 	return CL_EMEM;
     }
 }
 
 static void cacheset_destroy(struct cache_set *map)
 {
     pthread_mutex_destroy(&map->mutex);
     free(map->data);
 }
 
 static void cacheset_acquire(struct cache_set *map)
 {
     pthread_mutex_lock(&map->mutex);
 }
 
 static void cache_setversion(struct cache_set* map, uint32_t version)
 {
     unsigned i;
     if (map->version == version)
 	return;
     map->version = version;
     map->elements = 0;/* all elements have expired now */
     for (i=0;i<map->capacity;i++)
 	map->data[i].size = 0;
     map->lru_head = map->lru_tail = NULL;
 }
 
 static void cacheset_lru_remove(struct cache_set *map, size_t howmany)
 {
     while (howmany--) {
 	struct cache_key *old;
 	assert(map->lru_head);
 	assert(!old->lru_prev);
 	// Remove a key from the head of the list
 	old = map->lru_head;
 	map->lru_head = old->lru_next;
 	old->size = 0; /* this slot is now empty */
 	if (old == map->lru_tail)
 	    map->lru_tail = 0;
     }
 }
 
 static inline uint32_t hash32shift(uint32_t key)
 {
   key = ~key + (key << 15);
   key = key ^ (key >> 12);
   key = key + (key << 2);
   key = key ^ (key >> 4);
   key = (key + (key << 3)) + (key << 11);
   key = key ^ (key >> 16);
   return key;
 }
 
 static inline size_t hash(const unsigned char* k,const size_t len,const size_t SIZE)
 {
     size_t Hash = 1;
     size_t i;
     for(i=0;i<len;i++) {
 	/* a simple add is good, because we use the mixing function below */
 	Hash +=  k[i];
 	/* mixing function */
 	Hash = hash32shift(Hash);
     }
     /* SIZE is power of 2 */
     return Hash & (SIZE - 1);
 }
 
 int cacheset_lookup_internal(struct cache_set *map, const struct cache_key *key,
 			     uint32_t *insert_pos)
 {
     uint32_t idx = hash((const unsigned char*)key, sizeof(*key), map->capacity);
     uint32_t tries = 0;
     struct cache_key *k = &map->data[idx];
     while (k->size) {
 	if (k->size == key->size &&
 	    !memcmp(k->digest, key, 16)) {
 	    /* found key */
 	    *insert_pos = idx;
 	    return 1;
 	}
 	idx = (idx + tries++)&(map->capacity-1);
 	k = &map->data[idx];
     }
     /* found empty pos */
     *insert_pos = idx;
     return 0;
 }
 
 static inline void lru_remove(struct cache_set *map, struct cache_key *newkey)
 {
     if (newkey->lru_next)
 	newkey->lru_next->lru_prev = newkey->lru_prev;
     if (newkey->lru_prev)
 	newkey->lru_prev->lru_next = newkey->lru_next;
     if (newkey == map->lru_head)
 	map->lru_head = newkey->lru_next;
 }
 
 static inline void lru_addtail(struct cache_set *map, struct cache_key *newkey)
 {
     if (!map->lru_head)
 	map->lru_head = newkey;
     if (map->lru_tail)
 	map->lru_tail->lru_next = newkey;
     newkey->lru_next = NULL;
     newkey->lru_prev = map->lru_tail;
     map->lru_tail = newkey;
 }
 
 static void cacheset_add(struct cache_set *map, const struct cache_key *key)
 {
     int ret;
     uint32_t pos;
     struct cache_key *newkey;
     if (map->elements >= map->maxelements)
 	cacheset_lru_remove(map, 1);
     assert(map->elements < map->maxelements);
 
     ret = cacheset_lookup_internal(map, key, &pos);
     newkey = &map->data[pos];
     if (ret) {
 	/* was already added, remove from LRU list */
 	lru_remove(map, newkey);
     }
     /* add new key to tail of LRU list */
     lru_addtail(map, newkey);
 
     map->elements++;
 
     assert(pos < map->maxelements);
 
     memcpy(&map->data[pos], key, sizeof(*key));
 }
 
 static int cacheset_lookup(struct cache_set *map, const struct cache_key *key)
 {
     struct cache_key *newkey;
     int ret;
     uint32_t pos;
     ret = cacheset_lookup_internal(map, key, &pos);
     if (!ret)
 	return CACHE_INVALID_VERSION;
     newkey = &map->data[pos];
     /* update LRU position: move to tail */
     lru_remove(map, newkey);
     lru_addtail(map, newkey);
 
     return map->version;
 }
 
 static void cacheset_release(struct cache_set *map)
 {
     pthread_mutex_unlock(&map->mutex);
 }
 
 #if 0
 int main(int argc, char **argv)
 {
     struct cache_key key;
     struct cache_set map;
     cacheset_init(&map, 256, 80);
     cacheset_acquire(&map);
     cache_setversion(&map, 10);
 
     key.size = 1024;
     memcpy(key.digest, "1234567890123456", 16);
     cacheset_add(&map, &key);
     memcpy(key.digest, "1234567890123457", 16);
     cacheset_add(&map, &key);
     memcpy(key.digest, "0123456789012345", 16);
     cacheset_add(&map, &key);
 
     key.size = 1024;
     memcpy(key.digest, "1234567890123456", 16);
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
     memcpy(key.digest, "1234567890123456", 16);
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
     memcpy(key.digest, "1234567890123457", 16);
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
     memcpy(key.digest, "0123456789012345", 16);
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
     memcpy(key.digest, "0123456789012346", 16);
     if (cacheset_lookup(&map, &key) == 10)
 	abort();
 
     cache_setversion(&map, 1);
     memcpy(key.digest, "1234567890123456", 16);
     if (cacheset_lookup(&map, &key) != CACHE_INVALID_VERSION)
 	abort();
     memcpy(key.digest, "1234567890123456", 16);
     if (cacheset_lookup(&map, &key) != CACHE_INVALID_VERSION)
 	abort();
     memcpy(key.digest, "1234567890123457", 16);
     if (cacheset_lookup(&map, &key) != CACHE_INVALID_VERSION)
 	abort();
     memcpy(key.digest, "0123456789012345", 16);
     if (cacheset_lookup(&map, &key) != CACHE_INVALID_VERSION)
 	abort();
 
     cacheset_release(&map);
 
     cacheset_destroy(&map);
 
     cacheset_init(&map, 8, 50);
     cacheset_acquire(&map);
     cache_setversion(&map, 10);
 
     key.size = 416;
     memcpy(key.digest, "1234567890123456", 16);
     cacheset_add(&map, &key);
     memcpy(key.digest, "1234567890123457", 16);
     cacheset_add(&map, &key);
     memcpy(key.digest, "1234567890123459", 16);
     cacheset_add(&map, &key);
     key.size = 400;
     memcpy(key.digest, "1234567890123450", 16);
     cacheset_add(&map, &key);
 
     key.size = 416;
     memcpy(key.digest, "1234567890123456", 16);
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
     if (cacheset_lookup(&map, &key) != 10)
 	abort();
 
     key.size = 500;
     cacheset_add(&map, &key);
     memcpy(key.digest, "1234567890123457", 16);
     if (cacheset_lookup(&map, &key) == 10)
 	abort();
 
     cacheset_release(&map);
     cacheset_destroy(&map);
 
     return 0;
 }
 #endif