@@ -1235,43 +1235,65 @@ llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
12351235 return nullptr ;
12361236}
12371237
1238- void CodeGenFunction::EmitBoundsCheck (const Expr *E, const Expr *Base,
1239- llvm::Value *Index, QualType IndexType,
1238+ void CodeGenFunction::EmitBoundsCheck (const Expr *ArrayExpr,
1239+ const Expr *ArrayExprBase,
1240+ llvm::Value *IndexVal, QualType IndexType,
12401241 bool Accessed) {
12411242 assert (SanOpts.has (SanitizerKind::ArrayBounds) &&
12421243 " should not be called unless adding bounds checks" );
12431244 const LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
12441245 getLangOpts ().getStrictFlexArraysLevel ();
1245- QualType IndexedType ;
1246- llvm::Value *Bound =
1247- getArrayIndexingBound ( *this , Base, IndexedType , StrictFlexArraysLevel);
1246+ QualType ArrayExprBaseType ;
1247+ llvm::Value *BoundsVal = getArrayIndexingBound (
1248+ *this , ArrayExprBase, ArrayExprBaseType , StrictFlexArraysLevel);
12481249
1249- EmitBoundsCheckImpl (E, Bound, Index, IndexType, IndexedType, Accessed);
1250+ EmitBoundsCheckImpl (ArrayExpr, ArrayExprBaseType, IndexVal, IndexType,
1251+ BoundsVal, getContext ().getSizeType (), Accessed);
12501252}
12511253
1252- void CodeGenFunction::EmitBoundsCheckImpl (const Expr *E, llvm::Value *Bound,
1253- llvm::Value *Index,
1254+ void CodeGenFunction::EmitBoundsCheckImpl (const Expr *ArrayExpr,
1255+ QualType ArrayBaseType,
1256+ llvm::Value *IndexVal,
12541257 QualType IndexType,
1255- QualType IndexedType, bool Accessed) {
1256- if (!Bound)
1258+ llvm::Value *BoundsVal,
1259+ QualType BoundsType, bool Accessed) {
1260+ if (!BoundsVal)
12571261 return ;
12581262
12591263 auto CheckKind = SanitizerKind::SO_ArrayBounds;
12601264 auto CheckHandler = SanitizerHandler::OutOfBounds;
12611265 SanitizerDebugLocation SanScope (this , {CheckKind}, CheckHandler);
12621266
1267+ // All hail the C implicit type conversion rules!!!
12631268 bool IndexSigned = IndexType->isSignedIntegerOrEnumerationType ();
1264- llvm::Value *IndexVal = Builder.CreateIntCast (Index, SizeTy, IndexSigned);
1265- llvm::Value *BoundVal = Builder.CreateIntCast (Bound, SizeTy, false );
1269+ bool BoundsSigned = BoundsType->isSignedIntegerOrEnumerationType ();
1270+
1271+ const ASTContext &Ctx = getContext ();
1272+ llvm::Type *Ty = ConvertType (
1273+ Ctx.getTypeSize (IndexType) >= Ctx.getTypeSize (BoundsType) ? IndexType
1274+ : BoundsType);
1275+
1276+ llvm::Value *IndexInst = Builder.CreateIntCast (IndexVal, Ty, IndexSigned);
1277+ llvm::Value *BoundsInst = Builder.CreateIntCast (BoundsVal, Ty, false );
12661278
12671279 llvm::Constant *StaticData[] = {
1268- EmitCheckSourceLocation (E ->getExprLoc ()),
1269- EmitCheckTypeDescriptor (IndexedType ),
1270- EmitCheckTypeDescriptor (IndexType)
1280+ EmitCheckSourceLocation (ArrayExpr ->getExprLoc ()),
1281+ EmitCheckTypeDescriptor (ArrayBaseType ),
1282+ EmitCheckTypeDescriptor (IndexType),
12711283 };
1272- llvm::Value *Check = Accessed ? Builder.CreateICmpULT (IndexVal, BoundVal)
1273- : Builder.CreateICmpULE (IndexVal, BoundVal);
1274- EmitCheck (std::make_pair (Check, CheckKind), CheckHandler, StaticData, Index);
1284+
1285+ llvm::Value *Check = Accessed ? Builder.CreateICmpULT (IndexInst, BoundsInst)
1286+ : Builder.CreateICmpULE (IndexInst, BoundsInst);
1287+
1288+ if (BoundsSigned) {
1289+ // Don't allow a negative bounds.
1290+ llvm::Value *Cmp = Builder.CreateICmpSGT (
1291+ BoundsVal, llvm::ConstantInt::get (BoundsVal->getType (), 0 ));
1292+ Check = Builder.CreateAnd (Cmp, Check);
1293+ }
1294+
1295+ EmitCheck (std::make_pair (Check, CheckKind), CheckHandler, StaticData,
1296+ IndexInst);
12751297}
12761298
12771299llvm::MDNode *CodeGenFunction::buildAllocToken (QualType AllocType) {
@@ -4608,9 +4630,10 @@ static std::optional<int64_t> getOffsetDifferenceInBits(CodeGenFunction &CGF,
46084630// / i.e. "a.b.count", so we shouldn't need the full force of EmitLValue or
46094631// / similar to emit the correct GEP.
46104632void CodeGenFunction::EmitCountedByBoundsChecking (
4611- const Expr *E, llvm::Value *Idx, Address Addr, QualType IdxTy,
4612- QualType ArrayTy, bool Accessed, bool FlexibleArray) {
4613- const auto *ME = dyn_cast<MemberExpr>(E->IgnoreImpCasts ());
4633+ const Expr *ArrayExpr, QualType ArrayType, Address ArrayInst,
4634+ QualType IndexType, llvm::Value *IndexVal, bool Accessed,
4635+ bool FlexibleArray) {
4636+ const auto *ME = dyn_cast<MemberExpr>(ArrayExpr->IgnoreImpCasts ());
46144637 if (!ME || !ME->getMemberDecl ()->getType ()->isCountAttributedType ())
46154638 return ;
46164639
@@ -4627,11 +4650,11 @@ void CodeGenFunction::EmitCountedByBoundsChecking(
46274650
46284651 if (std::optional<int64_t > Diff =
46294652 getOffsetDifferenceInBits (*this , CountFD, FD)) {
4630- if (!Addr .isValid ()) {
4653+ if (!ArrayInst .isValid ()) {
46314654 // An invalid Address indicates we're checking a pointer array access.
46324655 // Emit the checked L-Value here.
4633- LValue LV = EmitCheckedLValue (E , TCK_MemberAccess);
4634- Addr = LV.getAddress ();
4656+ LValue LV = EmitCheckedLValue (ArrayExpr , TCK_MemberAccess);
4657+ ArrayInst = LV.getAddress ();
46354658 }
46364659
46374660 // FIXME: The 'static_cast' is necessary, otherwise the result turns into a
@@ -4640,17 +4663,19 @@ void CodeGenFunction::EmitCountedByBoundsChecking(
46404663
46414664 // Create a GEP with the byte offset between the counted object and the
46424665 // count and use that to load the count value.
4643- Addr = Builder.CreatePointerBitCastOrAddrSpaceCast (Addr, Int8PtrTy, Int8Ty);
4666+ ArrayInst = Builder.CreatePointerBitCastOrAddrSpaceCast (ArrayInst,
4667+ Int8PtrTy, Int8Ty);
46444668
4645- llvm::Type *CountTy = ConvertType (CountFD->getType ());
4646- llvm::Value *Res =
4647- Builder.CreateInBoundsGEP (Int8Ty, Addr .emitRawPointer (*this ),
4669+ llvm::Type *BoundsType = ConvertType (CountFD->getType ());
4670+ llvm::Value *BoundsVal =
4671+ Builder.CreateInBoundsGEP (Int8Ty, ArrayInst .emitRawPointer (*this ),
46484672 Builder.getInt32 (*Diff), " .counted_by.gep" );
4649- Res = Builder.CreateAlignedLoad (CountTy, Res , getIntAlign (),
4650- " .counted_by.load" );
4673+ BoundsVal = Builder.CreateAlignedLoad (BoundsType, BoundsVal , getIntAlign (),
4674+ " .counted_by.load" );
46514675
46524676 // Now emit the bounds checking.
4653- EmitBoundsCheckImpl (E, Res, Idx, IdxTy, ArrayTy, Accessed);
4677+ EmitBoundsCheckImpl (ArrayExpr, ArrayType, IndexVal, IndexType, BoundsVal,
4678+ CountFD->getType (), Accessed);
46544679 }
46554680}
46564681
@@ -4796,9 +4821,9 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
47964821 auto *Idx = EmitIdxAfterBase (/* Promote*/ true );
47974822
47984823 if (SanOpts.has (SanitizerKind::ArrayBounds))
4799- EmitCountedByBoundsChecking (Array, Idx , ArrayLV.getAddress (),
4800- E->getIdx ()->getType (), Array-> getType () ,
4801- Accessed, /* FlexibleArray=*/ true );
4824+ EmitCountedByBoundsChecking (Array, Array-> getType () , ArrayLV.getAddress (),
4825+ E->getIdx ()->getType (), Idx, Accessed ,
4826+ /* FlexibleArray=*/ true );
48024827
48034828 // Propagate the alignment from the array itself to the result.
48044829 QualType arrayType = Array->getType ();
@@ -4850,8 +4875,8 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
48504875
48514876 if (const auto *CE = dyn_cast_if_present<CastExpr>(Base);
48524877 CE && CE->getCastKind () == CK_LValueToRValue)
4853- EmitCountedByBoundsChecking (CE, Idx , Address::invalid (),
4854- E->getIdx ()->getType (), ptrType , Accessed,
4878+ EmitCountedByBoundsChecking (CE, ptrType , Address::invalid (),
4879+ E->getIdx ()->getType (), Idx , Accessed,
48554880 /* FlexibleArray=*/ false );
48564881 }
48574882 }
0 commit comments