Browse code

more conversion.

Török Edvin authored on 2009/08/28 02:22:50
Showing 1 changed files
... ...
@@ -28,6 +28,7 @@
28 28
 #include "llvm/ExecutionEngine/JITEventListener.h"
29 29
 #include "llvm/LLVMContext.h"
30 30
 #include "llvm/Module.h"
31
+#include "llvm/PassManager.h"
31 32
 #include "llvm/ModuleProvider.h"
32 33
 #include "llvm/Support/Compiler.h"
33 34
 #include "llvm/Support/CommandLine.h"
... ...
@@ -40,7 +41,10 @@
40 40
 #include "llvm/System/Signals.h"
41 41
 #include "llvm/System/Threading.h"
42 42
 #include "llvm/Target/TargetSelect.h"
43
+#include "llvm/Target/TargetData.h"
43 44
 #include "llvm/Support/TargetFolder.h"
45
+#include "llvm/Transforms/Scalar.h"
46
+#include "llvm/Analysis/Verifier.h"
44 47
 #include <cstdlib>
45 48
 #include <new>
46 49
 
... ...
@@ -81,12 +85,18 @@ private:
81 81
     const Type **TypeMap;
82 82
     Twine BytecodeID;
83 83
     ExecutionEngine *EE;
84
+    TargetFolder Folder;
85
+    IRBuilder<false, TargetFolder> Builder;
86
+    Value **Values;
87
+    FunctionPassManager &PM;
88
+    unsigned numLocals;
89
+    unsigned numArgs;
84 90
 
85 91
     const Type *mapType(uint16_t ty)
86 92
     {
87 93
 	if (!ty)
88 94
 	    return Type::getVoidTy(Context);
89
-	if (ty < 64)
95
+	if (ty <= 64)
90 96
 	    return IntegerType::get(Context, ty);
91 97
 	switch (ty) {
92 98
 	    case 65:
... ...
@@ -110,56 +120,80 @@ private:
110 110
 	}
111 111
     }
112 112
 
113
-    Value *convertOperand(const struct cli_bc_func *func, 
113
+    Value *convertOperand(const struct cli_bc_func *func, const Type *Ty, operand_t operand)
114
+    {
115
+	unsigned map[] = {0, 1, 2, 3, 3, 4, 4, 4, 4};
116
+	if (operand < func->numArgs)
117
+	    return Values[operand];
118
+	if (operand < func->numValues)
119
+	    return Builder.CreateLoad(Values[operand]);
120
+	unsigned w = (Ty->getPrimitiveSizeInBits()+7)/8;
121
+	return convertOperand(func, map[w], operand);
122
+    }
123
+
124
+    Value *convertOperand(const struct cli_bc_func *func,
114 125
 			  const struct cli_bc_inst *inst,  operand_t operand)
115 126
     {
116
-	if (operand >= func->numValues) {
117
-	    // Constant
118
-	    operand -= func->numValues;
119
-	    // This was already validated by libclamav.
120
-	    assert(operand < func->numConstants && "Constant out of range");
121
-	    uint64_t *c = &func->constants[operand-func->numValues];
122
-	    uint64_t v;
123
-	    const Type *Ty;
124
-	    switch (inst->interp_op%5) {
125
-		case 0:
126
-		case 1:
127
-		    Ty = (inst->interp_op%5) ? Type::getInt8Ty(Context) : 
128
-			Type::getInt1Ty(Context);
129
-		    v = *(uint8_t*)c;
130
-		    break;
131
-		case 2:
132
-		    Ty = Type::getInt16Ty(Context);
133
-		    v = *(uint16_t*)c;
134
-		    break;
135
-		case 3:
136
-		    Ty = Type::getInt32Ty(Context);
137
-		    v = *(uint32_t*)c;
138
-		    break;
139
-		case 4:
140
-		    Ty = Type::getInt64Ty(Context);
141
-		    v = *(uint64_t*)c;
142
-		    break;
143
-	    }
144
-	    return ConstantInt::get(Ty, v);
127
+	return convertOperand(func, inst->interp_op%5, operand);
128
+    }
129
+
130
+    Value *convertOperand(const struct cli_bc_func *func,
131
+			  unsigned w, operand_t operand) {
132
+	if (operand < func->numArgs)
133
+	    return Values[operand];
134
+	if (operand < func->numValues)
135
+	    return Builder.CreateLoad(Values[operand]);
136
+
137
+	// Constant
138
+	operand -= func->numValues;
139
+	// This was already validated by libclamav.
140
+       	assert(operand < func->numConstants && "Constant out of range");
141
+	uint64_t *c = &func->constants[operand];
142
+	uint64_t v;
143
+	const Type *Ty;
144
+	switch (w) {
145
+	    case 0:
146
+	    case 1:
147
+		Ty = w ? Type::getInt8Ty(Context) : 
148
+		    Type::getInt1Ty(Context);
149
+		v = *(uint8_t*)c;
150
+		break;
151
+	    case 2:
152
+		Ty = Type::getInt16Ty(Context);
153
+		v = *(uint16_t*)c;
154
+		break;
155
+	    case 3:
156
+		Ty = Type::getInt32Ty(Context);
157
+		v = *(uint32_t*)c;
158
+		break;
159
+	    case 4:
160
+		Ty = Type::getInt64Ty(Context);
161
+		v = *(uint64_t*)c;
162
+		break;
145 163
 	}
146
-	assert(0 && "Not implemented yet");
164
+	return ConstantInt::get(Ty, v);
165
+    }
166
+
167
+    void Store(uint16_t dest, Value *V)
168
+    {
169
+	assert(dest >= numArgs && dest < numLocals+numArgs && "Instruction destination out of range");
170
+	Builder.CreateStore(V, Values[dest]);
147 171
     }
148 172
 public:
149 173
     LLVMCodegen(const struct cli_bc *bc, Module *M, FunctionMapTy &cFuncs,
150
-		ExecutionEngine *EE)
174
+		ExecutionEngine *EE, FunctionPassManager &PM)
151 175
 	: bc(bc), M(M), Context(M->getContext()), compiledFunctions(cFuncs), 
152
-	BytecodeID("bc"+Twine(bc->id)), EE(EE) {
176
+	BytecodeID("bc"+Twine(bc->id)), EE(EE), 
177
+	Folder(EE->getTargetData(), Context), Builder(Context, Folder), PM(PM) {
153 178
 	    TypeMap = new const Type*[bc->num_types];
154 179
     }
155 180
 
156
-    void generate() {
181
+    bool generate() {
157 182
 	PrettyStackTraceString Trace(BytecodeID.str().c_str());
158 183
 	convertTypes();
159
-	TargetFolder Folder(EE->getTargetData(), Context);
160
-	IRBuilder<false, TargetFolder> Builder(Context, Folder);
184
+	Function **Functions = new Function*[bc->num_func];
161 185
 	for (unsigned j=0;j<bc->num_func;j++) {
162
-	    PrettyStackTraceString CrashInfo("Generate LLVM IR");
186
+	    PrettyStackTraceString CrashInfo("Generate LLVM IR functions");
163 187
 	    // Create LLVM IR Function
164 188
 	    const struct cli_bc_func *func = &bc->funcs[j];
165 189
 	    std::vector<const Type*> argTypes;
... ...
@@ -169,36 +203,217 @@ public:
169 169
 	    const Type *RetTy = mapType(func->returnType);
170 170
 	    llvm::FunctionType *FTy =  FunctionType::get(RetTy, argTypes,
171 171
 							 false);
172
-	    Function *F = Function::Create(FTy, Function::InternalLinkage, 
172
+	    Functions[j] = Function::Create(FTy, Function::InternalLinkage, 
173 173
 					   BytecodeID+"f"+Twine(j), M);
174
-
174
+	}
175
+	for (unsigned j=0;j<bc->num_func;j++) {
176
+	    PrettyStackTraceString CrashInfo("Generate LLVM IR");
177
+	    const struct cli_bc_func *func = &bc->funcs[j];
175 178
 	    // Create all BasicBlocks
179
+	    Function *F = Functions[j];
176 180
 	    BasicBlock **BB = new BasicBlock*[func->numBB];
177 181
 	    for (unsigned i=0;i<func->numBB;i++) {
178 182
 		BB[i] = BasicBlock::Create(Context, "", F);
179 183
 	    }
180 184
 
185
+	    Values = new Value*[func->numValues];
186
+	    Builder.SetInsertPoint(BB[0]);
187
+	    Function::arg_iterator I = F->arg_begin();
188
+	    for (unsigned i=0;i<func->numArgs; i++) {
189
+		assert(I != F->arg_end());
190
+		Values[i] = &*I;
191
+		++I;
192
+	    }
193
+	    for (unsigned i=func->numArgs;i<func->numValues;i++) {
194
+		Values[i] = Builder.CreateAlloca(mapType(func->types[i]));
195
+	    }
196
+	    numLocals = func->numLocals;
197
+	    numArgs = func->numArgs;
181 198
 	    // Generate LLVM IR for each BB
182 199
 	    for (unsigned i=0;i<func->numBB;i++) {
183 200
 		const struct cli_bc_bb *bb = &func->BB[i];
184 201
 		Builder.SetInsertPoint(BB[i]);
185 202
 		for (unsigned j=0;j<bb->numInsts;j++) {
186
-		    const struct cli_bc_inst *inst = &bb->insts[i];
203
+		    const struct cli_bc_inst *inst = &bb->insts[j];
204
+		    Value *Op0, *Op1, *Op2;
205
+		    // libclamav has already validated this.
206
+		    assert(inst->opcode < OP_INVALID && "Invalid opcode");
207
+		    switch (inst->opcode) {
208
+			case OP_JMP:
209
+			case OP_BRANCH:
210
+			case OP_CALL_API:
211
+			case OP_CALL_DIRECT:
212
+			case OP_ZEXT:
213
+			case OP_SEXT:
214
+			case OP_TRUNC:
215
+			    // these instructions represents operands differently
216
+			    break;
217
+			default:
218
+			    switch (operand_counts[inst->opcode]) {
219
+				case 1:
220
+				    Op0 = convertOperand(func, inst, inst->u.unaryop);
221
+				    break;
222
+				case 2:
223
+				    Op0 = convertOperand(func, inst, inst->u.binop[0]);
224
+				    Op1 = convertOperand(func, inst, inst->u.binop[1]);
225
+				    break;
226
+				case 3:
227
+				    Op0 = convertOperand(func, inst, inst->u.three[0]);
228
+				    Op1 = convertOperand(func, inst, inst->u.three[1]);
229
+				    Op2 = convertOperand(func, inst, inst->u.three[2]);
230
+				    break;
231
+			    }
232
+		    }
187 233
 
188 234
 		    switch (inst->opcode) {
235
+			case OP_ADD:
236
+			    Store(inst->dest, Builder.CreateAdd(Op0, Op1));
237
+			    break;
238
+			case OP_SUB:
239
+			    Store(inst->dest, Builder.CreateSub(Op0, Op1));
240
+			    break;
241
+			case OP_MUL:
242
+			    Store(inst->dest, Builder.CreateMul(Op0, Op1));
243
+			    break;
244
+			case OP_UDIV:
245
+			    Store(inst->dest, Builder.CreateUDiv(Op0, Op1));
246
+			    break;
247
+			case OP_SDIV:
248
+			    Store(inst->dest, Builder.CreateSDiv(Op0, Op1));
249
+			    break;
250
+			case OP_UREM:
251
+			    Store(inst->dest, Builder.CreateURem(Op0, Op1));
252
+			    break;
253
+			case OP_SREM:
254
+			    Store(inst->dest, Builder.CreateSRem(Op0, Op1));
255
+			    break;
256
+			case OP_SHL:
257
+			    Store(inst->dest, Builder.CreateShl(Op0, Op1));
258
+			    break;
259
+			case OP_LSHR:
260
+			    Store(inst->dest, Builder.CreateLShr(Op0, Op1));
261
+			    break;
262
+			case OP_ASHR:
263
+			    Store(inst->dest, Builder.CreateAShr(Op0, Op1));
264
+			    break;
265
+			case OP_AND:
266
+			    Store(inst->dest, Builder.CreateAnd(Op0, Op1));
267
+			    break;
268
+			case OP_OR:
269
+			    Store(inst->dest, Builder.CreateOr(Op0, Op1));
270
+			    break;
271
+			case OP_XOR:
272
+			    Store(inst->dest, Builder.CreateXor(Op0, Op1));
273
+			    break;
274
+			case OP_TRUNC:
275
+			{
276
+			    Value *Src = convertOperand(func, inst, inst->u.cast.source);
277
+			    const Type *Ty = mapType(func->types[inst->dest]);
278
+			    Store(inst->dest, Builder.CreateTrunc(Src,  Ty));
279
+			    break;
280
+			}
281
+			case OP_ZEXT:
282
+			{
283
+			    Value *Src = convertOperand(func, inst, inst->u.cast.source);
284
+			    const Type *Ty = mapType(func->types[inst->dest]);
285
+			    Store(inst->dest, Builder.CreateZExt(Src,  Ty));
286
+			    break;
287
+			}
288
+			case OP_SEXT:
289
+			{
290
+			    Value *Src = convertOperand(func, inst, inst->u.cast.source);
291
+			    const Type *Ty = mapType(func->types[inst->dest]);
292
+			    Store(inst->dest, Builder.CreateSExt(Src,  Ty));
293
+			    break;
294
+			}
295
+			case OP_BRANCH:
296
+			{
297
+			    Value *Cond = convertOperand(func, inst, inst->u.branch.condition);
298
+			    BasicBlock *True = BB[inst->u.branch.br_true];
299
+			    BasicBlock *False = BB[inst->u.branch.br_false];
300
+			    if (Cond->getType() != Type::getInt1Ty(Context)) {
301
+				errs() << MODULE << "type mismatch in condition\n";
302
+				return false;
303
+			    }
304
+			    Builder.CreateCondBr(Cond, True, False);
305
+			    break;
306
+			}
307
+			case OP_JMP:
308
+			{
309
+			    BasicBlock *Jmp = BB[inst->u.jump];
310
+			    Builder.CreateBr(Jmp);
311
+			    break;
312
+			}
189 313
 			case OP_RET:
190
-			    Value *V = convertOperand(func, inst, inst->u.unaryop);
191
-			    Builder.CreateRet(V);
314
+			    Builder.CreateRet(Op0);
315
+			    break;
316
+			case OP_ICMP_EQ:
317
+			    Store(inst->dest, Builder.CreateICmpEQ(Op0, Op1));
318
+			    break;
319
+			case OP_ICMP_NE:
320
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
321
+			    break;
322
+			case OP_ICMP_UGT:
323
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
324
+			    break;
325
+			case OP_ICMP_UGE:
326
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
192 327
 			    break;
328
+			case OP_ICMP_ULT:
329
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
330
+			    break;
331
+			case OP_ICMP_ULE:
332
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
333
+			    break;
334
+			case OP_ICMP_SGT:
335
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
336
+			    break;
337
+			case OP_ICMP_SGE:
338
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
339
+			    break;
340
+			case OP_ICMP_SLT:
341
+			    Store(inst->dest, Builder.CreateICmpNE(Op0, Op1));
342
+			    break;
343
+			case OP_SELECT:
344
+			    Store(inst->dest, Builder.CreateSelect(Op0, Op1, Op2));
345
+			    break;
346
+			case OP_COPY:
347
+			    Builder.CreateStore(Op0, Op1);
348
+			    break;
349
+			case OP_CALL_DIRECT:
350
+			{
351
+			    Function *DestF = Functions[inst->u.ops.funcid];
352
+			    SmallVector<Value*, 2> args;
353
+			    for (unsigned a=0;a<inst->u.ops.numOps;a++) {
354
+				operand_t op = inst->u.ops.ops[a];
355
+				args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a), op));
356
+			    }
357
+			    Store(inst->dest, Builder.CreateCall(DestF, args.begin(), args.end()));
358
+			    break;
359
+			}
360
+			default:
361
+			    assert(0 && "Not implemented yet");
193 362
 		    }
194 363
 		}
195 364
 	    }
196 365
 
366
+	    if (verifyFunction(*F, PrintMessageAction)) {
367
+		errs() << MODULE << "Verification failed\n";
368
+		// verification failed
369
+		return false;
370
+	    }
371
+	    PM.run(*F);
372
+	    delete [] Values;
373
+	}
374
+
375
+	for (unsigned j=0;j<bc->num_func;j++) {
376
+	    const struct cli_bc_func *func = &bc->funcs[j];
197 377
 	    PrettyStackTraceString CrashInfo2("Native machine codegen");
198 378
 	    // Codegen current function as executable machine code.
199
-	    compiledFunctions[func] = EE->getPointerToFunction(F);
379
+	    compiledFunctions[func] = EE->getPointerToFunction(Functions[j]);
200 380
 	}
201
-	delete TypeMap;
381
+	delete [] TypeMap;
382
+	return true;
202 383
     }
203 384
 };
204 385
 }
... ...
@@ -214,10 +429,11 @@ int cli_bytecode_prepare_jit(struct cli_all_bc *bcs)
214 214
   // LLVM itself never throws exceptions, but operator new may throw bad_alloc
215 215
   try {
216 216
     Module *M = new Module("ClamAV jit module", bcs->engine->Context);
217
+    ExistingModuleProvider *MP = new ExistingModuleProvider(M);
217 218
     {
218 219
 	// Create the JIT.
219 220
 	std::string ErrorMsg;
220
-	EngineBuilder builder(M);
221
+	EngineBuilder builder(MP);
221 222
 	builder.setErrorStr(&ErrorMsg);
222 223
 	builder.setEngineKind(EngineKind::JIT);
223 224
 	builder.setOptLevel(CodeGenOpt::Aggressive);
... ...
@@ -233,10 +449,22 @@ int cli_bytecode_prepare_jit(struct cli_all_bc *bcs)
233 233
 	EE->RegisterJITEventListener(createOProfileJITEventListener());
234 234
 	EE->DisableLazyCompilation();
235 235
 
236
+	FunctionPassManager OurFPM(MP);
237
+	// Set up the optimizer pipeline.  Start with registering info about how
238
+	// the target lays out data structures.
239
+	OurFPM.add(new TargetData(*EE->getTargetData()));
240
+	// Promote allocas to registers.
241
+	OurFPM.add(createPromoteMemoryToRegisterPass());
242
+	// Do simple "peephole" optimizations and bit-twiddling optzns.
243
+	OurFPM.add(createInstructionCombiningPass());
244
+	OurFPM.doInitialization();
236 245
 	for (unsigned i=0;i<bcs->count;i++) {
237 246
 	    const struct cli_bc *bc = &bcs->all_bcs[i];
238
-	    LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE);
239
-	    Codegen.generate();
247
+	    LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE, OurFPM);
248
+	    if (!Codegen.generate()) {
249
+		errs() << MODULE << "JIT codegen failed\n";
250
+		return CL_EBYTECODE;
251
+	    }
240 252
 	}
241 253
 
242 254
 	// compile all functions now, not lazily!
... ...
@@ -283,7 +511,7 @@ int cli_bytecode_done_jit(struct cli_all_bc *bcs)
283 283
 {
284 284
     if (bcs->engine->EE)
285 285
 	delete bcs->engine->EE;
286
-    free(bcs->engine);
286
+    delete bcs->engine;
287 287
     bcs->engine = 0;
288 288
     return 0;
289 289
 }