libclamav/c++/ClamBCRTChecks.cpp
a45651d4
 /*
  *  Compile LLVM bytecode to ClamAV bytecode.
  *
e1cbc270
  *  Copyright (C) 2013-2019 Cisco Systems, Inc. and/or its affiliates. All rights reserved.
26f3bad1
  *  Copyright (C) 2009-2013 Sourcefire, Inc.
a45651d4
  *
0ea799f8
  *  Authors: Török Edvin, Kevin Lin
a45651d4
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License version 2 as
  *  published by the Free Software Foundation.
  *
  *  This program is distributed in the hope that it will be useful,
  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  *  GNU General Public License for more details.
  *
  *  You should have received a copy of the GNU General Public License
  *  along with this program; if not, write to the Free Software
  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
  *  MA 02110-1301, USA.
  */
 #define DEBUG_TYPE "clambc-rtcheck"
 #include "ClamBCModule.h"
daad92ac
 #include "ClamBCDiagnostics.h"
0ea799f8
 #include "llvm30_compat.h" /* libclamav-specific */
556eaf04
 #include "llvm/ADT/DenseSet.h"
a45651d4
 #include "llvm/ADT/PostOrderIterator.h"
556eaf04
 #include "llvm/ADT/SCCIterator.h"
 #include "llvm/Analysis/CallGraph.h"
0ea799f8
 #if LLVM_VERSION < 32
 #include "llvm/Analysis/DebugInfo.h"
4a40b53a
 #elif LLVM_VERSION < 35
0ea799f8
 #include "llvm/DebugInfo.h"
4a40b53a
 #else
 #include "llvm/IR/DebugInfo.h"
0ea799f8
 #endif
4a40b53a
 #if LLVM_VERSION < 35
a45651d4
 #include "llvm/Analysis/Dominators.h"
4a40b53a
 #include "llvm/Analysis/Verifier.h"
 #else
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Verifier.h"
 #endif
a45651d4
 #include "llvm/Analysis/ConstantFolding.h"
0ea799f8
 #if LLVM_VERSION < 29
 //#include "llvm/Analysis/LiveValues.h" (unused)
a45651d4
 #include "llvm/Analysis/PointerTracking.h"
0ea799f8
 #else
 #include "llvm/Analysis/ValueTracking.h"
 #include "PointerTracking.h" /* included from old LLVM source */
0c79cc55
 #endif
a45651d4
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Config/config.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
4a40b53a
 #if LLVM_VERSION < 35
a45651d4
 #include "llvm/Support/DataFlow.h"
 #include "llvm/Support/InstIterator.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
4a40b53a
 #else
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #endif
a45651d4
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Support/Debug.h"
34d8b8cf
 #if LLVM_VERSION < 32
 #include "llvm/Target/TargetData.h"
 #elif LLVM_VERSION < 33
 #include "llvm/DataLayout.h"
 #else
 #include "llvm/IR/DataLayout.h"
 #endif
 #if LLVM_VERSION < 33
 #include "llvm/DerivedTypes.h"
 #include "llvm/Instructions.h"
 #include "llvm/IntrinsicInst.h"
 #include "llvm/Intrinsics.h"
 #include "llvm/LLVMContext.h"
 #include "llvm/Module.h"
 #else
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
4a40b53a
 #endif
 
 #if LLVM_VERSION < 33
 #include "llvm/Support/InstVisitor.h"
 #elif LLVM_VERSION < 35
34d8b8cf
 #include "llvm/InstVisitor.h"
4a40b53a
 #else
 #include "llvm/IR/InstVisitor.h"
34d8b8cf
 #endif
 
0ea799f8
 #define DEFINEPASS(passname) passname() : FunctionPass(ID)
0c79cc55
 
a45651d4
 using namespace llvm;
0ea799f8
 #if LLVM_VERSION < 29
7cd9337a
 /* function is succeeded in later LLVM with LLVM corresponding standalone */
0c79cc55
 static Value *GetUnderlyingObject(Value *P, TargetData *TD)
 {
     return P->getUnderlyingObject();
 }
 #endif
59f1b78b
 
0ea799f8
 namespace llvm {
   class PtrVerifier;
deafd8fa
 #if LLVM_VERSION >= 29
59f1b78b
   void initializePtrVerifierPass(PassRegistry&);
 #endif
a45651d4
 
   class PtrVerifier : public FunctionPass {
556eaf04
   private:
0ea799f8
       DenseSet<Function*> badFunctions;
       std::vector<Instruction*> delInst;
4a40b53a
 #if LLVM_VERSION < 35
0ea799f8
       CallGraphNode *rootNode;
4a40b53a
 #else
       CallGraph *CG;
 #endif
a45651d4
   public:
0ea799f8
       static char ID;
4a40b53a
 #if LLVM_VERSION < 35
7d05cb15
       DEFINEPASS(PtrVerifier), rootNode(0), PT(), TD(), SE(), expander(),
4a40b53a
 #else
       DEFINEPASS(PtrVerifier), CG(0), PT(), TD(), SE(), expander(),
 #endif
7d05cb15
           DT(), AbrtBB(), Changed(false), valid(false), EP() {
deafd8fa
 #if LLVM_VERSION >= 29
0ea799f8
           initializePtrVerifierPass(*PassRegistry::getPassRegistry());
59f1b78b
 #endif
556eaf04
       }
 
0ea799f8
       virtual bool runOnFunction(Function &F) {
           /*
 #ifndef CLAMBC_COMPILER
           // Bytecode was already verified and had stack protector applied.
           // We get called again because ALL bytecode functions loaded are part of
           // the same module.
           if (F.hasFnAttr(Attribute::StackProtectReq))
               return false;
 #endif
           */
 
           DEBUG(errs() << "Running on " << F.getName() << "\n");
           DEBUG(F.dump());
           Changed = false;
           BaseMap.clear();
           BoundsMap.clear();
           delInst.clear();
           AbrtBB = 0;
           valid = true;
 
4a40b53a
 #if LLVM_VERSION < 35
0ea799f8
           if (!rootNode) {
               rootNode = getAnalysis<CallGraph>().getRoot();
4a40b53a
 #else
           if (!CG) {
               CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
 #endif
0ea799f8
               // No recursive functions for now.
               // In the future we may insert runtime checks for stack depth.
4a40b53a
 #if LLVM_VERSION < 35
0ea799f8
               for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
                        E = scc_end(rootNode); SCCI != E; ++SCCI) {
4a40b53a
 #else
               for (scc_iterator<CallGraph*> SCCI = scc_begin(CG); !SCCI.isAtEnd(); ++SCCI) {
 #endif
0ea799f8
                   const std::vector<CallGraphNode*> &nextSCC = *SCCI;
                   if (nextSCC.size() > 1 || SCCI.hasLoop()) {
                       errs() << "INVALID: Recursion detected, callgraph SCC components: ";
                       for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
                                E = nextSCC.end(); I != E; ++I) {
                           Function *FF = (*I)->getFunction();
                           if (FF) {
                               errs() << FF->getName() << ", ";
                               badFunctions.insert(FF);
                           }
                       }
                       if (SCCI.hasLoop())
                           errs() << "(self-loop)";
                       errs() << "\n";
                   }
                   // we could also have recursion via function pointers, but we don't
                   // allow calls to unknown functions, see runOnFunction() below
               }
           }
 
           BasicBlock::iterator It = F.getEntryBlock().begin();
           while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
           EP = &*It;
34d8b8cf
 #if LLVM_VERSION < 32
0ea799f8
           TD = &getAnalysis<TargetData>();
4a40b53a
 #elif LLVM_VERSION < 35
0ea799f8
           TD = &getAnalysis<DataLayout>();
4a40b53a
 #else
           DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
           TD = DLP ? &DLP->getDataLayout() : 0;
34d8b8cf
 #endif
0ea799f8
           SE = &getAnalysis<ScalarEvolution>();
           PT = &getAnalysis<PointerTracking>();
4a40b53a
 #if LLVM_VERSION < 35
0ea799f8
           DT = &getAnalysis<DominatorTree>();
4a40b53a
 #else
           DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
 #endif
0b5550c1
           expander = new SCEVExpander(*SE OPT("SCEVexpander"));
0ea799f8
 
           std::vector<Instruction*> insns;
 
           BasicBlock *LastBB = 0;
           for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
               Instruction *II = &*I;
               /* only appears in the libclamav version */
               if (II->getParent() != LastBB) {
                   LastBB = II->getParent();
                   if (DT->getNode(LastBB) == 0)
                       continue;
               }
               /* end-block */
               if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
                   insns.push_back(II);
               else if (CallInst *CI = dyn_cast<CallInst>(II)) {
                   Value *V = CI->getCalledValue()->stripPointerCasts();
                   Function *F = dyn_cast<Function>(V);
                   if (!F) {
                       printLocation(CI, true);
                       errs() << "Could not determine call target\n";
                       valid = 0;
                       continue;
                   }
                   // this statement disable checks on user-defined CallInst
                   //if (!F->isDeclaration())
                   //continue;
                   insns.push_back(CI);
               }
           }
 
           for (unsigned Idx = 0; Idx < insns.size(); ++Idx) {
               Instruction *II = insns[Idx];
               DEBUG(dbgs() << "checking " << *II << "\n");
               if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
0b5550c1
                   constType *Ty = LI->getType();
0ea799f8
                   valid &= validateAccess(LI->getPointerOperand(),
                                           TD->getTypeAllocSize(Ty), LI);
               } else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
0b5550c1
                   constType *Ty = SI->getOperand(0)->getType();
0ea799f8
                   valid &= validateAccess(SI->getPointerOperand(),
                                           TD->getTypeAllocSize(Ty), SI);
               } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
                   valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
                   if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
                       valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
                   }
               } else if (CallInst *CI = dyn_cast<CallInst>(II)) {
                   Value *V = CI->getCalledValue()->stripPointerCasts();
                   Function *F = cast<Function>(V);
0b5550c1
                   constFunctionType *FTy = F->getFunctionType();
0ea799f8
                   CallSite CS(CI);
                   if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
                       valid &= validateAccess(CS.getArgument(0), CS.getArgument(2), CI);
                       valid &= validateAccess(CS.getArgument(1), CS.getArgument(2), CI);
                       continue;
                   }
                   unsigned i;
556eaf04
 #ifdef CLAMBC_COMPILER
0ea799f8
                   i = 0;
556eaf04
 #else
0ea799f8
                   i = 1;// skip hidden ctx*
556eaf04
 #endif
0ea799f8
                   for (;i<FTy->getNumParams();i++) {
                       if (isa<PointerType>(FTy->getParamType(i))) {
                           Value *Ptr = CS.getArgument(i);
                           if (i+1 >= FTy->getNumParams()) {
                               printLocation(CI, false);
                               errs() << "Call to external function with pointer parameter last"
                                   " cannot be analyzed\n";
                               errs() << *CI << "\n";
                               valid = 0;
                               break;
                           }
                           Value *Size = CS.getArgument(i+1);
                           if (!Size->getType()->isIntegerTy()) {
                               printLocation(CI, false);
                               errs() << "Pointer argument must be followed by integer argument"
                                   " representing its size\n";
                               errs() << *CI << "\n";
                               valid = 0;
                               break;
                           }
                           valid &= validateAccess(Ptr, Size, CI);
                       }
                   }
556eaf04
               }
           }
0ea799f8
           if (badFunctions.count(&F))
               valid = 0;
 
           if (!valid) {
               DEBUG(F.dump());
               ClamBCModule::stop("Verification found errors!", &F);
               // replace function with call to abort
               std::vector<constType*>args;
               FunctionType* abrtTy = FunctionType::get(Type::getVoidTy(F.getContext()),args,false);
               Constant *func_abort = F.getParent()->getOrInsertFunction("abort", abrtTy);
 
               BasicBlock *BB = &F.getEntryBlock();
               Instruction *I = &*BB->begin();
               Instruction *UI = new UnreachableInst(F.getContext(), I);
               CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
               AbrtC->setCallingConv(CallingConv::C);
               AbrtC->setTailCall(true);
34d8b8cf
 #if LLVM_VERSION < 32
0ea799f8
               AbrtC->setDoesNotReturn(true);
               AbrtC->setDoesNotThrow(true);
34d8b8cf
 #else
0ea799f8
               AbrtC->setDoesNotReturn();
               AbrtC->setDoesNotThrow();
34d8b8cf
 #endif
0ea799f8
               // remove all instructions from entry
               BasicBlock::iterator BBI = I, BBE=BB->end();
               while (BBI != BBE) {
                   if (!BBI->use_empty())
                       BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
                   BB->getInstList().erase(BBI++);
               }
           }
 
           // bb#9967 - deleting obsolete termination instructions
           for (unsigned i = 0; i < delInst.size(); ++i)
               delInst[i]->eraseFromParent();
 
           delete expander;
           return Changed;
a45651d4
       }
 
0ea799f8
       virtual void releaseMemory() {
           badFunctions.clear();
       }
556eaf04
 
0ea799f8
       virtual void getAnalysisUsage(AnalysisUsage &AU) const {
34d8b8cf
 #if LLVM_VERSION < 32
0ea799f8
           AU.addRequired<TargetData>();
4a40b53a
 #elif LLVM_VERSION < 35
0ea799f8
           AU.addRequired<DataLayout>();
4a40b53a
 #else
           AU.addRequired<DataLayoutPass>();
34d8b8cf
 #endif
4a40b53a
 #if LLVM_VERSION < 35
0ea799f8
           AU.addRequired<DominatorTree>();
4a40b53a
 #else
           AU.addRequired<DominatorTreeWrapperPass>();
 #endif
0ea799f8
           AU.addRequired<ScalarEvolution>();
           AU.addRequired<PointerTracking>();
4a40b53a
 #if LLVM_VERSION < 35
0ea799f8
           AU.addRequired<CallGraph>();
4a40b53a
 #else
           AU.addRequired<CallGraphWrapperPass>();
 #endif
0ea799f8
       }
a45651d4
 
0ea799f8
       bool isValid() const { return valid; }
a45651d4
   private:
0ea799f8
       PointerTracking *PT;
34d8b8cf
 #if LLVM_VERSION < 32
0ea799f8
       TargetData *TD;
4a40b53a
 #elif LLVM_VERSION < 35
0ea799f8
       DataLayout *TD;
4a40b53a
 #else
       const DataLayout *TD;
34d8b8cf
 #endif
0ea799f8
       ScalarEvolution *SE;
       SCEVExpander *expander;
       DominatorTree *DT;
       DenseMap<Value*, Value*> BaseMap;
       DenseMap<Value*, Value*> BoundsMap;
       BasicBlock *AbrtBB;
       bool Changed;
       bool valid;
       Instruction *EP;
 
       Instruction *getInsertPoint(Value *V)
       {
           BasicBlock::iterator It = EP;
           if (Instruction *I = dyn_cast<Instruction>(V)) {
               It = I;
               ++It;
           }
           return &*It;
a45651d4
       }
 
0ea799f8
       Value *getPointerBase(Value *Ptr)
       {
           if (BaseMap.count(Ptr))
               return BaseMap[Ptr];
           Value *P = Ptr->stripPointerCasts();
           if (BaseMap.count(P)) {
               return BaseMap[Ptr] = BaseMap[P];
           }
           Value *P2 = GetUnderlyingObject(P, TD);
           if (P2 != P) {
               Value *V = getPointerBase(P2);
               return BaseMap[Ptr] = V;
           }
 
           constType *P8Ty =
               PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
           if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
               BasicBlock::iterator It = PN;
               ++It;
               PHINode *newPN = PHINode::Create(P8Ty, HINT(PN->getNumIncomingValues()) ".verif.base", &*It);
               Changed = true;
               BaseMap[Ptr] = newPN;
 
               for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
                   Value *Inc = PN->getIncomingValue(i);
                   Value *V = getPointerBase(Inc);
                   newPN->addIncoming(V, PN->getIncomingBlock(i));
               }
               return newPN;
           }
           if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
               BasicBlock::iterator It = SI;
               ++It;
               Value *TrueB = getPointerBase(SI->getTrueValue());
               Value *FalseB = getPointerBase(SI->getFalseValue());
               if (TrueB && FalseB) {
                   SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
                                                          FalseB, ".select.base", &*It);
                   Changed = true;
                   return BaseMap[Ptr] = NewSI;
               }
           }
           if (Ptr->getType() != P8Ty) {
               if (Constant *C = dyn_cast<Constant>(Ptr))
                   Ptr = ConstantExpr::getPointerCast(C, P8Ty);
               else {
                   Instruction *I = getInsertPoint(Ptr);
                   Ptr = new BitCastInst(Ptr, P8Ty, "", I);
               }
           }
           return BaseMap[Ptr] = Ptr;
556eaf04
       }
0ea799f8
 
       Value* getValAtIdx(Function *F, unsigned Idx) {
           Value *Val= NULL;
 
           // check if accessed Idx is within function parameter list
           if (Idx < F->arg_size()) {
               Function::arg_iterator It = F->arg_begin();
               Function::arg_iterator ItEnd = F->arg_end();
               for (unsigned i = 0; i < Idx; ++i, ++It) {
                   // redundant check, should not be possible
                   if (It == ItEnd) {
                       // Houston, the impossible has become possible
                       //printDiagnostic("Idx is outside of Function parameters", F);
                       errs() << "Idx is outside of Function parameters\n";
                       errs() << *F << "\n";
                       //valid = 0;
                       break;
                   }
               }
               // retrieve value ptr of argument of F at Idx
               Val = &(*It);
           }
           else {
               // Idx is outside function parameter list
               //printDiagnostic("Idx is outside of Function parameters", F);
               errs() << "Idx is outside of Function parameters\n";
               errs() << *F << "\n";
               //valid = 0;
           }
           return Val;
a45651d4
       }
0ea799f8
 
       Value* getPointerBounds(Value *Base) {
           if (BoundsMap.count(Base))
               return BoundsMap[Base];
0b5550c1
           constType *I64Ty =
0ea799f8
               Type::getInt64Ty(Base->getContext());
 
46744356
 #ifndef CLAMBC_COMPILER
           // first arg is hidden ctx
           if (Argument *A = dyn_cast<Argument>(Base)) {
               if (A->getArgNo() == 0) {
                   constType *Ty = cast<PointerType>(A->getType())->getElementType();
                   return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
               } else if (Base->getType()->isPointerTy()) {
0ea799f8
                   Function *F = A->getParent();
                   const FunctionType *FT = F->getFunctionType();
 
                   bool checks = true;
                   // last argument check
                   if (A->getArgNo() == (FT->getNumParams()-1)) {
                       //printDiagnostic("pointer argument cannot be last argument", F);
                       errs() << "pointer argument cannot be last argument\n";
                       errs() << *F << "\n";
                       checks = false;
                   }
 
                   // argument after pointer MUST be a integer (unsigned probably too)
                   if (checks && !FT->getParamType(A->getArgNo()+1)->isIntegerTy()) {
                       //printDiagnostic("argument following pointer argument is not an integer", F);
                       errs() << "argument following pointer argument is not an integer\n";
                       errs() << *F << "\n";
                       checks = false;
                   }
 
                   if (checks)
                       return BoundsMap[Base] = getValAtIdx(F, A->getArgNo()+1);
               }
           }
           if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
0b5550c1
               Value *V = GetUnderlyingObject(LI->getPointerOperand()->stripPointerCasts(), TD);
0ea799f8
               if (Argument *A = dyn_cast<Argument>(V)) {
                   if (A->getArgNo() == 0) {
                       // pointers from hidden ctx are trusted to be at least the
                       // size they say they are
0b5550c1
                       constType *Ty = cast<PointerType>(LI->getType())->getElementType();
0ea799f8
                       return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
                   }
               }
           }
46744356
 #else
           if (Base->getType()->isPointerTy()) {
               if (Argument *A = dyn_cast<Argument>(Base)) {
                   Function *F = A->getParent();
                   const FunctionType *FT = F->getFunctionType();
 
                   bool checks = true;
                   // last argument check
                   if (A->getArgNo() == (FT->getNumParams()-1)) {
                       //printDiagnostic("pointer argument cannot be last argument", F);
                       errs() << "pointer argument cannot be last argument\n";
                       errs() << *F << "\n";
                       checks = false;
                   }
 
                   // argument after pointer MUST be a integer (unsigned probably too)
                   if (checks && !FT->getParamType(A->getArgNo()+1)->isIntegerTy()) {
                       //printDiagnostic("argument following pointer argument is not an integer", F);
                       errs() << "argument following pointer argument is not an integer\n";
                       errs() << *F << "\n";
                       checks = false;
                   }
 
                   if (checks)
                       return BoundsMap[Base] = getValAtIdx(F, A->getArgNo()+1);
               }
           }
556eaf04
 #endif
0ea799f8
           if (PHINode *PN = dyn_cast<PHINode>(Base)) {
               BasicBlock::iterator It = PN;
               ++It;
               PHINode *newPN = PHINode::Create(I64Ty, HINT(PN->getNumIncomingValues()) ".verif.bounds", &*It);
               Changed = true;
               BoundsMap[Base] = newPN;
 
               bool good = true;
               for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
                   Value *Inc = PN->getIncomingValue(i);
                   Value *B = getPointerBounds(Inc);
                   if (!B) {
                       good = false;
                       B = ConstantInt::get(newPN->getType(), 0);
                       DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
                             << "\n");
                   }
                   newPN->addIncoming(B, PN->getIncomingBlock(i));
               }
               if (!good)
                   newPN = 0;
               return BoundsMap[Base] = newPN;
           }
           if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
               BasicBlock::iterator It = SI;
               ++It;
               Value *TrueB = getPointerBounds(SI->getTrueValue());
               Value *FalseB = getPointerBounds(SI->getFalseValue());
               if (TrueB && FalseB) {
                   SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
                                                          FalseB, ".select.bounds", &*It);
                   Changed = true;
                   return BoundsMap[Base] = NewSI;
               }
           }
a45651d4
 
0ea799f8
           constType *Ty;
           Value *V = PT->computeAllocationCountValue(Base, Ty);
           if (!V) {
               Base = Base->stripPointerCasts();
               if (CallInst *CI = dyn_cast<CallInst>(Base)) {
                   Function *F = CI->getCalledFunction();
0b5550c1
                   constFunctionType *FTy = F->getFunctionType();
0ea799f8
                   // last operand is always size for this API call kind
                   if (F->isDeclaration() && FTy->getNumParams() > 0) {
                       CallSite CS(CI);
                       if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
                           V = CS.getArgument(FTy->getNumParams()-1);
                   }
556eaf04
               }
0ea799f8
               if (!V)
                   return BoundsMap[Base] = 0;
           } else {
               unsigned size = TD->getTypeAllocSize(Ty);
               if (size > 1) {
                   Constant *C = cast<Constant>(V);
                   C = ConstantExpr::getMul(C,
                                            ConstantInt::get(Type::getInt32Ty(C->getContext()),
                                                             size));
                   V = C;
               }
           }
           if (V->getType() != I64Ty) {
               if (Constant *C = dyn_cast<Constant>(V))
                   V = ConstantExpr::getZExt(C, I64Ty);
               else {
                   Instruction *I = getInsertPoint(V);
                   V = new ZExtInst(V, I64Ty, "", I);
               }
           }
           return BoundsMap[Base] = V;
a45651d4
       }
0ea799f8
 
       MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind)
       {
           Approximate = false;
           if (MDNode *Dbg = I->getMetadata(MDDbgKind))
               return Dbg;
           if (!MDDbgKind)
               return 0;
           Approximate = true;
           BasicBlock::iterator It = I;
           while (It != I->getParent()->begin()) {
               --It;
               if (MDNode *Dbg = It->getMetadata(MDDbgKind))
                   return Dbg;
           }
           BasicBlock *BB = I->getParent();
           while ((BB = BB->getUniquePredecessor())) {
               It = BB->end();
               while (It != BB->begin()) {
                   --It;
                   if (MDNode *Dbg = It->getMetadata(MDDbgKind))
                       return Dbg;
               }
           }
           return 0;
a45651d4
       }
0ea799f8
 
       bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I,
                        bool strict)
       {
           if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
               errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
               return false;
           }
           if (isa<SCEVCouldNotCompute>(Idx)) {
               errs() << "Could not compute index: \n" << *I << "\n";
               return false;
           }
           if (isa<SCEVCouldNotCompute>(Limit)) {
               errs() << "Could not compute limit: " << *I << "\n";
               return false;
           }
           BasicBlock *BB = I->getParent();
           BasicBlock::iterator It = I;
           BasicBlock *newBB = SplitBlock(BB, &*It, this);
           PHINode *PN;
           unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
           //verifyFunction(*BB->getParent());
           if (!AbrtBB) {
0b5550c1
               std::vector<constType*>args;
0ea799f8
               FunctionType* abrtTy = FunctionType::get(Type::getVoidTy(BB->getContext()),args,false);
               args.push_back(Type::getInt32Ty(BB->getContext()));
               FunctionType* rterrTy = FunctionType::get(Type::getInt32Ty(BB->getContext()),args,false);
               Constant *func_abort = BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
               Constant *func_rterr = BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error",
                                                                                        rterrTy);
               AbrtBB = BasicBlock::Create(BB->getContext(), "rterr.trig", BB->getParent());
               
               PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),HINT(1) "",
                                    AbrtBB);
               if (MDDbgKind) {
                   CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
                   RtErrCall->setCallingConv(CallingConv::C);
                   RtErrCall->setTailCall(true);
34d8b8cf
 #if LLVM_VERSION < 32
0ea799f8
                   RtErrCall->setDoesNotThrow(true);
34d8b8cf
 #else
0ea799f8
                   RtErrCall->setDoesNotThrow();
34d8b8cf
 #endif
0ea799f8
               }
               CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
               AbrtC->setCallingConv(CallingConv::C);
               AbrtC->setTailCall(true);
34d8b8cf
 #if LLVM_VERSION < 32
0ea799f8
               AbrtC->setDoesNotReturn(true);
               AbrtC->setDoesNotThrow(true);
34d8b8cf
 #else
0ea799f8
               AbrtC->setDoesNotReturn();
               AbrtC->setDoesNotThrow();
34d8b8cf
 #endif
0ea799f8
               new UnreachableInst(BB->getContext(), AbrtBB);
               DT->addNewBlock(AbrtBB, BB);
               //verifyFunction(*BB->getParent());
           } else {
               PN = cast<PHINode>(AbrtBB->begin());
           }
           unsigned locationid = 0;
           bool Approximate;
           if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) {
               DILocation Loc(Dbg);
               locationid = Loc.getLineNumber() << 8;
               unsigned col = Loc.getColumnNumber();
               if (col > 254)
                   col = 254;
               if (Approximate)
                   col = 255;
               locationid |= col;
           }
           PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
                                            locationid), BB);
           TerminatorInst *TI = BB->getTerminator();
           Value *IdxV = expander->expandCodeFor(Idx, Limit->getType(), TI);
           Value *LimitV = expander->expandCodeFor(Limit, Limit->getType(), TI);
           if (isa<Instruction>(IdxV) &&
               !DT->dominates(cast<Instruction>(IdxV)->getParent(),I->getParent())) {
               printLocation(I, true);
               errs() << "basic block with value [ " << IdxV->getName();
               errs() << " ] with limit [ " << LimitV->getName();
               errs() << " ] does not dominate" << *I << "\n";
               return false;
           }
           if (isa<Instruction>(LimitV) &&
               !DT->dominates(cast<Instruction>(LimitV)->getParent(),I->getParent())) {
               printLocation(I, true);
               errs() << "basic block with limit [" << LimitV->getName();
               errs() << " ] on value [ " << IdxV->getName();
               errs() << " ] does not dominate" << *I << "\n";
               return false;
           }
           Value *Cond = new ICmpInst(TI, strict ?
                                      ICmpInst::ICMP_ULT :
                                      ICmpInst::ICMP_ULE, IdxV, LimitV);
           BranchInst::Create(newBB, AbrtBB, Cond, TI);
           //TI->eraseFromParent();
           delInst.push_back(TI);
           // Update dominator info
           BasicBlock *DomBB =
               DT->findNearestCommonDominator(BB, DT->getNode(AbrtBB)->getIDom()->getBlock());
           DT->changeImmediateDominator(AbrtBB, DomBB);
           return true;
556eaf04
       }
0ea799f8
 
       static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS)
       {
           if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
               LHS = ZL->getOperand();
           if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
               RHS = ZR->getOperand();
 
           constType* LTy = SE->getEffectiveSCEVType(LHS->getType());
           constType *RTy = SE->getEffectiveSCEVType(RHS->getType());
           if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
               LTy = RTy;
           LHS = SE->getNoopOrZeroExtend(LHS, LTy);
           RHS = SE->getNoopOrZeroExtend(RHS, LTy);
88f0eaac
       }
0ea799f8
 
       bool checkCond(Instruction *ICI, Instruction *I, bool equal)
       {
           for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
                JU != JUE; ++JU) {
               Value *JU_V = *JU;
               if (BranchInst *BI = dyn_cast<BranchInst>(JU_V)) {
                   if (!BI->isConditional())
                       continue;
                   BasicBlock *S = BI->getSuccessor(equal);
                   if (DT->dominates(S, I->getParent()))
                       return true;
               }
               if (BinaryOperator *BI = dyn_cast<BinaryOperator>(JU_V)) {
                   if (BI->getOpcode() == Instruction::Or &&
                       checkCond(BI, I, equal))
                       return true;
                   if (BI->getOpcode() == Instruction::And &&
                       checkCond(BI, I, !equal))
                       return true;
               }
           }
           return false;
dcd3e801
       }
a45651d4
 
0ea799f8
       bool checkCondition(Instruction *CI, Instruction *I)
       {
           for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
                U != UE; ++U) {
               Value *U_V = *U;
               if (ICmpInst *ICI = dyn_cast<ICmpInst>(U_V)) {
                   if (ICI->getOperand(0)->stripPointerCasts() == CI &&
                       isa<ConstantPointerNull>(ICI->getOperand(1))) {
                       if (checkCond(ICI, I, ICI->getPredicate() == ICmpInst::ICMP_EQ))
                           return true;
                   }
               }
           }
a45651d4
           return false;
0ea799f8
       }
 
       bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
       {
           // get base
           Value *Base = getPointerBase(Pointer);
 
           Value *SBase = Base->stripPointerCasts();
           // get bounds
           Value *Bounds = getPointerBounds(SBase);
           if (!Bounds) {
               printLocation(I, true);
               errs() << "no bounds for base ";
               printValue(SBase);
               errs() << " while checking access to ";
               printValue(Pointer);
               errs() << " of length ";
               printValue(Length);
               errs() << "\n";
 
               return false;
           }
 
           // checks if a NULL pointer check (returned from function) is made:
           if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
               // by checking if use is in the same block (i.e. no branching decisions)
               if (I->getParent() == CI->getParent()) {
                   printLocation(I, true);
                   errs() << "no null pointer check of pointer ";
                   printValue(Base, false, true);
                   errs() << " obtained by function call";
                   errs() << " before use in same block\n";
                   return false;
               }
               // by checking if a conditional contains the values in question somewhere
               // between their usage
               if (!checkCondition(CI, I)) {
                   printLocation(I, true);
                   errs() << "no null pointer check of pointer ";
                   printValue(Base, false, true);
                   errs() << " obtained by function call";
                   errs() << " before use\n";
                   return false;
               }
           }
 
0b5550c1
       constType *I64Ty =
a45651d4
           Type::getInt64Ty(Base->getContext());
0ea799f8
       const SCEV *SLen = SE->getSCEV(Length);
       const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
                                              SE->getSCEV(Base));
       SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
       OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
       const SCEV *Limit = SE->getSCEV(Bounds);
       Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
 
       DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
             *Length << "\n");
       if (OffsetP == Limit) {
daad92ac
           printLocation(I, true);
           errs() << "OffsetP == Limit: " << *OffsetP << "\n";
           errs() << " while checking access to ";
           printValue(Pointer);
           errs() << " of length ";
           printValue(Length);
           errs() << "\n";
           return false;
0ea799f8
       }
a45651d4
 
0ea799f8
       if (SLen == Limit) {
a45651d4
           if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
0ea799f8
               if (SC->isZero())
                   return true;
a45651d4
           }
           errs() << "SLen == Limit: " << *SLen << "\n";
           errs() << " while checking access to " << *Pointer << " of length "
0ea799f8
                  << *Length << " at " << *I << "\n";
daad92ac
           return false;
0ea799f8
       }
 
       bool valid = true;
       SLen = SE->getAddExpr(OffsetP, SLen);
       // check that offset + slen <= limit;
       // umax(offset+slen, limit) == limit is a sufficient (but not necessary
       // condition)
       const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
       if (MaxL != Limit) {
a45651d4
           DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
daad92ac
           valid &= insertCheck(SLen, Limit, I, false);
0ea799f8
       }
a45651d4
 
0ea799f8
       //TODO: nullpointer check
       const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
       if (Max == Limit)
daad92ac
           return valid;
0ea799f8
       DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
 
       // check that offset < limit
       valid &= insertCheck(OffsetP, Limit, I, true);
       return valid;
       }
 
       bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
       {
           return validateAccess(Pointer,
                                 ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
                                                  size), I);
       }
 
a45651d4
   };
0ea799f8
     char PtrVerifier::ID;
a45651d4
 
0ea799f8
 } /* end namespace llvm */
deafd8fa
 #if LLVM_VERSION >= 29
59f1b78b
 INITIALIZE_PASS_BEGIN(PtrVerifier, "", "", false, false)
34d8b8cf
 #if LLVM_VERSION < 32
59f1b78b
 INITIALIZE_PASS_DEPENDENCY(TargetData)
4a40b53a
 #elif LLVM_VERSION < 35
34d8b8cf
 INITIALIZE_PASS_DEPENDENCY(DataLayout)
4a40b53a
 #else
 INITIALIZE_PASS_DEPENDENCY(DataLayoutPass)
34d8b8cf
 #endif
4a40b53a
 #if LLVM_VERSION < 35
59f1b78b
 INITIALIZE_PASS_DEPENDENCY(DominatorTree)
4a40b53a
 #else
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
 #endif
59f1b78b
 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
34d8b8cf
 #if LLVM_VERSION < 34
59f1b78b
 INITIALIZE_AG_DEPENDENCY(CallGraph)
4a40b53a
 #elif LLVM_VERSION < 35
34d8b8cf
 INITIALIZE_PASS_DEPENDENCY(CallGraph)
4a40b53a
 #else
 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
34d8b8cf
 #endif
59f1b78b
 INITIALIZE_PASS_DEPENDENCY(PointerTracking)
0b5550c1
 INITIALIZE_PASS_END(PtrVerifier, "clambc-rtchecks", "ClamBC RTchecks", false, false)
59f1b78b
 #endif
 
a45651d4
 
 llvm::Pass *createClamBCRTChecks()
 {
0ea799f8
     return new PtrVerifier();
a45651d4
 }