Browse code

Add APIcall support to JIT.

Török Edvin authored on 2009/09/02 21:29:55
Showing 2 changed files
... ...
@@ -479,7 +479,7 @@ static int parseTypes(struct cli_bc *bc, unsigned char *buffer)
479 479
 	return CL_BREAK;
480 480
     }
481 481
     add_static_types(bc);
482
-    for (i=(BC_START_TID - 64);i<bc->num_types;i++) {
482
+    for (i=(BC_START_TID - 65);i<bc->num_types-1;i++) {
483 483
 	struct cli_bc_type *ty = &bc->types[i];
484 484
 	uint8_t t = readFixedNumber(buffer, &offset, len, &ok, 1);
485 485
 	if (!ok) {
... ...
@@ -495,6 +495,10 @@ static int parseTypes(struct cli_bc *bc, unsigned char *buffer)
495 495
 		    cli_errmsg("Error parsing type %u\n", i);
496 496
 		    return CL_EMALFDB;
497 497
 		}
498
+		if (!ty->numElements) {
499
+		    cli_errmsg("Function with no return type? %u\n", i);
500
+		    return CL_EMALFDB;
501
+		}
498 502
 		break;
499 503
 	    case 2:
500 504
 	    case 3:
... ...
@@ -550,7 +554,7 @@ static int parseTypes(struct cli_bc *bc, unsigned char *buffer)
550 550
 static int types_equal(const struct cli_bc *bc, uint16_t *apity2ty, uint16_t tid, uint16_t apitid)
551 551
 {
552 552
     unsigned i;
553
-    const struct cli_bc_type *ty = &bc->types[tid - 63];
553
+    const struct cli_bc_type *ty = &bc->types[tid - 64];
554 554
     const struct cli_bc_type *apity = &cli_apicall_types[apitid];
555 555
     /* If we've already verified type equality, return.
556 556
      * Since we need to check equality of recursive types, we assume types are
... ...
@@ -1,5 +1,5 @@
1 1
 /*
2
- *  Load, and verify ClamAV bytecode.
2
+ *  JIT compile ClamAV bytecode.
3 3
  *
4 4
  *  Copyright (C) 2009 Sourcefire, Inc.
5 5
  *
... ...
@@ -89,23 +89,12 @@ void llvm_error_handler(void *user_data, const std::string &reason)
89 89
     jit_exception_handler();
90 90
 }
91 91
 
92
-class VISIBILITY_HIDDEN LLVMCodegen {
92
+class LLVMTypeMapper {
93 93
 private:
94
-    const struct cli_bc *bc;
95
-    Module *M;
94
+    std::vector<PATypeHolder> TypeMap;
96 95
     LLVMContext &Context;
97
-    FunctionMapTy &compiledFunctions;
98
-    const Type **TypeMap;
99
-    Twine BytecodeID;
100
-    ExecutionEngine *EE;
101
-    TargetFolder Folder;
102
-    IRBuilder<false, TargetFolder> Builder;
103
-    Value **Values;
104
-    FunctionPassManager &PM;
105
-    unsigned numLocals;
106
-    unsigned numArgs;
107
-
108
-    const Type *mapType(uint16_t ty)
96
+    unsigned numTypes;
97
+    const Type *getStatic(uint16_t ty)
109 98
     {
110 99
 	if (!ty)
111 100
 	    return Type::getVoidTy(Context);
... ...
@@ -121,17 +110,83 @@ private:
121 121
 	    case 68:
122 122
 		return PointerType::getUnqual(Type::getInt64Ty(Context));
123 123
 	}
124
+	llvm_unreachable("getStatic");
125
+    }
126
+public:
127
+    LLVMTypeMapper(LLVMContext &Context, const struct cli_bc_type *types,
128
+		   unsigned count) : Context(Context), numTypes(count)
129
+    {
130
+	TypeMap.reserve(count);
131
+	// During recursive type construction pointers to Type* may be
132
+	// invalidated, so we must use a TypeHolder to an Opaque type as a
133
+	// start.
134
+	for (unsigned i=0;i<count;i++) {
135
+	    TypeMap.push_back(OpaqueType::get(Context));
136
+	}
137
+	std::vector<const Type*> Elts;
138
+	for (unsigned i=0;i<count;i++) {
139
+	    const struct cli_bc_type *type = &types[i];
140
+	    Elts.clear();
141
+	    unsigned n = type->kind == DArrayType ? 1 : type->numElements;
142
+	    for (unsigned j=0;j<n;j++) {
143
+		Elts.push_back(get(type->containedTypes[j]));
144
+	    }
145
+	    const Type *Ty;
146
+	    switch (type->kind) {
147
+		case DFunctionType:
148
+		{
149
+		    assert(Elts.size() > 0 && "Function with no return type?");
150
+		    const Type *RetTy = Elts[0];
151
+		    Elts.erase(Elts.begin());
152
+		    Ty = FunctionType::get(RetTy, Elts, false);
153
+		    break;
154
+		}
155
+		case DPointerType:
156
+		    Ty = PointerType::getUnqual(Elts[0]);
157
+		    break;
158
+		case DStructType:
159
+		    Ty = StructType::get(Context, Elts);
160
+		    break;
161
+		case DPackedStructType:
162
+		    Ty = StructType::get(Context, Elts, true);
163
+		    break;
164
+		case DArrayType:
165
+		    Ty = ArrayType::get(Elts[0], type->numElements);
166
+		    break;
167
+	    }
168
+	    // Make the opaque type a concrete type, doing recursive type
169
+	    // unification if needed.
170
+	    cast<OpaqueType>(TypeMap[i].get())->refineAbstractTypeTo(Ty);
171
+	}
172
+    }
173
+
174
+    const Type *get(uint16_t ty)
175
+    {
176
+	if (ty < 69)
177
+	    return getStatic(ty);
124 178
 	ty -= 69;
125
-	// This was validated by libclamav already.
126
-	assert(ty < bc->num_types && "Out of range type ID");
127
-	return TypeMap[ty];
179
+	assert(ty < numTypes && "TypeID out of range");
180
+	return TypeMap[ty].get();
128 181
     }
182
+};
129 183
 
130
-    void convertTypes() {
131
-	for (unsigned j=0;j<bc->num_types;j++) {
132 184
 
133
-	}
134
-    }
185
+class VISIBILITY_HIDDEN LLVMCodegen {
186
+private:
187
+    const struct cli_bc *bc;
188
+    Module *M;
189
+    LLVMContext &Context;
190
+    LLVMTypeMapper *TypeMap;
191
+    Function **apiFuncs;
192
+    FunctionMapTy &compiledFunctions;
193
+    Twine BytecodeID;
194
+    ExecutionEngine *EE;
195
+    TargetFolder Folder;
196
+    IRBuilder<false, TargetFolder> Builder;
197
+    Value **Values;
198
+    FunctionPassManager &PM;
199
+    unsigned numLocals;
200
+    unsigned numArgs;
135 201
 
136 202
     Value *convertOperand(const struct cli_bc_func *func, const Type *Ty, operand_t operand)
137 203
     {
... ...
@@ -167,7 +222,7 @@ private:
167 167
 	switch (w) {
168 168
 	    case 0:
169 169
 	    case 1:
170
-		Ty = w ? Type::getInt8Ty(Context) : 
170
+		Ty = w ? Type::getInt8Ty(Context) :
171 171
 		    Type::getInt1Ty(Context);
172 172
 		v = *(uint8_t*)c;
173 173
 		break;
... ...
@@ -205,18 +260,23 @@ private:
205 205
 	Builder.CreateCondBr(FailCond, Fail, OkBB);
206 206
 	Builder.SetInsertPoint(OkBB);
207 207
     }
208
+
209
+    const Type* mapType(uint16_t typeID)
210
+    {
211
+	return TypeMap->get(typeID);
212
+    }
208 213
 public:
209 214
     LLVMCodegen(const struct cli_bc *bc, Module *M, FunctionMapTy &cFuncs,
210
-		ExecutionEngine *EE, FunctionPassManager &PM)
215
+		ExecutionEngine *EE, FunctionPassManager &PM, Function **apiFuncs)
211 216
 	: bc(bc), M(M), Context(M->getContext()), compiledFunctions(cFuncs), 
212 217
 	BytecodeID("bc"+Twine(bc->id)), EE(EE), 
213
-	Folder(EE->getTargetData(), Context), Builder(Context, Folder), PM(PM) {
214
-	    TypeMap = new const Type*[bc->num_types];
215
-    }
218
+	Folder(EE->getTargetData(), Context), Builder(Context, Folder), PM(PM), 
219
+	apiFuncs(apiFuncs) 
220
+    {}
216 221
 
217 222
     bool generate() {
218 223
 	PrettyStackTraceString Trace(BytecodeID.str().c_str());
219
-	convertTypes();
224
+	TypeMap = new LLVMTypeMapper(Context, bc->types + 4, bc->num_types - 5);
220 225
 
221 226
 	FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context),
222 227
 						    false);
... ...
@@ -246,6 +306,7 @@ public:
246 246
 	for (unsigned j=0;j<bc->num_func;j++) {
247 247
 	    PrettyStackTraceString CrashInfo("Generate LLVM IR");
248 248
 	    const struct cli_bc_func *func = &bc->funcs[j];
249
+
249 250
 	    // Create all BasicBlocks
250 251
 	    Function *F = Functions[j];
251 252
 	    BasicBlock **BB = new BasicBlock*[func->numBB];
... ...
@@ -447,6 +508,19 @@ public:
447 447
 			    Store(inst->dest, Builder.CreateCall(DestF, args.begin(), args.end()));
448 448
 			    break;
449 449
 			}
450
+			case OP_CALL_API:
451
+			{
452
+			    assert(inst->u.ops.funcid < cli_apicall_maxapi && "APICall out of range");
453
+			    const struct cli_apicall *api = &cli_apicalls[inst->u.ops.funcid];
454
+			    std::vector<Value*> args;
455
+			    Function *DestF = apiFuncs[inst->u.ops.funcid];
456
+			    for (unsigned a=0;a<inst->u.ops.numOps;a++) {
457
+				operand_t op = inst->u.ops.ops[a];
458
+				args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a), op));
459
+			    }
460
+			    Store(inst->dest, Builder.CreateCall(DestF, args.begin(), args.end()));
461
+			    break;
462
+			}
450 463
 			default:
451 464
 			    errs() << "JIT doesn't implement opcode " <<
452 465
 				inst->opcode << " yet!\n";
... ...
@@ -466,7 +540,7 @@ public:
466 466
 	}
467 467
 
468 468
 	DEBUG(M->dump());
469
-	delete [] TypeMap;
469
+	delete TypeMap;
470 470
 	FunctionType *Callable = FunctionType::get(Type::getInt32Ty(Context),false);
471 471
 	for (unsigned j=0;j<bc->num_func;j++) {
472 472
 	    const struct cli_bc_func *func = &bc->funcs[j];
... ...
@@ -556,9 +630,31 @@ int cli_bytecode_prepare_jit(struct cli_all_bc *bcs)
556 556
 	// Promote allocas to registers.
557 557
 	OurFPM.add(createPromoteMemoryToRegisterPass());
558 558
 	OurFPM.doInitialization();
559
+
560
+	LLVMTypeMapper apiMap(bcs->engine->Context, cli_apicall_types, cli_apicall_maxtypes);
561
+	Function **apiFuncs = new Function *[cli_apicall_maxapi];
562
+	for (unsigned i=0;i<cli_apicall_maxapi;i++) {
563
+	    const struct cli_apicall *api = &cli_apicalls[i];
564
+	    const FunctionType *FTy = cast<FunctionType>(apiMap.get(69+api->type));
565
+	    Function *F = Function::Create(FTy, Function::ExternalLinkage,
566
+					   api->name, M);
567
+	    void *dest;
568
+	    switch (api->kind) {
569
+		case 0:
570
+		    dest = (void*)cli_apicalls0[api->idx];
571
+		    break;
572
+		case 1:
573
+		    dest = (void*)cli_apicalls1[api->idx];
574
+		    break;
575
+	    }
576
+	    EE->addGlobalMapping(F, dest);
577
+	    apiFuncs[i] = F;
578
+	}
579
+
559 580
 	for (unsigned i=0;i<bcs->count;i++) {
560 581
 	    const struct cli_bc *bc = &bcs->all_bcs[i];
561
-	    LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE, OurFPM);
582
+	    LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE, 
583
+				OurFPM, apiFuncs);
562 584
 	    if (!Codegen.generate()) {
563 585
 		errs() << MODULE << "JIT codegen failed\n";
564 586
 		return CL_EBYTECODE;
... ...
@@ -574,6 +670,7 @@ int cli_bytecode_prepare_jit(struct cli_all_bc *bcs)
574 574
 	    if (!Fn->isDeclaration())
575 575
 		EE->getPointerToFunction(Fn);
576 576
 	}
577
+	delete [] apiFuncs;
577 578
     }
578 579
     return -1;
579 580
   } catch (std::bad_alloc &badalloc) {