Browse code

Initial support for __match_count.

Török Edvin authored on 2009/09/22 05:44:32
Showing 4 changed files
... ...
@@ -98,6 +98,7 @@ struct cli_bc_ctx {
98 98
     size_t file_size;
99 99
     off_t off;
100 100
     int fd;
101
+    const uint32_t *lsigcnt;
101 102
 };
102 103
 struct cli_all_bc;
103 104
 int cli_vm_execute(const struct cli_bc *bc, struct cli_bc_ctx *ctx, const struct cli_bc_func *func, const struct cli_bc_inst *inst);
... ...
@@ -21,6 +21,7 @@
21 21
  */
22 22
 #define DEBUG_TYPE "clamavjit"
23 23
 #include "llvm/ADT/DenseMap.h"
24
+#include "llvm/ADT/BitVector.h"
24 25
 #include "llvm/CallingConv.h"
25 26
 #include "llvm/DerivedTypes.h"
26 27
 #include "llvm/Function.h"
... ...
@@ -74,6 +75,7 @@ struct cli_bcengine {
74 74
 namespace {
75 75
 
76 76
 static sys::ThreadLocal<const jmp_buf> ExceptionReturn;
77
+static sys::ThreadLocal<const jmp_buf> MatchCounts;
77 78
 
78 79
 void do_shutdown() {
79 80
     llvm_shutdown();
... ...
@@ -187,7 +189,7 @@ private:
187 187
     ExecutionEngine *EE;
188 188
     TargetFolder Folder;
189 189
     IRBuilder<false, TargetFolder> Builder;
190
-    std::vector<GlobalVariable*> globals;
190
+    std::vector<Value*> globals;
191 191
     Value **Values;
192 192
     FunctionPassManager &PM;
193 193
     unsigned numLocals;
... ...
@@ -234,11 +236,13 @@ private:
234 234
 	    operand &= 0x7fffffff;
235 235
 	    assert(operand < globals.size() && "Global index out of range");
236 236
 	    // Global
237
-	    GlobalVariable *GV = globals[operand];
238
-	    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(GV->getInitializer())) {
239
-		return CE;
237
+	    if (GlobalVariable *GV = dyn_cast<GlobalVariable>(globals[operand])) {
238
+		if (ConstantExpr *CE = dyn_cast<ConstantExpr>(GV->getInitializer())) {
239
+		    return CE;
240
+		}
241
+		return GV;
240 242
 	    }
241
-	    return GV;
243
+	    return globals[operand];
242 244
 	}
243 245
 	// Constant
244 246
 	operand -= func->numValues;
... ...
@@ -298,11 +302,13 @@ private:
298 298
     {
299 299
         if (isa<PointerType>(Ty)) {
300 300
           Constant *idxs[2] = {
301
-	      ConstantInt::get(Type::getInt32Ty(Context), 0), 
301
+	      ConstantInt::get(Type::getInt32Ty(Context), 0),
302 302
 	      ConstantInt::get(Type::getInt32Ty(Context), components[c++])
303 303
 	  };
304
-          GlobalVariable *GV = globals[components[c++]];
305
-          return ConstantExpr::getInBoundsGetElementPtr(GV, idxs, 2);
304
+	  unsigned idx = components[c++];
305
+	  assert(idx < globals.size());
306
+	  GlobalVariable *GV = cast<GlobalVariable>(globals[idx]);
307
+	  return ConstantExpr::getInBoundsGetElementPtr(GV, idxs, 2);
306 308
         }
307 309
 	if (isa<IntegerType>(Ty)) {
308 310
 	    return ConstantInt::get(Ty, components[c++]);
... ...
@@ -354,15 +360,31 @@ public:
354 354
 	const Type *HiddenCtx = PointerType::getUnqual(Type::getInt8Ty(Context));
355 355
 
356 356
 	globals.reserve(bc->num_globals);
357
+	// Fake GV for __match_counts, we'll replace this with loads from ctx!
358
+	const Type *MatchesTy = PointerType::getUnqual(Type::getInt32Ty(Context));//uint32*
359
+	BitVector FakeGVs;
360
+	FakeGVs.resize(bc->num_globals);
361
+
357 362
 	for (unsigned i=0;i<bc->num_globals;i++) {
358 363
 	    const Type *Ty = mapType(bc->globaltys[i]);
359 364
 
360 365
 	    // TODO: validate number of components against type_components
361 366
 	    unsigned c = 0;
367
+	    GlobalVariable *GV;
368
+	    if (isa<PointerType>(Ty)) {
369
+		switch (bc->globals[i][1]) {
370
+		default: break;
371
+		case GLOBAL_MATCH_COUNTS:
372
+		    assert(Ty == MatchesTy);
373
+		    FakeGVs.set(i);
374
+		    globals.push_back(0);
375
+		    continue;
376
+		}
377
+	    }
362 378
 	    Constant *C = buildConstant(Ty, bc->globals[i], c);
363
-	    GlobalVariable *GV = new GlobalVariable(*M, Ty, true,
364
-						    GlobalValue::InternalLinkage,
365
-						    C, "glob"+Twine(i));
379
+	    GV = new GlobalVariable(*M, Ty, true,
380
+				    GlobalValue::InternalLinkage,
381
+				    C, "glob"+Twine(i));
366 382
 	    globals.push_back(GV);
367 383
 	}
368 384
 
... ...
@@ -417,6 +439,28 @@ public:
417 417
 	    }
418 418
 	    numLocals = func->numLocals;
419 419
 	    numArgs = func->numArgs;
420
+
421
+	    if (FakeGVs.any()) {
422
+		Argument *Ctx = F->arg_begin();
423
+		struct cli_bc_ctx *N = 0;
424
+		unsigned offset = (char*)&((struct cli_bc_ctx*)0)->lsigcnt - (char*)NULL;
425
+		Constant *Idx = ConstantInt::get(Type::getInt32Ty(Context), offset);
426
+		Value *GEP = Builder.CreateInBoundsGEP(Ctx, Idx);
427
+		Value *Cast = Builder.CreateBitCast(GEP, PointerType::getUnqual(MatchesTy));
428
+		Value *__MatchesCount = Builder.CreateLoad(Cast);
429
+
430
+		for (unsigned i=0;i<bc->num_globals;i++) {
431
+		    if (!FakeGVs[i])
432
+			continue;
433
+		    switch (bc->globals[i][1]) {
434
+			case GLOBAL_MATCH_COUNTS:
435
+			    Constant *C = ConstantInt::get(Type::getInt32Ty(Context), bc->globals[i][0]);
436
+			    globals[i] = Builder.CreateInBoundsGEP(__MatchesCount, C);
437
+			    break;
438
+		    }
439
+		}
440
+	    }
441
+
420 442
 	    // Generate LLVM IR for each BB
421 443
 	    for (unsigned i=0;i<func->numBB;i++) {
422 444
 		const struct cli_bc_bb *bb = &func->BB[i];
... ...
@@ -823,7 +867,7 @@ int cli_bytecode_prepare_jit(struct cli_all_bc *bcs)
823 823
 	    const struct cli_bc *bc = &bcs->all_bcs[i];
824 824
 	    if (bc->state == bc_skip)
825 825
 		continue;
826
-	    LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE, 
826
+	    LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE,
827 827
 				OurFPM, apiFuncs);
828 828
 	    if (!Codegen.generate()) {
829 829
 		errs() << MODULE << "JIT codegen failed\n";
... ...
@@ -61,11 +61,13 @@ struct cli_lsig_tdb {
61 61
 #endif
62 62
 };
63 63
 
64
+struct cli_bc;
64 65
 struct cli_ac_lsig {
65 66
     uint32_t id;
66 67
     char *logic;
67 68
     const char *virname;
68 69
     struct cli_lsig_tdb tdb;
70
+    const struct cli_bc *bc;
69 71
 };
70 72
 
71 73
 struct cli_matcher {
... ...
@@ -848,7 +848,7 @@ static int lsigattribs(char *attribs, struct cli_lsig_tdb *tdb)
848 848
   } while(0);
849 849
 
850 850
 #define LDB_TOKENS 67
851
-static int load_oneldb(char *buffer, int chkpua, int chkign, struct cl_engine *engine, unsigned int options, const char *dbname, unsigned line, unsigned *sigs)
851
+static int load_oneldb(char *buffer, int chkpua, int chkign, struct cl_engine *engine, unsigned int options, const char *dbname, unsigned line, unsigned *sigs, struct cli_bc *bc)
852 852
 {
853 853
     const char *sig, *virname, *offset, *logic;
854 854
     struct cli_ac_lsig **newtable, *lsig;
... ...
@@ -932,6 +932,7 @@ static int load_oneldb(char *buffer, int chkpua, int chkign, struct cl_engine *e
932 932
 	mpool_free(engine->mempool, lsig);
933 933
 	return CL_EMEM;
934 934
     }
935
+    lsig->bc = bc;
935 936
     newtable[root->ac_lsigs - 1] = lsig;
936 937
     root->ac_lsigtable = newtable;
937 938
 
... ...
@@ -990,7 +991,7 @@ static int cli_loadldb(FILE *fs, struct cl_engine *engine, unsigned int *signo,
990 990
 	ret = load_oneldb(buffer,
991 991
 			  engine->pua_cats && (options & CL_DB_PUA_MODE) && (options & (CL_DB_PUA_INCLUDE | CL_DB_PUA_EXCLUDE)),
992 992
 			  !!engine->ignored,
993
-			  engine, options, dbname, line, &sigs);
993
+			  engine, options, dbname, line, &sigs, NULL);
994 994
 	if (ret)
995 995
 	    break;
996 996
     }
... ...
@@ -1016,6 +1017,11 @@ static int cli_loadcbc(FILE *fs, struct cl_engine *engine, unsigned int *signo,
1016 1016
     int rc;
1017 1017
     struct cli_all_bc *bcs = &engine->bcs;
1018 1018
     struct cli_bc *bc;
1019
+    unsigned sigs = 0;
1020
+
1021
+    if((rc = cli_initroots(engine, options)))
1022
+	return rc;
1023
+
1019 1024
     if(!(engine->dconf->bytecode & BYTECODE_ENGINE_MASK)) {
1020 1025
 	return CL_SUCCESS;
1021 1026
     }
... ...
@@ -1031,10 +1037,18 @@ static int cli_loadcbc(FILE *fs, struct cl_engine *engine, unsigned int *signo,
1031 1031
 	fprintf(stderr,"Unable to load %s bytecode: %s\n", dbname, cl_strerror(rc));
1032 1032
 	return rc;
1033 1033
     }
1034
+    sigs += 2;/* the bytecode itself and the logical sig */
1034 1035
     if (bc->lsig) {
1035 1036
 	cli_dbgmsg("Bytecode %s has logical signature: %s\n", dbname, bc->lsig);
1036
-      	
1037
+	rc = load_oneldb(bc->lsig, 0, 0, engine, options, dbname, 0, &sigs, bc);
1038
+	if (rc != CL_SUCCESS) {
1039
+	    fprintf(stderr,"Problem parsing logical signature %s for bytecode %s: %s\n",
1040
+		    bc->lsig, dbname, cl_strerror(rc));
1041
+	    return rc;
1042
+	}
1037 1043
     }
1044
+    if (signo)
1045
+	*signo += sigs;
1038 1046
     return CL_SUCCESS;
1039 1047
 }
1040 1048