Browse code

ctx param to APIs

Török Edvin authored on 2009/09/04 22:24:52
Showing 6 changed files
... ...
@@ -335,7 +335,8 @@ libclamav_la_SOURCES = \
335 335
 	type_desc.h \
336 336
 	bytecode_api.c \
337 337
 	bytecode_api_decl.c \
338
-	bytecode_api.h
338
+	bytecode_api.h \
339
+	bytecode_api_impl.h
339 340
 
340 341
 if !LINK_TOMMATH
341 342
 libclamav_la_SOURCES += bignum.c \
... ...
@@ -3,6 +3,8 @@
3 3
  *
4 4
  *  Copyright (C) 2009 Sourcefire, Inc.
5 5
  *
6
+ *  Authors: Török Edvin
7
+ *
6 8
  *  This program is free software; you can redistribute it and/or modify
7 9
  *  it under the terms of the GNU General Public License version 2 as
8 10
  *  published by the Free Software Foundation.
... ...
@@ -21,13 +23,14 @@
21 21
 #include "cltypes.h"
22 22
 #include "type_desc.h"
23 23
 #include "bytecode_api.h"
24
+#include "bytecode_api_impl.h"
24 25
 
25
-int32_t cli_bcapi_test0(struct foo* s, uint32_t u)
26
+uint32_t cli_bcapi_test0(struct cli_bc_ctx *ctx, struct foo* s, uint32_t u)
26 27
 {
27 28
     return (s && s->nxt == s && u == 0xdeadbeef) ? 0x12345678 : 0x55;
28 29
 }
29 30
 
30
-int32_t cli_bcapi_test1(int32_t a, int32_t b)
31
+uint32_t cli_bcapi_test1(struct cli_bc_ctx *ctx, uint32_t a, uint32_t b)
31 32
 {
32 33
     return (a==0xf00dbeef && b==0xbeeff00d) ? 0x12345678 : 0x55;
33 34
 }
... ...
@@ -30,5 +30,5 @@ struct foo {
30 30
     struct foo *nxt;
31 31
 };
32 32
 
33
-int32_t test0(struct foo*, uint32_t);
34
-int32_t test1(int32_t, int32_t);
33
+uint32_t test0(struct foo*, uint32_t);
34
+uint32_t test1(uint32_t, uint32_t);
... ...
@@ -22,6 +22,7 @@
22 22
 #include "cltypes.h"
23 23
 #include "type_desc.h"
24 24
 #include "bytecode_api.h"
25
+#include "bytecode_api_impl.h"
25 26
 
26 27
 uint32_t cli_bcapi_test0(struct cli_bc_ctx *ctx, struct foo*, uint32_t);
27 28
 uint32_t cli_bcapi_test1(struct cli_bc_ctx *ctx, uint32_t, uint32_t);
... ...
@@ -21,6 +21,7 @@
21 21
  */
22 22
 #define DEBUG_TYPE "clamavjit"
23 23
 #include "llvm/ADT/DenseMap.h"
24
+#include "llvm/CallingConv.h"
24 25
 #include "llvm/DerivedTypes.h"
25 26
 #include "llvm/Function.h"
26 27
 #include "llvm/ExecutionEngine/ExecutionEngine.h"
... ...
@@ -114,7 +115,7 @@ private:
114 114
     }
115 115
 public:
116 116
     LLVMTypeMapper(LLVMContext &Context, const struct cli_bc_type *types,
117
-		   unsigned count) : Context(Context), numTypes(count)
117
+		   unsigned count, const Type *Hidden=0) : Context(Context), numTypes(count)
118 118
     {
119 119
 	TypeMap.reserve(count);
120 120
 	// During recursive type construction pointers to Type* may be
... ...
@@ -137,7 +138,10 @@ public:
137 137
 		{
138 138
 		    assert(Elts.size() > 0 && "Function with no return type?");
139 139
 		    const Type *RetTy = Elts[0];
140
-		    Elts.erase(Elts.begin());
140
+		    if (Hidden)
141
+			Elts[0] = Hidden;
142
+		    else
143
+			Elts.erase(Elts.begin());
141 144
 		    Ty = FunctionType::get(RetTy, Elts, false);
142 145
 		    break;
143 146
 		}
... ...
@@ -281,10 +285,10 @@ private:
281 281
 public:
282 282
     LLVMCodegen(const struct cli_bc *bc, Module *M, FunctionMapTy &cFuncs,
283 283
 		ExecutionEngine *EE, FunctionPassManager &PM, Function **apiFuncs)
284
-	: bc(bc), M(M), Context(M->getContext()), compiledFunctions(cFuncs), 
285
-	BytecodeID("bc"+Twine(bc->id)), EE(EE), 
286
-	Folder(EE->getTargetData(), Context), Builder(Context, Folder), PM(PM), 
287
-	apiFuncs(apiFuncs) 
284
+	: bc(bc), M(M), Context(M->getContext()), compiledFunctions(cFuncs),
285
+	BytecodeID("bc"+Twine(bc->id)), EE(EE),
286
+	Folder(EE->getTargetData(), Context), Builder(Context, Folder), PM(PM),
287
+	apiFuncs(apiFuncs)
288 288
     {}
289 289
 
290 290
     bool generate() {
... ...
@@ -300,21 +304,26 @@ public:
300 300
 	FHandler->addFnAttr(Attribute::NoInline);
301 301
 	EE->addGlobalMapping(FHandler, (void*)jit_exception_handler); 
302 302
 
303
+	// The hidden ctx param to all functions
304
+	const Type *HiddenCtx = PointerType::getUnqual(Type::getInt8Ty(Context));
305
+
303 306
 	Function **Functions = new Function*[bc->num_func];
304 307
 	for (unsigned j=0;j<bc->num_func;j++) {
305 308
 	    PrettyStackTraceString CrashInfo("Generate LLVM IR functions");
306 309
 	    // Create LLVM IR Function
307 310
 	    const struct cli_bc_func *func = &bc->funcs[j];
308 311
 	    std::vector<const Type*> argTypes;
312
+	    argTypes.push_back(HiddenCtx);
309 313
 	    for (unsigned a=0;a<func->numArgs;a++) {
310 314
 		argTypes.push_back(mapType(func->types[a]));
311 315
 	    }
312 316
 	    const Type *RetTy = mapType(func->returnType);
313 317
 	    FunctionType *FTy =  FunctionType::get(RetTy, argTypes,
314 318
 							 false);
315
-	    Functions[j] = Function::Create(FTy, Function::InternalLinkage, 
319
+	    Functions[j] = Function::Create(FTy, Function::InternalLinkage,
316 320
 					   BytecodeID+"f"+Twine(j), M);
317 321
 	    Functions[j]->setDoesNotThrow();
322
+	    Functions[j]->setCallingConv(CallingConv::Fast);
318 323
 	}
319 324
 	const Type *I32Ty = Type::getInt32Ty(Context);
320 325
 	for (unsigned j=0;j<bc->num_func;j++) {
... ...
@@ -332,6 +341,8 @@ public:
332 332
 	    Values = new Value*[func->numValues];
333 333
 	    Builder.SetInsertPoint(BB[0]);
334 334
 	    Function::arg_iterator I = F->arg_begin();
335
+	    assert(F->arg_size() == func->numArgs + 1 && "Mismatched args");
336
+	    ++I;
335 337
 	    for (unsigned i=0;i<func->numArgs; i++) {
336 338
 		assert(I != F->arg_end());
337 339
 		Values[i] = &*I;
... ...
@@ -524,11 +535,14 @@ public:
524 524
 			{
525 525
 			    Function *DestF = Functions[inst->u.ops.funcid];
526 526
 			    SmallVector<Value*, 2> args;
527
+			    args.push_back(&*F->arg_begin()); // pass hidden arg
527 528
 			    for (unsigned a=0;a<inst->u.ops.numOps;a++) {
528 529
 				operand_t op = inst->u.ops.ops[a];
529
-				args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a), op));
530
+				args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a+1), op));
530 531
 			    }
531
-			    Store(inst->dest, Builder.CreateCall(DestF, args.begin(), args.end()));
532
+			    CallInst *CI = Builder.CreateCall(DestF, args.begin(), args.end());
533
+			    CI->setCallingConv(CallingConv::Fast);
534
+			    Store(inst->dest, CI);
532 535
 			    break;
533 536
 			}
534 537
 			case OP_CALL_API:
... ...
@@ -537,9 +551,10 @@ public:
537 537
 			    const struct cli_apicall *api = &cli_apicalls[inst->u.ops.funcid];
538 538
 			    std::vector<Value*> args;
539 539
 			    Function *DestF = apiFuncs[inst->u.ops.funcid];
540
+			    args.push_back(&*F->arg_begin()); // pass hidden arg
540 541
 			    for (unsigned a=0;a<inst->u.ops.numOps;a++) {
541 542
 				operand_t op = inst->u.ops.ops[a];
542
-				args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a), op));
543
+				args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a+1), op));
543 544
 			    }
544 545
 			    Store(inst->dest, Builder.CreateCall(DestF, args.begin(), args.end()));
545 546
 			    break;
... ...
@@ -601,16 +616,38 @@ public:
601 601
 
602 602
 	DEBUG(M->dump());
603 603
 	delete TypeMap;
604
-	FunctionType *Callable = FunctionType::get(Type::getInt32Ty(Context),false);
604
+	std::vector<const Type*> args;
605
+	args.push_back(PointerType::getUnqual(Type::getInt8Ty(Context)));
606
+	FunctionType *Callable = FunctionType::get(Type::getInt32Ty(Context),
607
+						   args, false);
605 608
 	for (unsigned j=0;j<bc->num_func;j++) {
606 609
 	    const struct cli_bc_func *func = &bc->funcs[j];
607 610
 	    PrettyStackTraceString CrashInfo2("Native machine codegen");
608
-	    // Codegen current function as executable machine code.
609
-	    void *code = EE->getPointerToFunction(Functions[j]);
610 611
 
611 612
 	    // If prototype matches, add to callable functions
612
-	    if (Functions[j]->getFunctionType() == Callable)
613
+	    if (Functions[j]->getFunctionType() == Callable) {
614
+		// All functions have the Fast calling convention, however
615
+		// entrypoint can only be C, emit wrapper
616
+		Function *F = Function::Create(Functions[j]->getFunctionType(),
617
+					       Function::ExternalLinkage,
618
+					       Functions[j]->getName()+"_wrap", M);
619
+		F->setDoesNotThrow();
620
+		BasicBlock *BB = BasicBlock::Create(Context, "", F);
621
+		std::vector<Value*> args;
622
+		for (Function::arg_iterator J=F->arg_begin(),
623
+		     JE=F->arg_end(); J != JE; ++JE) {
624
+		    args.push_back(&*J);
625
+		}
626
+		CallInst *CI = CallInst::Create(Functions[j], args.begin(), args.end(), "", BB);
627
+		CI->setCallingConv(CallingConv::Fast);
628
+		ReturnInst::Create(Context, CI, BB);
629
+
630
+		if (verifyFunction(*F, PrintMessageAction));
631
+		// Codegen current function as executable machine code.
632
+		void *code = EE->getPointerToFunction(F);
633
+
613 634
 		compiledFunctions[func] = code;
635
+	    }
614 636
 	}
615 637
 	delete [] Functions;
616 638
 	return true;
... ...
@@ -631,7 +668,7 @@ int cli_vm_execute_jit(const struct cli_all_bc *bcs, struct cli_bc_ctx *ctx,
631 631
     if (setjmp(env) == 0) {
632 632
 	// setup exception handler to longjmp back here
633 633
 	ExceptionReturn.set(&env);
634
-	uint32_t result = ((uint32_t (*)(void))code)();
634
+	uint32_t result = ((uint32_t (*)(struct cli_bc_ctx *))code)(ctx);
635 635
 	*(uint32_t*)ctx->values = result;
636 636
 	return 0;
637 637
     }
... ...
@@ -693,7 +730,10 @@ int cli_bytecode_prepare_jit(struct cli_all_bc *bcs)
693 693
 	OurFPM.add(createDeadCodeEliminationPass());
694 694
 	OurFPM.doInitialization();
695 695
 
696
-	LLVMTypeMapper apiMap(bcs->engine->Context, cli_apicall_types, cli_apicall_maxtypes);
696
+	//TODO: create a wrapper that calls pthread_getspecific
697
+	const Type *HiddenCtx = PointerType::getUnqual(Type::getInt8Ty(bcs->engine->Context));
698
+
699
+	LLVMTypeMapper apiMap(bcs->engine->Context, cli_apicall_types, cli_apicall_maxtypes, HiddenCtx);
697 700
 	Function **apiFuncs = new Function *[cli_apicall_maxapi];
698 701
 	for (unsigned i=0;i<cli_apicall_maxapi;i++) {
699 702
 	    const struct cli_apicall *api = &cli_apicalls[i];
... ...
@@ -40,8 +40,8 @@ struct cli_bc_type {
40 40
     unsigned align;
41 41
 };
42 42
 
43
-typedef int32_t (*cli_apicall_int2)(struct cli_bc_ctx *, int32_t, int32_t);
44
-typedef int32_t (*cli_apicall_pointer)(struct cli_bc_ctx *, void*, uint32_t);
43
+typedef uint32_t (*cli_apicall_int2)(struct cli_bc_ctx *, uint32_t, uint32_t);
44
+typedef uint32_t (*cli_apicall_pointer)(struct cli_bc_ctx *, void*, uint32_t);
45 45
 
46 46
 struct cli_apicall {
47 47
     const char *name;