/* * Load, and verify ClamAV bytecode. * * Copyright (C) 2009 Sourcefire, Inc. * * Authors: Török Edvin * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as * published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, * MA 02110-1301, USA. */ #define DEBUG_TYPE "clamavjit" #include "llvm/ADT/DenseMap.h" #include "llvm/DerivedTypes.h" #include "llvm/Function.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JIT.h" #include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/LLVMContext.h" #include "llvm/Module.h" #include "llvm/PassManager.h" #include "llvm/ModuleProvider.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DataTypes.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/System/Signals.h" #include "llvm/System/Threading.h" #include "llvm/Target/TargetSelect.h" #include "llvm/Target/TargetData.h" #include "llvm/Support/TargetFolder.h" #include "llvm/Analysis/Verifier.h" #include "llvm/Transforms/Scalar.h" #include "llvm/System/ThreadLocal.h" #include #include #include #include "llvm/Config/config.h" #ifndef LLVM_MULTITHREADED #error "Multithreading support must be available to LLVM!" #endif #include "clamav.h" #include "clambc.h" #include "bytecode_priv.h" #include "bytecode.h" #define MODULE "libclamav JIT: " using namespace llvm; typedef DenseMap FunctionMapTy; struct cli_bcengine { ExecutionEngine *EE; LLVMContext Context; FunctionMapTy compiledFunctions; }; namespace { static sys::ThreadLocal ExceptionReturn; void do_shutdown() { llvm_shutdown(); } static void NORETURN jit_exception_handler(void) { longjmp(*const_cast(ExceptionReturn.get()), 1); } void llvm_error_handler(void *user_data, const std::string &reason) { errs() << reason; jit_exception_handler(); } class VISIBILITY_HIDDEN LLVMCodegen { private: const struct cli_bc *bc; Module *M; LLVMContext &Context; FunctionMapTy &compiledFunctions; const Type **TypeMap; Twine BytecodeID; ExecutionEngine *EE; TargetFolder Folder; IRBuilder Builder; Value **Values; FunctionPassManager &PM; unsigned numLocals; unsigned numArgs; const Type *mapType(uint16_t ty) { if (!ty) return Type::getVoidTy(Context); if (ty <= 64) return IntegerType::get(Context, ty); switch (ty) { case 65: return PointerType::getUnqual(Type::getInt8Ty(Context)); case 66: return PointerType::getUnqual(Type::getInt16Ty(Context)); case 67: return PointerType::getUnqual(Type::getInt32Ty(Context)); case 68: return PointerType::getUnqual(Type::getInt64Ty(Context)); } ty -= 69; // This was validated by libclamav already. assert(ty < bc->num_types && "Out of range type ID"); return TypeMap[ty]; } void convertTypes() { for (unsigned j=0;jnum_types;j++) { } } Value *convertOperand(const struct cli_bc_func *func, const Type *Ty, operand_t operand) { unsigned map[] = {0, 1, 2, 3, 3, 4, 4, 4, 4}; if (operand < func->numArgs) return Values[operand]; if (operand < func->numValues) return Builder.CreateLoad(Values[operand]); unsigned w = (Ty->getPrimitiveSizeInBits()+7)/8; return convertOperand(func, map[w], operand); } Value *convertOperand(const struct cli_bc_func *func, const struct cli_bc_inst *inst, operand_t operand) { return convertOperand(func, inst->interp_op%5, operand); } Value *convertOperand(const struct cli_bc_func *func, unsigned w, operand_t operand) { if (operand < func->numArgs) return Values[operand]; if (operand < func->numValues) return Builder.CreateLoad(Values[operand]); // Constant operand -= func->numValues; // This was already validated by libclamav. assert(operand < func->numConstants && "Constant out of range"); uint64_t *c = &func->constants[operand]; uint64_t v; const Type *Ty; switch (w) { case 0: case 1: Ty = w ? Type::getInt8Ty(Context) : Type::getInt1Ty(Context); v = *(uint8_t*)c; break; case 2: Ty = Type::getInt16Ty(Context); v = *(uint16_t*)c; break; case 3: Ty = Type::getInt32Ty(Context); v = *(uint32_t*)c; break; case 4: Ty = Type::getInt64Ty(Context); v = *(uint64_t*)c; break; } return ConstantInt::get(Ty, v); } void Store(uint16_t dest, Value *V) { assert(dest >= numArgs && dest < numLocals+numArgs && "Instruction destination out of range"); Builder.CreateStore(V, Values[dest]); } // Insert code that calls \arg FHandler if \arg FailCond is true. void InsertVerify(Value *FailCond, BasicBlock *&Fail, Function *FHandler, Function *F) { if (!Fail) { Fail = BasicBlock::Create(Context, "fail", F); CallInst::Create(FHandler,"",Fail); new UnreachableInst(Context, Fail); } BasicBlock *OkBB = BasicBlock::Create(Context, "", F); Builder.CreateCondBr(FailCond, Fail, OkBB); Builder.SetInsertPoint(OkBB); } public: LLVMCodegen(const struct cli_bc *bc, Module *M, FunctionMapTy &cFuncs, ExecutionEngine *EE, FunctionPassManager &PM) : bc(bc), M(M), Context(M->getContext()), compiledFunctions(cFuncs), BytecodeID("bc"+Twine(bc->id)), EE(EE), Folder(EE->getTargetData(), Context), Builder(Context, Folder), PM(PM) { TypeMap = new const Type*[bc->num_types]; } bool generate() { PrettyStackTraceString Trace(BytecodeID.str().c_str()); convertTypes(); FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), false); Function *FHandler = Function::Create(FTy, Function::InternalLinkage, "clamjit.fail", M); FHandler->setDoesNotReturn(); FHandler->setDoesNotThrow(); FHandler->addFnAttr(Attribute::NoInline); EE->addGlobalMapping(FHandler, (void*)jit_exception_handler); Function **Functions = new Function*[bc->num_func]; for (unsigned j=0;jnum_func;j++) { PrettyStackTraceString CrashInfo("Generate LLVM IR functions"); // Create LLVM IR Function const struct cli_bc_func *func = &bc->funcs[j]; std::vector argTypes; for (unsigned a=0;anumArgs;a++) { argTypes.push_back(mapType(func->types[a])); } const Type *RetTy = mapType(func->returnType); FunctionType *FTy = FunctionType::get(RetTy, argTypes, false); Functions[j] = Function::Create(FTy, Function::InternalLinkage, BytecodeID+"f"+Twine(j), M); Functions[j]->setDoesNotThrow(); } for (unsigned j=0;jnum_func;j++) { PrettyStackTraceString CrashInfo("Generate LLVM IR"); const struct cli_bc_func *func = &bc->funcs[j]; // Create all BasicBlocks Function *F = Functions[j]; BasicBlock **BB = new BasicBlock*[func->numBB]; for (unsigned i=0;inumBB;i++) { BB[i] = BasicBlock::Create(Context, "", F); } BasicBlock *Fail = 0; Values = new Value*[func->numValues]; Builder.SetInsertPoint(BB[0]); Function::arg_iterator I = F->arg_begin(); for (unsigned i=0;inumArgs; i++) { assert(I != F->arg_end()); Values[i] = &*I; ++I; } for (unsigned i=func->numArgs;inumValues;i++) { Values[i] = Builder.CreateAlloca(mapType(func->types[i])); } numLocals = func->numLocals; numArgs = func->numArgs; // Generate LLVM IR for each BB for (unsigned i=0;inumBB;i++) { const struct cli_bc_bb *bb = &func->BB[i]; Builder.SetInsertPoint(BB[i]); for (unsigned j=0;jnumInsts;j++) { const struct cli_bc_inst *inst = &bb->insts[j]; Value *Op0, *Op1, *Op2; // libclamav has already validated this. assert(inst->opcode < OP_INVALID && "Invalid opcode"); switch (inst->opcode) { case OP_JMP: case OP_BRANCH: case OP_CALL_API: case OP_CALL_DIRECT: case OP_ZEXT: case OP_SEXT: case OP_TRUNC: // these instructions represents operands differently break; default: switch (operand_counts[inst->opcode]) { case 1: Op0 = convertOperand(func, inst, inst->u.unaryop); break; case 2: Op0 = convertOperand(func, inst, inst->u.binop[0]); Op1 = convertOperand(func, inst, inst->u.binop[1]); break; case 3: Op0 = convertOperand(func, inst, inst->u.three[0]); Op1 = convertOperand(func, inst, inst->u.three[1]); Op2 = convertOperand(func, inst, inst->u.three[2]); break; } } switch (inst->opcode) { case OP_ADD: Store(inst->dest, Builder.CreateAdd(Op0, Op1)); break; case OP_SUB: Store(inst->dest, Builder.CreateSub(Op0, Op1)); break; case OP_MUL: Store(inst->dest, Builder.CreateMul(Op0, Op1)); break; case OP_UDIV: { Value *Bad = Builder.CreateICmpEQ(Op1, ConstantInt::get(Op1->getType(), 0)); InsertVerify(Bad, Fail, FHandler, F); Store(inst->dest, Builder.CreateUDiv(Op0, Op1)); break; } case OP_SDIV: { //TODO: also verify Op0 == -1 && Op1 = INT_MIN Value *Bad = Builder.CreateICmpEQ(Op1, ConstantInt::get(Op1->getType(), 0)); InsertVerify(Bad, Fail, FHandler, F); Store(inst->dest, Builder.CreateSDiv(Op0, Op1)); break; } case OP_UREM: { Value *Bad = Builder.CreateICmpEQ(Op1, ConstantInt::get(Op1->getType(), 0)); InsertVerify(Bad, Fail, FHandler, F); Store(inst->dest, Builder.CreateURem(Op0, Op1)); break; } case OP_SREM: { //TODO: also verify Op0 == -1 && Op1 = INT_MIN Value *Bad = Builder.CreateICmpEQ(Op1, ConstantInt::get(Op1->getType(), 0)); InsertVerify(Bad, Fail, FHandler, F); Store(inst->dest, Builder.CreateSRem(Op0, Op1)); break; } case OP_SHL: Store(inst->dest, Builder.CreateShl(Op0, Op1)); break; case OP_LSHR: Store(inst->dest, Builder.CreateLShr(Op0, Op1)); break; case OP_ASHR: Store(inst->dest, Builder.CreateAShr(Op0, Op1)); break; case OP_AND: Store(inst->dest, Builder.CreateAnd(Op0, Op1)); break; case OP_OR: Store(inst->dest, Builder.CreateOr(Op0, Op1)); break; case OP_XOR: Store(inst->dest, Builder.CreateXor(Op0, Op1)); break; case OP_TRUNC: { Value *Src = convertOperand(func, inst, inst->u.cast.source); const Type *Ty = mapType(func->types[inst->dest]); Store(inst->dest, Builder.CreateTrunc(Src, Ty)); break; } case OP_ZEXT: { Value *Src = convertOperand(func, inst, inst->u.cast.source); const Type *Ty = mapType(func->types[inst->dest]); Store(inst->dest, Builder.CreateZExt(Src, Ty)); break; } case OP_SEXT: { Value *Src = convertOperand(func, inst, inst->u.cast.source); const Type *Ty = mapType(func->types[inst->dest]); Store(inst->dest, Builder.CreateSExt(Src, Ty)); break; } case OP_BRANCH: { Value *Cond = convertOperand(func, inst, inst->u.branch.condition); BasicBlock *True = BB[inst->u.branch.br_true]; BasicBlock *False = BB[inst->u.branch.br_false]; if (Cond->getType() != Type::getInt1Ty(Context)) { errs() << MODULE << "type mismatch in condition\n"; return false; } Builder.CreateCondBr(Cond, True, False); break; } case OP_JMP: { BasicBlock *Jmp = BB[inst->u.jump]; Builder.CreateBr(Jmp); break; } case OP_RET: Builder.CreateRet(Op0); break; case OP_ICMP_EQ: Store(inst->dest, Builder.CreateICmpEQ(Op0, Op1)); break; case OP_ICMP_NE: Store(inst->dest, Builder.CreateICmpNE(Op0, Op1)); break; case OP_ICMP_UGT: Store(inst->dest, Builder.CreateICmpUGT(Op0, Op1)); break; case OP_ICMP_UGE: Store(inst->dest, Builder.CreateICmpUGE(Op0, Op1)); break; case OP_ICMP_ULT: Store(inst->dest, Builder.CreateICmpULT(Op0, Op1)); break; case OP_ICMP_ULE: Store(inst->dest, Builder.CreateICmpULE(Op0, Op1)); break; case OP_ICMP_SGT: Store(inst->dest, Builder.CreateICmpSGT(Op0, Op1)); break; case OP_ICMP_SGE: Store(inst->dest, Builder.CreateICmpSGE(Op0, Op1)); break; case OP_ICMP_SLT: Store(inst->dest, Builder.CreateICmpSLT(Op0, Op1)); break; case OP_SELECT: Store(inst->dest, Builder.CreateSelect(Op0, Op1, Op2)); break; case OP_COPY: Builder.CreateStore(Op0, Op1); break; case OP_CALL_DIRECT: { Function *DestF = Functions[inst->u.ops.funcid]; SmallVector args; for (unsigned a=0;au.ops.numOps;a++) { operand_t op = inst->u.ops.ops[a]; args.push_back(convertOperand(func, DestF->getFunctionType()->getParamType(a), op)); } Store(inst->dest, Builder.CreateCall(DestF, args.begin(), args.end())); break; } default: errs() << "JIT doesn't implement opcode " << inst->opcode << " yet!\n"; return false; } } } if (verifyFunction(*F, PrintMessageAction)) { errs() << MODULE << "Verification failed\n"; // verification failed return false; } PM.run(*F); delete [] Values; delete [] BB; } DEBUG(M->dump()); delete [] TypeMap; FunctionType *Callable = FunctionType::get(Type::getInt32Ty(Context),false); for (unsigned j=0;jnum_func;j++) { const struct cli_bc_func *func = &bc->funcs[j]; PrettyStackTraceString CrashInfo2("Native machine codegen"); // Codegen current function as executable machine code. void *code = EE->getPointerToFunction(Functions[j]); // If prototype matches, add to callable functions if (Functions[j]->getFunctionType() == Callable) compiledFunctions[func] = code; } delete [] Functions; return true; } }; } int cli_vm_execute_jit(const struct cli_all_bc *bcs, struct cli_bc_ctx *ctx, const struct cli_bc_func *func) { jmp_buf env; void *code = bcs->engine->compiledFunctions[func]; if (!code) { errs() << MODULE << "Unable to find compiled function\n"; return CL_EBYTECODE; } // execute; if (setjmp(env) == 0) { // setup exception handler to longjmp back here ExceptionReturn.set(&env); uint32_t result = ((uint32_t (*)(void))code)(); *(uint32_t*)ctx->values = result; return 0; } errs() << "\n"; errs().changeColor(raw_ostream::RED, true) << MODULE << "*** JITed code intercepted runtime error!\n"; errs().resetColor(); return CL_EBYTECODE; } int cli_bytecode_prepare_jit(struct cli_all_bc *bcs) { if (!bcs->engine) return CL_EBYTECODE; jmp_buf env; // setup exception handler to longjmp back here ExceptionReturn.set(&env); if (setjmp(env) != 0) { errs() << "\n"; errs().changeColor(raw_ostream::RED, true) << MODULE << "*** FATAL error encountered during bytecode generation\n"; errs().resetColor(); return CL_EBYTECODE; } // LLVM itself never throws exceptions, but operator new may throw bad_alloc try { Module *M = new Module("ClamAV jit module", bcs->engine->Context); ExistingModuleProvider *MP = new ExistingModuleProvider(M); { // Create the JIT. std::string ErrorMsg; EngineBuilder builder(MP); builder.setErrorStr(&ErrorMsg); builder.setEngineKind(EngineKind::JIT); builder.setOptLevel(CodeGenOpt::Aggressive); ExecutionEngine *EE = bcs->engine->EE = builder.create(); if (!EE) { if (!ErrorMsg.empty()) errs() << MODULE << "error creating execution engine: " << ErrorMsg << "\n"; else errs() << MODULE << "JIT not registered?\n"; return CL_EBYTECODE; } EE->RegisterJITEventListener(createOProfileJITEventListener()); // Due to LLVM PR4816 only X86 supports non-lazy compilation, disable // for now. // EE->DisableLazyCompilation(); EE->DisableSymbolSearching(); FunctionPassManager OurFPM(MP); // Set up the optimizer pipeline. Start with registering info about how // the target lays out data structures. OurFPM.add(new TargetData(*EE->getTargetData())); // Promote allocas to registers. OurFPM.add(createPromoteMemoryToRegisterPass()); OurFPM.doInitialization(); for (unsigned i=0;icount;i++) { const struct cli_bc *bc = &bcs->all_bcs[i]; LLVMCodegen Codegen(bc, M, bcs->engine->compiledFunctions, EE, OurFPM); if (!Codegen.generate()) { errs() << MODULE << "JIT codegen failed\n"; return CL_EBYTECODE; } } for (unsigned i=0;icount;i++) { bcs->all_bcs[i].state = bc_jit; } // compile all functions now, not lazily! for (Module::iterator I = M->begin(), E = M->end(); I != E; ++I) { Function *Fn = &*I; if (!Fn->isDeclaration()) EE->getPointerToFunction(Fn); } } return -1; } catch (std::bad_alloc &badalloc) { errs() << MODULE << badalloc.what() << "\n"; return CL_EMEM; } catch (...) { errs() << MODULE << "Unexpected unknown exception occurred.\n"; return CL_EBYTECODE; } } int bytecode_init(void) { llvm_install_error_handler(llvm_error_handler); sys::PrintStackTraceOnErrorSignal(); atexit(do_shutdown); llvm_start_multithreaded(); // If we have a native target, initialize it to ensure it is linked in and // usable by the JIT. InitializeNativeTarget(); return 0; } // Called once when loading a new set of BC files int cli_bytecode_init_jit(struct cli_all_bc *bcs) { bcs->engine = new(std::nothrow) struct cli_bcengine; if (!bcs->engine) return CL_EMEM; return 0; } int cli_bytecode_done_jit(struct cli_all_bc *bcs) { if (bcs->engine) { if (bcs->engine->EE) delete bcs->engine->EE; delete bcs->engine; bcs->engine = 0; } return 0; } void cli_bytecode_debug(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv); } int have_clamjit=1;