/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- InductiveRangeCheckElimination.cpp - -------------------------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // The InductiveRangeCheckElimination pass splits a loop's iteration space into |
10 | | // three disjoint ranges. It does that in a way such that the loop running in |
11 | | // the middle loop provably does not need range checks. As an example, it will |
12 | | // convert |
13 | | // |
14 | | // len = < known positive > |
15 | | // for (i = 0; i < n; i++) { |
16 | | // if (0 <= i && i < len) { |
17 | | // do_something(); |
18 | | // } else { |
19 | | // throw_out_of_bounds(); |
20 | | // } |
21 | | // } |
22 | | // |
23 | | // to |
24 | | // |
25 | | // len = < known positive > |
26 | | // limit = smin(n, len) |
27 | | // // no first segment |
28 | | // for (i = 0; i < limit; i++) { |
29 | | // if (0 <= i && i < len) { // this check is fully redundant |
30 | | // do_something(); |
31 | | // } else { |
32 | | // throw_out_of_bounds(); |
33 | | // } |
34 | | // } |
35 | | // for (i = limit; i < n; i++) { |
36 | | // if (0 <= i && i < len) { |
37 | | // do_something(); |
38 | | // } else { |
39 | | // throw_out_of_bounds(); |
40 | | // } |
41 | | // } |
42 | | // |
43 | | //===----------------------------------------------------------------------===// |
44 | | |
45 | | #include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h" |
46 | | #include "llvm/ADT/APInt.h" |
47 | | #include "llvm/ADT/ArrayRef.h" |
48 | | #include "llvm/ADT/None.h" |
49 | | #include "llvm/ADT/Optional.h" |
50 | | #include "llvm/ADT/SmallPtrSet.h" |
51 | | #include "llvm/ADT/SmallVector.h" |
52 | | #include "llvm/ADT/StringRef.h" |
53 | | #include "llvm/ADT/Twine.h" |
54 | | #include "llvm/Analysis/BranchProbabilityInfo.h" |
55 | | #include "llvm/Analysis/LoopAnalysisManager.h" |
56 | | #include "llvm/Analysis/LoopInfo.h" |
57 | | #include "llvm/Analysis/LoopPass.h" |
58 | | #include "llvm/Analysis/ScalarEvolution.h" |
59 | | #include "llvm/Analysis/ScalarEvolutionExpander.h" |
60 | | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
61 | | #include "llvm/IR/BasicBlock.h" |
62 | | #include "llvm/IR/CFG.h" |
63 | | #include "llvm/IR/Constants.h" |
64 | | #include "llvm/IR/DerivedTypes.h" |
65 | | #include "llvm/IR/Dominators.h" |
66 | | #include "llvm/IR/Function.h" |
67 | | #include "llvm/IR/IRBuilder.h" |
68 | | #include "llvm/IR/InstrTypes.h" |
69 | | #include "llvm/IR/Instructions.h" |
70 | | #include "llvm/IR/Metadata.h" |
71 | | #include "llvm/IR/Module.h" |
72 | | #include "llvm/IR/PatternMatch.h" |
73 | | #include "llvm/IR/Type.h" |
74 | | #include "llvm/IR/Use.h" |
75 | | #include "llvm/IR/User.h" |
76 | | #include "llvm/IR/Value.h" |
77 | | #include "llvm/Pass.h" |
78 | | #include "llvm/Support/BranchProbability.h" |
79 | | #include "llvm/Support/Casting.h" |
80 | | #include "llvm/Support/CommandLine.h" |
81 | | #include "llvm/Support/Compiler.h" |
82 | | #include "llvm/Support/Debug.h" |
83 | | #include "llvm/Support/ErrorHandling.h" |
84 | | #include "llvm/Support/raw_ostream.h" |
85 | | #include "llvm/Transforms/Scalar.h" |
86 | | #include "llvm/Transforms/Utils/Cloning.h" |
87 | | #include "llvm/Transforms/Utils/LoopSimplify.h" |
88 | | #include "llvm/Transforms/Utils/LoopUtils.h" |
89 | | #include "llvm/Transforms/Utils/ValueMapper.h" |
90 | | #include <algorithm> |
91 | | #include <cassert> |
92 | | #include <iterator> |
93 | | #include <limits> |
94 | | #include <utility> |
95 | | #include <vector> |
96 | | |
97 | | using namespace llvm; |
98 | | using namespace llvm::PatternMatch; |
99 | | |
100 | | static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, |
101 | | cl::init(64)); |
102 | | |
103 | | static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden, |
104 | | cl::init(false)); |
105 | | |
106 | | static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden, |
107 | | cl::init(false)); |
108 | | |
109 | | static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", |
110 | | cl::Hidden, cl::init(10)); |
111 | | |
112 | | static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", |
113 | | cl::Hidden, cl::init(false)); |
114 | | |
115 | | static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", |
116 | | cl::Hidden, cl::init(true)); |
117 | | |
118 | | static cl::opt<bool> AllowNarrowLatchCondition( |
119 | | "irce-allow-narrow-latch", cl::Hidden, cl::init(true), |
120 | | cl::desc("If set to true, IRCE may eliminate wide range checks in loops " |
121 | | "with narrow latch condition.")); |
122 | | |
123 | | static const char *ClonedLoopTag = "irce.loop.clone"; |
124 | | |
125 | | #define DEBUG_TYPE "irce" |
126 | | |
127 | | namespace { |
128 | | |
129 | | /// An inductive range check is conditional branch in a loop with |
130 | | /// |
131 | | /// 1. a very cold successor (i.e. the branch jumps to that successor very |
132 | | /// rarely) |
133 | | /// |
134 | | /// and |
135 | | /// |
136 | | /// 2. a condition that is provably true for some contiguous range of values |
137 | | /// taken by the containing loop's induction variable. |
138 | | /// |
139 | | class InductiveRangeCheck { |
140 | | |
141 | | const SCEV *Begin = nullptr; |
142 | | const SCEV *Step = nullptr; |
143 | | const SCEV *End = nullptr; |
144 | | Use *CheckUse = nullptr; |
145 | | bool IsSigned = true; |
146 | | |
147 | | static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, |
148 | | Value *&Index, Value *&Length, |
149 | | bool &IsSigned); |
150 | | |
151 | | static void |
152 | | extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, |
153 | | SmallVectorImpl<InductiveRangeCheck> &Checks, |
154 | | SmallPtrSetImpl<Value *> &Visited); |
155 | | |
156 | | public: |
157 | 508 | const SCEV *getBegin() const { return Begin; } |
158 | 253 | const SCEV *getStep() const { return Step; } |
159 | 249 | const SCEV *getEnd() const { return End; } |
160 | 0 | bool isSigned() const { return IsSigned; } |
161 | | |
162 | 20 | void print(raw_ostream &OS) const { |
163 | 20 | OS << "InductiveRangeCheck:\n"; |
164 | 20 | OS << " Begin: "; |
165 | 20 | Begin->print(OS); |
166 | 20 | OS << " Step: "; |
167 | 20 | Step->print(OS); |
168 | 20 | OS << " End: "; |
169 | 20 | End->print(OS); |
170 | 20 | OS << "\n CheckUse: "; |
171 | 20 | getCheckUse()->getUser()->print(OS); |
172 | 20 | OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; |
173 | 20 | } |
174 | | |
175 | | LLVM_DUMP_METHOD |
176 | 0 | void dump() { |
177 | 0 | print(dbgs()); |
178 | 0 | } |
179 | | |
180 | 269 | Use *getCheckUse() const { return CheckUse; } |
181 | | |
182 | | /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If |
183 | | /// R.getEnd() le R.getBegin(), then R denotes the empty range. |
184 | | |
185 | | class Range { |
186 | | const SCEV *Begin; |
187 | | const SCEV *End; |
188 | | |
189 | | public: |
190 | 292 | Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) { |
191 | 292 | assert(Begin->getType() == End->getType() && "ill-typed range!"); |
192 | 292 | } |
193 | | |
194 | 274 | Type *getType() const { return Begin->getType(); } |
195 | 690 | const SCEV *getBegin() const { return Begin; } |
196 | 454 | const SCEV *getEnd() const { return End; } |
197 | 292 | bool isEmpty(ScalarEvolution &SE, bool IsSigned) const { |
198 | 292 | if (Begin == End) |
199 | 18 | return true; |
200 | 274 | if (IsSigned) |
201 | 166 | return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End); |
202 | 108 | else |
203 | 108 | return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End); |
204 | 274 | } |
205 | | }; |
206 | | |
207 | | /// This is the value the condition of the branch needs to evaluate to for the |
208 | | /// branch to take the hot successor (see (1) above). |
209 | 229 | bool getPassingDirection() { return true; } |
210 | | |
211 | | /// Computes a range for the induction variable (IndVar) in which the range |
212 | | /// check is redundant and can be constant-folded away. The induction |
213 | | /// variable is not required to be the canonical {0,+,1} induction variable. |
214 | | Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, |
215 | | const SCEVAddRecExpr *IndVar, |
216 | | bool IsLatchSigned) const; |
217 | | |
218 | | /// Parse out a set of inductive range checks from \p BI and append them to \p |
219 | | /// Checks. |
220 | | /// |
221 | | /// NB! There may be conditions feeding into \p BI that aren't inductive range |
222 | | /// checks, and hence don't end up in \p Checks. |
223 | | static void |
224 | | extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, |
225 | | BranchProbabilityInfo *BPI, |
226 | | SmallVectorImpl<InductiveRangeCheck> &Checks); |
227 | | }; |
228 | | |
229 | | class InductiveRangeCheckElimination { |
230 | | ScalarEvolution &SE; |
231 | | BranchProbabilityInfo *BPI; |
232 | | DominatorTree &DT; |
233 | | LoopInfo &LI; |
234 | | |
235 | | public: |
236 | | InductiveRangeCheckElimination(ScalarEvolution &SE, |
237 | | BranchProbabilityInfo *BPI, DominatorTree &DT, |
238 | | LoopInfo &LI) |
239 | 477 | : SE(SE), BPI(BPI), DT(DT), LI(LI) {} |
240 | | |
241 | | bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); |
242 | | }; |
243 | | |
244 | | class IRCELegacyPass : public LoopPass { |
245 | | public: |
246 | | static char ID; |
247 | | |
248 | 31 | IRCELegacyPass() : LoopPass(ID) { |
249 | 31 | initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry()); |
250 | 31 | } |
251 | | |
252 | 31 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
253 | 31 | AU.addRequired<BranchProbabilityInfoWrapperPass>(); |
254 | 31 | getLoopAnalysisUsage(AU); |
255 | 31 | } |
256 | | |
257 | | bool runOnLoop(Loop *L, LPPassManager &LPM) override; |
258 | | }; |
259 | | |
260 | | } // end anonymous namespace |
261 | | |
262 | | char IRCELegacyPass::ID = 0; |
263 | | |
264 | 36.0k | INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce", |
265 | 36.0k | "Inductive range check elimination", false, false) |
266 | 36.0k | INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) |
267 | 36.0k | INITIALIZE_PASS_DEPENDENCY(LoopPass) |
268 | 36.0k | INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination", |
269 | | false, false) |
270 | | |
271 | | /// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot |
272 | | /// be interpreted as a range check, return false and set `Index` and `Length` |
273 | | /// to `nullptr`. Otherwise set `Index` to the value being range checked, and |
274 | | /// set `Length` to the upper limit `Index` is being range checked. |
275 | | bool |
276 | | InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, |
277 | | ScalarEvolution &SE, Value *&Index, |
278 | 341 | Value *&Length, bool &IsSigned) { |
279 | 341 | auto IsLoopInvariant = [&SE, L](Value *V) { |
280 | 321 | return SE.isLoopInvariant(SE.getSCEV(V), L); |
281 | 321 | }; |
282 | 341 | |
283 | 341 | ICmpInst::Predicate Pred = ICI->getPredicate(); |
284 | 341 | Value *LHS = ICI->getOperand(0); |
285 | 341 | Value *RHS = ICI->getOperand(1); |
286 | 341 | |
287 | 341 | switch (Pred) { |
288 | 341 | default: |
289 | 0 | return false; |
290 | 341 | |
291 | 341 | case ICmpInst::ICMP_SLE: |
292 | 0 | std::swap(LHS, RHS); |
293 | 0 | LLVM_FALLTHROUGH; |
294 | 10 | case ICmpInst::ICMP_SGE: |
295 | 10 | IsSigned = true; |
296 | 10 | if (match(RHS, m_ConstantInt<0>())) { |
297 | 10 | Index = LHS; |
298 | 10 | return true; // Lower. |
299 | 10 | } |
300 | 0 | return false; |
301 | 0 |
|
302 | 199 | case ICmpInst::ICMP_SLT: |
303 | 199 | std::swap(LHS, RHS); |
304 | 199 | LLVM_FALLTHROUGH; |
305 | 209 | case ICmpInst::ICMP_SGT: |
306 | 209 | IsSigned = true; |
307 | 209 | if (match(RHS, m_ConstantInt<-1>())) { |
308 | 10 | Index = LHS; |
309 | 10 | return true; // Lower. |
310 | 10 | } |
311 | 199 | |
312 | 199 | if (IsLoopInvariant(LHS)) { |
313 | 173 | Index = RHS; |
314 | 173 | Length = LHS; |
315 | 173 | return true; // Upper. |
316 | 173 | } |
317 | 26 | return false; |
318 | 26 | |
319 | 122 | case ICmpInst::ICMP_ULT: |
320 | 122 | std::swap(LHS, RHS); |
321 | 122 | LLVM_FALLTHROUGH; |
322 | 122 | case ICmpInst::ICMP_UGT: |
323 | 122 | IsSigned = false; |
324 | 122 | if (IsLoopInvariant(LHS)) { |
325 | 120 | Index = RHS; |
326 | 120 | Length = LHS; |
327 | 120 | return true; // Both lower and upper. |
328 | 120 | } |
329 | 2 | return false; |
330 | 0 | } |
331 | 0 | |
332 | 0 | llvm_unreachable("default clause returns!"); |
333 | 0 | } |
334 | | |
335 | | void InductiveRangeCheck::extractRangeChecksFromCond( |
336 | | Loop *L, ScalarEvolution &SE, Use &ConditionUse, |
337 | | SmallVectorImpl<InductiveRangeCheck> &Checks, |
338 | 415 | SmallPtrSetImpl<Value *> &Visited) { |
339 | 415 | Value *Condition = ConditionUse.get(); |
340 | 415 | if (!Visited.insert(Condition).second) |
341 | 0 | return; |
342 | 415 | |
343 | 415 | // TODO: Do the same for OR, XOR, NOT etc? |
344 | 415 | if (match(Condition, m_And(m_Value(), m_Value()))) { |
345 | 42 | extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0), |
346 | 42 | Checks, Visited); |
347 | 42 | extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1), |
348 | 42 | Checks, Visited); |
349 | 42 | return; |
350 | 42 | } |
351 | 373 | |
352 | 373 | ICmpInst *ICI = dyn_cast<ICmpInst>(Condition); |
353 | 373 | if (!ICI) |
354 | 32 | return; |
355 | 341 | |
356 | 341 | Value *Length = nullptr, *Index; |
357 | 341 | bool IsSigned; |
358 | 341 | if (!parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned)) |
359 | 28 | return; |
360 | 313 | |
361 | 313 | const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index)); |
362 | 313 | bool IsAffineIndex = |
363 | 313 | IndexAddRec && (IndexAddRec->getLoop() == L)311 && IndexAddRec->isAffine()293 ; |
364 | 313 | |
365 | 313 | if (!IsAffineIndex) |
366 | 20 | return; |
367 | 293 | |
368 | 293 | const SCEV *End = nullptr; |
369 | 293 | // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". |
370 | 293 | // We can potentially do much better here. |
371 | 293 | if (Length) |
372 | 273 | End = SE.getSCEV(Length); |
373 | 20 | else { |
374 | 20 | // So far we can only reach this point for Signed range check. This may |
375 | 20 | // change in future. In this case we will need to pick Unsigned max for the |
376 | 20 | // unsigned range check. |
377 | 20 | unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth(); |
378 | 20 | const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); |
379 | 20 | End = SIntMax; |
380 | 20 | } |
381 | 293 | |
382 | 293 | InductiveRangeCheck IRC; |
383 | 293 | IRC.End = End; |
384 | 293 | IRC.Begin = IndexAddRec->getStart(); |
385 | 293 | IRC.Step = IndexAddRec->getStepRecurrence(SE); |
386 | 293 | IRC.CheckUse = &ConditionUse; |
387 | 293 | IRC.IsSigned = IsSigned; |
388 | 293 | Checks.push_back(IRC); |
389 | 293 | } |
390 | | |
391 | | void InductiveRangeCheck::extractRangeChecksFromBranch( |
392 | | BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, |
393 | 1.43k | SmallVectorImpl<InductiveRangeCheck> &Checks) { |
394 | 1.43k | if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()1.15k ) |
395 | 757 | return; |
396 | 681 | |
397 | 681 | BranchProbability LikelyTaken(15, 16); |
398 | 681 | |
399 | 681 | if (!SkipProfitabilityChecks && BPI658 && |
400 | 681 | BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken658 ) |
401 | 350 | return; |
402 | 331 | |
403 | 331 | SmallPtrSet<Value *, 8> Visited; |
404 | 331 | InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0), |
405 | 331 | Checks, Visited); |
406 | 331 | } |
407 | | |
408 | | // Add metadata to the loop L to disable loop optimizations. Callers need to |
409 | | // confirm that optimizing loop L is not beneficial. |
410 | 203 | static void DisableAllLoopOptsOnLoop(Loop &L) { |
411 | 203 | // We do not care about any existing loopID related metadata for L, since we |
412 | 203 | // are setting all loop metadata to false. |
413 | 203 | LLVMContext &Context = L.getHeader()->getContext(); |
414 | 203 | // Reserve first location for self reference to the LoopID metadata node. |
415 | 203 | MDNode *Dummy = MDNode::get(Context, {}); |
416 | 203 | MDNode *DisableUnroll = MDNode::get( |
417 | 203 | Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); |
418 | 203 | Metadata *FalseVal = |
419 | 203 | ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); |
420 | 203 | MDNode *DisableVectorize = MDNode::get( |
421 | 203 | Context, |
422 | 203 | {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); |
423 | 203 | MDNode *DisableLICMVersioning = MDNode::get( |
424 | 203 | Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); |
425 | 203 | MDNode *DisableDistribution= MDNode::get( |
426 | 203 | Context, |
427 | 203 | {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); |
428 | 203 | MDNode *NewLoopID = |
429 | 203 | MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, |
430 | 203 | DisableLICMVersioning, DisableDistribution}); |
431 | 203 | // Set operand 0 to refer to the loop id itself. |
432 | 203 | NewLoopID->replaceOperandWith(0, NewLoopID); |
433 | 203 | L.setLoopID(NewLoopID); |
434 | 203 | } |
435 | | |
436 | | namespace { |
437 | | |
438 | | // Keeps track of the structure of a loop. This is similar to llvm::Loop, |
439 | | // except that it is more lightweight and can track the state of a loop through |
440 | | // changing and potentially invalid IR. This structure also formalizes the |
441 | | // kinds of loops we can deal with -- ones that have a single latch that is also |
442 | | // an exiting block *and* have a canonical induction variable. |
443 | | struct LoopStructure { |
444 | | const char *Tag = ""; |
445 | | |
446 | | BasicBlock *Header = nullptr; |
447 | | BasicBlock *Latch = nullptr; |
448 | | |
449 | | // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th |
450 | | // successor is `LatchExit', the exit block of the loop. |
451 | | BranchInst *LatchBr = nullptr; |
452 | | BasicBlock *LatchExit = nullptr; |
453 | | unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max(); |
454 | | |
455 | | // The loop represented by this instance of LoopStructure is semantically |
456 | | // equivalent to: |
457 | | // |
458 | | // intN_ty inc = IndVarIncreasing ? 1 : -1; |
459 | | // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; |
460 | | // |
461 | | // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase) |
462 | | // ... body ... |
463 | | |
464 | | Value *IndVarBase = nullptr; |
465 | | Value *IndVarStart = nullptr; |
466 | | Value *IndVarStep = nullptr; |
467 | | Value *LoopExitAt = nullptr; |
468 | | bool IndVarIncreasing = false; |
469 | | bool IsSignedPredicate = true; |
470 | | |
471 | 783 | LoopStructure() = default; |
472 | | |
473 | 203 | template <typename M> LoopStructure map(M Map) const { |
474 | 203 | LoopStructure Result; |
475 | 203 | Result.Tag = Tag; |
476 | 203 | Result.Header = cast<BasicBlock>(Map(Header)); |
477 | 203 | Result.Latch = cast<BasicBlock>(Map(Latch)); |
478 | 203 | Result.LatchBr = cast<BranchInst>(Map(LatchBr)); |
479 | 203 | Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); |
480 | 203 | Result.LatchBrExitIdx = LatchBrExitIdx; |
481 | 203 | Result.IndVarBase = Map(IndVarBase); |
482 | 203 | Result.IndVarStart = Map(IndVarStart); |
483 | 203 | Result.IndVarStep = Map(IndVarStep); |
484 | 203 | Result.LoopExitAt = Map(LoopExitAt); |
485 | 203 | Result.IndVarIncreasing = IndVarIncreasing; |
486 | 203 | Result.IsSignedPredicate = IsSignedPredicate; |
487 | 203 | return Result; |
488 | 203 | } |
489 | | |
490 | | static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, |
491 | | BranchProbabilityInfo *BPI, |
492 | | Loop &, const char *&); |
493 | | }; |
494 | | |
495 | | /// This class is used to constrain loops to run within a given iteration space. |
496 | | /// The algorithm this class implements is given a Loop and a range [Begin, |
497 | | /// End). The algorithm then tries to break out a "main loop" out of the loop |
498 | | /// it is given in a way that the "main loop" runs with the induction variable |
499 | | /// in a subset of [Begin, End). The algorithm emits appropriate pre and post |
500 | | /// loops to run any remaining iterations. The pre loop runs any iterations in |
501 | | /// which the induction variable is < Begin, and the post loop runs any |
502 | | /// iterations in which the induction variable is >= End. |
503 | | class LoopConstrainer { |
504 | | // The representation of a clone of the original loop we started out with. |
505 | | struct ClonedLoop { |
506 | | // The cloned blocks |
507 | | std::vector<BasicBlock *> Blocks; |
508 | | |
509 | | // `Map` maps values in the clonee into values in the cloned version |
510 | | ValueToValueMapTy Map; |
511 | | |
512 | | // An instance of `LoopStructure` for the cloned loop |
513 | | LoopStructure Structure; |
514 | | }; |
515 | | |
516 | | // Result of rewriting the range of a loop. See changeIterationSpaceEnd for |
517 | | // more details on what these fields mean. |
518 | | struct RewrittenRangeInfo { |
519 | | BasicBlock *PseudoExit = nullptr; |
520 | | BasicBlock *ExitSelector = nullptr; |
521 | | std::vector<PHINode *> PHIValuesAtPseudoExit; |
522 | | PHINode *IndVarEnd = nullptr; |
523 | | |
524 | 575 | RewrittenRangeInfo() = default; |
525 | | }; |
526 | | |
527 | | // Calculated subranges we restrict the iteration space of the main loop to. |
528 | | // See the implementation of `calculateSubRanges' for more details on how |
529 | | // these fields are computed. `LowLimit` is None if there is no restriction |
530 | | // on low end of the restricted iteration space of the main loop. `HighLimit` |
531 | | // is None if there is no restriction on high end of the restricted iteration |
532 | | // space of the main loop. |
533 | | |
534 | | struct SubRanges { |
535 | | Optional<const SCEV *> LowLimit; |
536 | | Optional<const SCEV *> HighLimit; |
537 | | }; |
538 | | |
539 | | // Compute a safe set of limits for the main loop to run in -- effectively the |
540 | | // intersection of `Range' and the iteration space of the original loop. |
541 | | // Return None if unable to compute the set of subranges. |
542 | | Optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; |
543 | | |
544 | | // Clone `OriginalLoop' and return the result in CLResult. The IR after |
545 | | // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- |
546 | | // the PHI nodes say that there is an incoming edge from `OriginalPreheader` |
547 | | // but there is no such edge. |
548 | | void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; |
549 | | |
550 | | // Create the appropriate loop structure needed to describe a cloned copy of |
551 | | // `Original`. The clone is described by `VM`. |
552 | | Loop *createClonedLoopStructure(Loop *Original, Loop *Parent, |
553 | | ValueToValueMapTy &VM, bool IsSubloop); |
554 | | |
555 | | // Rewrite the iteration space of the loop denoted by (LS, Preheader). The |
556 | | // iteration space of the rewritten loop ends at ExitLoopAt. The start of the |
557 | | // iteration space is not changed. `ExitLoopAt' is assumed to be slt |
558 | | // `OriginalHeaderCount'. |
559 | | // |
560 | | // If there are iterations left to execute, control is made to jump to |
561 | | // `ContinuationBlock', otherwise they take the normal loop exit. The |
562 | | // returned `RewrittenRangeInfo' object is populated as follows: |
563 | | // |
564 | | // .PseudoExit is a basic block that unconditionally branches to |
565 | | // `ContinuationBlock'. |
566 | | // |
567 | | // .ExitSelector is a basic block that decides, on exit from the loop, |
568 | | // whether to branch to the "true" exit or to `PseudoExit'. |
569 | | // |
570 | | // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value |
571 | | // for each PHINode in the loop header on taking the pseudo exit. |
572 | | // |
573 | | // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate |
574 | | // preheader because it is made to branch to the loop header only |
575 | | // conditionally. |
576 | | RewrittenRangeInfo |
577 | | changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, |
578 | | Value *ExitLoopAt, |
579 | | BasicBlock *ContinuationBlock) const; |
580 | | |
581 | | // The loop denoted by `LS' has `OldPreheader' as its preheader. This |
582 | | // function creates a new preheader for `LS' and returns it. |
583 | | BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, |
584 | | const char *Tag) const; |
585 | | |
586 | | // `ContinuationBlockAndPreheader' was the continuation block for some call to |
587 | | // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. |
588 | | // This function rewrites the PHI nodes in `LS.Header' to start with the |
589 | | // correct value. |
590 | | void rewriteIncomingValuesForPHIs( |
591 | | LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader, |
592 | | const LoopConstrainer::RewrittenRangeInfo &RRI) const; |
593 | | |
594 | | // Even though we do not preserve any passes at this time, we at least need to |
595 | | // keep the parent loop structure consistent. The `LPPassManager' seems to |
596 | | // verify this after running a loop pass. This function adds the list of |
597 | | // blocks denoted by BBs to this loops parent loop if required. |
598 | | void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs); |
599 | | |
600 | | // Some global state. |
601 | | Function &F; |
602 | | LLVMContext &Ctx; |
603 | | ScalarEvolution &SE; |
604 | | DominatorTree &DT; |
605 | | LoopInfo &LI; |
606 | | function_ref<void(Loop *, bool)> LPMAddNewLoop; |
607 | | |
608 | | // Information about the original loop we started out with. |
609 | | Loop &OriginalLoop; |
610 | | |
611 | | const SCEV *LatchTakenCount = nullptr; |
612 | | BasicBlock *OriginalPreheader = nullptr; |
613 | | |
614 | | // The preheader of the main loop. This may or may not be different from |
615 | | // `OriginalPreheader'. |
616 | | BasicBlock *MainLoopPreheader = nullptr; |
617 | | |
618 | | // The range we need to run the main loop in. |
619 | | InductiveRangeCheck::Range Range; |
620 | | |
621 | | // The structure of the main loop (see comment at the beginning of this class |
622 | | // for a definition) |
623 | | LoopStructure MainLoopStructure; |
624 | | |
625 | | public: |
626 | | LoopConstrainer(Loop &L, LoopInfo &LI, |
627 | | function_ref<void(Loop *, bool)> LPMAddNewLoop, |
628 | | const LoopStructure &LS, ScalarEvolution &SE, |
629 | | DominatorTree &DT, InductiveRangeCheck::Range R) |
630 | | : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), |
631 | | SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), |
632 | 188 | Range(R), MainLoopStructure(LS) {} |
633 | | |
634 | | // Entry point for the algorithm. Returns true on success. |
635 | | bool run(); |
636 | | }; |
637 | | |
638 | | } // end anonymous namespace |
639 | | |
640 | | /// Given a loop with an deccreasing induction variable, is it possible to |
641 | | /// safely calculate the bounds of a new loop using the given Predicate. |
642 | | static bool isSafeDecreasingBound(const SCEV *Start, |
643 | | const SCEV *BoundSCEV, const SCEV *Step, |
644 | | ICmpInst::Predicate Pred, |
645 | | unsigned LatchBrExitIdx, |
646 | 40 | Loop *L, ScalarEvolution &SE) { |
647 | 40 | if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT29 && |
648 | 40 | Pred != ICmpInst::ICMP_ULT16 && Pred != ICmpInst::ICMP_UGT10 ) |
649 | 0 | return false; |
650 | 40 | |
651 | 40 | if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) |
652 | 0 | return false; |
653 | 40 | |
654 | 40 | assert(SE.isKnownNegative(Step) && "expecting negative step"); |
655 | 40 | |
656 | 40 | LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n"); |
657 | 40 | LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); |
658 | 40 | LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); |
659 | 40 | LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); |
660 | 40 | LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) |
661 | 40 | << "\n"); |
662 | 40 | LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); |
663 | 40 | |
664 | 40 | bool IsSigned = ICmpInst::isSigned(Pred); |
665 | 40 | // The predicate that we need to check that the induction variable lies |
666 | 40 | // within bounds. |
667 | 40 | ICmpInst::Predicate BoundPred = |
668 | 40 | IsSigned ? CmpInst::ICMP_SGT24 : CmpInst::ICMP_UGT16 ; |
669 | 40 | |
670 | 40 | if (LatchBrExitIdx == 1) |
671 | 23 | return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); |
672 | 17 | |
673 | 17 | assert(LatchBrExitIdx == 0 && |
674 | 17 | "LatchBrExitIdx should be either 0 or 1"); |
675 | 17 | |
676 | 17 | const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); |
677 | 17 | unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); |
678 | 17 | APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)11 : |
679 | 17 | APInt::getMinValue(BitWidth)6 ; |
680 | 17 | const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); |
681 | 17 | |
682 | 17 | const SCEV *MinusOne = |
683 | 17 | SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); |
684 | 17 | |
685 | 17 | return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && |
686 | 17 | SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); |
687 | 17 | |
688 | 17 | } |
689 | | |
690 | | /// Given a loop with an increasing induction variable, is it possible to |
691 | | /// safely calculate the bounds of a new loop using the given Predicate. |
692 | | static bool isSafeIncreasingBound(const SCEV *Start, |
693 | | const SCEV *BoundSCEV, const SCEV *Step, |
694 | | ICmpInst::Predicate Pred, |
695 | | unsigned LatchBrExitIdx, |
696 | 171 | Loop *L, ScalarEvolution &SE) { |
697 | 171 | if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT68 && |
698 | 171 | Pred != ICmpInst::ICMP_ULT68 && Pred != ICmpInst::ICMP_UGT12 ) |
699 | 0 | return false; |
700 | 171 | |
701 | 171 | if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) |
702 | 0 | return false; |
703 | 171 | |
704 | 171 | LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n"); |
705 | 171 | LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); |
706 | 171 | LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); |
707 | 171 | LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); |
708 | 171 | LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) |
709 | 171 | << "\n"); |
710 | 171 | LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); |
711 | 171 | |
712 | 171 | bool IsSigned = ICmpInst::isSigned(Pred); |
713 | 171 | // The predicate that we need to check that the induction variable lies |
714 | 171 | // within bounds. |
715 | 171 | ICmpInst::Predicate BoundPred = |
716 | 171 | IsSigned ? CmpInst::ICMP_SLT103 : CmpInst::ICMP_ULT68 ; |
717 | 171 | |
718 | 171 | if (LatchBrExitIdx == 1) |
719 | 159 | return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); |
720 | 12 | |
721 | 12 | assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); |
722 | 12 | |
723 | 12 | const SCEV *StepMinusOne = |
724 | 12 | SE.getMinusSCEV(Step, SE.getOne(Step->getType())); |
725 | 12 | unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); |
726 | 12 | APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)0 : |
727 | 12 | APInt::getMaxValue(BitWidth); |
728 | 12 | const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); |
729 | 12 | |
730 | 12 | return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, |
731 | 12 | SE.getAddExpr(BoundSCEV, Step)) && |
732 | 12 | SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); |
733 | 12 | } |
734 | | |
735 | | Optional<LoopStructure> |
736 | | LoopStructure::parseLoopStructure(ScalarEvolution &SE, |
737 | | BranchProbabilityInfo *BPI, Loop &L, |
738 | 242 | const char *&FailureReason) { |
739 | 242 | if (!L.isLoopSimplifyForm()) { |
740 | 3 | FailureReason = "loop not in LoopSimplify form"; |
741 | 3 | return None; |
742 | 3 | } |
743 | 239 | |
744 | 239 | BasicBlock *Latch = L.getLoopLatch(); |
745 | 239 | assert(Latch && "Simplified loops only have one latch!"); |
746 | 239 | |
747 | 239 | if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { |
748 | 14 | FailureReason = "loop has already been cloned"; |
749 | 14 | return None; |
750 | 14 | } |
751 | 225 | |
752 | 225 | if (!L.isLoopExiting(Latch)) { |
753 | 0 | FailureReason = "no loop latch"; |
754 | 0 | return None; |
755 | 0 | } |
756 | 225 | |
757 | 225 | BasicBlock *Header = L.getHeader(); |
758 | 225 | BasicBlock *Preheader = L.getLoopPreheader(); |
759 | 225 | if (!Preheader) { |
760 | 0 | FailureReason = "no preheader"; |
761 | 0 | return None; |
762 | 0 | } |
763 | 225 | |
764 | 225 | BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); |
765 | 225 | if (!LatchBr || LatchBr->isUnconditional()) { |
766 | 0 | FailureReason = "latch terminator not conditional branch"; |
767 | 0 | return None; |
768 | 0 | } |
769 | 225 | |
770 | 225 | unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1187 : 038 ; |
771 | 225 | |
772 | 225 | BranchProbability ExitProbability = |
773 | 225 | BPI ? BPI->getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx) |
774 | 225 | : BranchProbability::getZero()0 ; |
775 | 225 | |
776 | 225 | if (!SkipProfitabilityChecks && |
777 | 225 | ExitProbability > BranchProbability(1, MaxExitProbReciprocal)214 ) { |
778 | 1 | FailureReason = "short running loop, not profitable"; |
779 | 1 | return None; |
780 | 1 | } |
781 | 224 | |
782 | 224 | ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); |
783 | 224 | if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { |
784 | 0 | FailureReason = "latch terminator branch not conditional on integral icmp"; |
785 | 0 | return None; |
786 | 0 | } |
787 | 224 | |
788 | 224 | const SCEV *LatchCount = SE.getExitCount(&L, Latch); |
789 | 224 | if (isa<SCEVCouldNotCompute>(LatchCount)) { |
790 | 7 | FailureReason = "could not compute latch count"; |
791 | 7 | return None; |
792 | 7 | } |
793 | 217 | |
794 | 217 | ICmpInst::Predicate Pred = ICI->getPredicate(); |
795 | 217 | Value *LeftValue = ICI->getOperand(0); |
796 | 217 | const SCEV *LeftSCEV = SE.getSCEV(LeftValue); |
797 | 217 | IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); |
798 | 217 | |
799 | 217 | Value *RightValue = ICI->getOperand(1); |
800 | 217 | const SCEV *RightSCEV = SE.getSCEV(RightValue); |
801 | 217 | |
802 | 217 | // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. |
803 | 217 | if (!isa<SCEVAddRecExpr>(LeftSCEV)) { |
804 | 0 | if (isa<SCEVAddRecExpr>(RightSCEV)) { |
805 | 0 | std::swap(LeftSCEV, RightSCEV); |
806 | 0 | std::swap(LeftValue, RightValue); |
807 | 0 | Pred = ICmpInst::getSwappedPredicate(Pred); |
808 | 0 | } else { |
809 | 0 | FailureReason = "no add recurrences in the icmp"; |
810 | 0 | return None; |
811 | 0 | } |
812 | 217 | } |
813 | 217 | |
814 | 217 | auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { |
815 | 33 | if (AR->getNoWrapFlags(SCEV::FlagNSW)) |
816 | 7 | return true; |
817 | 26 | |
818 | 26 | IntegerType *Ty = cast<IntegerType>(AR->getType()); |
819 | 26 | IntegerType *WideTy = |
820 | 26 | IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); |
821 | 26 | |
822 | 26 | const SCEVAddRecExpr *ExtendAfterOp = |
823 | 26 | dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); |
824 | 26 | if (ExtendAfterOp) { |
825 | 22 | const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); |
826 | 22 | const SCEV *ExtendedStep = |
827 | 22 | SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); |
828 | 22 | |
829 | 22 | bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && |
830 | 22 | ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; |
831 | 22 | |
832 | 22 | if (NoSignedWrap) |
833 | 22 | return true; |
834 | 4 | } |
835 | 4 | |
836 | 4 | // We may have proved this when computing the sign extension above. |
837 | 4 | return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; |
838 | 4 | }; |
839 | 217 | |
840 | 217 | // `ICI` is interpreted as taking the backedge if the *next* value of the |
841 | 217 | // induction variable satisfies some constraint. |
842 | 217 | |
843 | 217 | const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); |
844 | 217 | if (!IndVarBase->isAffine()) { |
845 | 0 | FailureReason = "LHS in icmp not induction variable"; |
846 | 0 | return None; |
847 | 0 | } |
848 | 217 | const SCEV* StepRec = IndVarBase->getStepRecurrence(SE); |
849 | 217 | if (!isa<SCEVConstant>(StepRec)) { |
850 | 0 | FailureReason = "LHS in icmp not induction variable"; |
851 | 0 | return None; |
852 | 0 | } |
853 | 217 | ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); |
854 | 217 | |
855 | 217 | if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)33 ) { |
856 | 4 | FailureReason = "LHS in icmp needs nsw for equality predicates"; |
857 | 4 | return None; |
858 | 4 | } |
859 | 213 | |
860 | 213 | assert(!StepCI->isZero() && "Zero step?"); |
861 | 213 | bool IsIncreasing = !StepCI->isNegative(); |
862 | 213 | bool IsSignedPredicate; |
863 | 213 | const SCEV *StartNext = IndVarBase->getStart(); |
864 | 213 | const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); |
865 | 213 | const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); |
866 | 213 | const SCEV *Step = SE.getSCEV(StepCI); |
867 | 213 | |
868 | 213 | ConstantInt *One = ConstantInt::get(IndVarTy, 1); |
869 | 213 | if (IsIncreasing) { |
870 | 171 | bool DecreasedRightValueByOne = false; |
871 | 171 | if (StepCI->isOne()) { |
872 | 159 | // Try to turn eq/ne predicates to those we can work with. |
873 | 159 | if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 115 ) |
874 | 15 | // while (++i != len) { while (++i < len) { |
875 | 15 | // ... ---> ... |
876 | 15 | // } } |
877 | 15 | // If both parts are known non-negative, it is profitable to use |
878 | 15 | // unsigned comparison in increasing loop. This allows us to make the |
879 | 15 | // comparison check against "RightSCEV + 1" more optimistic. |
880 | 15 | if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && |
881 | 15 | isKnownNonNegativeInLoop(RightSCEV, &L, SE)5 ) |
882 | 3 | Pred = ICmpInst::ICMP_ULT; |
883 | 12 | else |
884 | 12 | Pred = ICmpInst::ICMP_SLT; |
885 | 144 | else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 06 ) { |
886 | 6 | // while (true) { while (true) { |
887 | 6 | // if (++i == len) ---> if (++i > len - 1) |
888 | 6 | // break; break; |
889 | 6 | // ... ... |
890 | 6 | // } } |
891 | 6 | if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && |
892 | 6 | cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) { |
893 | 6 | Pred = ICmpInst::ICMP_UGT; |
894 | 6 | RightSCEV = SE.getMinusSCEV(RightSCEV, |
895 | 6 | SE.getOne(RightSCEV->getType())); |
896 | 6 | DecreasedRightValueByOne = true; |
897 | 6 | } else if (0 cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)0 ) { |
898 | 0 | Pred = ICmpInst::ICMP_SGT; |
899 | 0 | RightSCEV = SE.getMinusSCEV(RightSCEV, |
900 | 0 | SE.getOne(RightSCEV->getType())); |
901 | 0 | DecreasedRightValueByOne = true; |
902 | 0 | } |
903 | 6 | } |
904 | 159 | } |
905 | 171 | |
906 | 171 | bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT68 ); |
907 | 171 | bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); |
908 | 171 | bool FoundExpectedPred = |
909 | 171 | (LTPred && LatchBrExitIdx == 1159 ) || (12 GTPred12 && LatchBrExitIdx == 012 ); |
910 | 171 | |
911 | 171 | if (!FoundExpectedPred) { |
912 | 0 | FailureReason = "expected icmp slt semantically, found something else"; |
913 | 0 | return None; |
914 | 0 | } |
915 | 171 | |
916 | 171 | IsSignedPredicate = ICmpInst::isSigned(Pred); |
917 | 171 | if (!IsSignedPredicate && !AllowUnsignedLatchCondition68 ) { |
918 | 0 | FailureReason = "unsigned latch conditions are explicitly prohibited"; |
919 | 0 | return None; |
920 | 0 | } |
921 | 171 | |
922 | 171 | if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, |
923 | 171 | LatchBrExitIdx, &L, SE)) { |
924 | 6 | FailureReason = "Unsafe loop bounds"; |
925 | 6 | return None; |
926 | 6 | } |
927 | 165 | if (LatchBrExitIdx == 0) { |
928 | 12 | // We need to increase the right value unless we have already decreased |
929 | 12 | // it virtually when we replaced EQ with SGT. |
930 | 12 | if (!DecreasedRightValueByOne) { |
931 | 6 | IRBuilder<> B(Preheader->getTerminator()); |
932 | 6 | RightValue = B.CreateAdd(RightValue, One); |
933 | 6 | } |
934 | 153 | } else { |
935 | 153 | assert(!DecreasedRightValueByOne && |
936 | 153 | "Right value can be decreased only for LatchBrExitIdx == 0!"); |
937 | 153 | } |
938 | 165 | } else { |
939 | 42 | bool IncreasedRightValueByOne = false; |
940 | 42 | if (StepCI->isMinusOne()) { |
941 | 36 | // Try to turn eq/ne predicates to those we can work with. |
942 | 36 | if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 13 ) |
943 | 3 | // while (--i != len) { while (--i > len) { |
944 | 3 | // ... ---> ... |
945 | 3 | // } } |
946 | 3 | // We intentionally don't turn the predicate into UGT even if we know |
947 | 3 | // that both operands are non-negative, because it will only pessimize |
948 | 3 | // our check against "RightSCEV - 1". |
949 | 3 | Pred = ICmpInst::ICMP_SGT; |
950 | 33 | else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 05 ) { |
951 | 5 | // while (true) { while (true) { |
952 | 5 | // if (--i == len) ---> if (--i < len + 1) |
953 | 5 | // break; break; |
954 | 5 | // ... ... |
955 | 5 | // } } |
956 | 5 | if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && |
957 | 5 | cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)0 ) { |
958 | 0 | Pred = ICmpInst::ICMP_ULT; |
959 | 0 | RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); |
960 | 0 | IncreasedRightValueByOne = true; |
961 | 5 | } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { |
962 | 3 | Pred = ICmpInst::ICMP_SLT; |
963 | 3 | RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); |
964 | 3 | IncreasedRightValueByOne = true; |
965 | 3 | } |
966 | 5 | } |
967 | 36 | } |
968 | 42 | |
969 | 42 | bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT31 ); |
970 | 42 | bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT29 ); |
971 | 42 | |
972 | 42 | bool FoundExpectedPred = |
973 | 42 | (GTPred && LatchBrExitIdx == 123 ) || (19 LTPred19 && LatchBrExitIdx == 017 ); |
974 | 42 | |
975 | 42 | if (!FoundExpectedPred) { |
976 | 2 | FailureReason = "expected icmp sgt semantically, found something else"; |
977 | 2 | return None; |
978 | 2 | } |
979 | 40 | |
980 | 40 | IsSignedPredicate = |
981 | 40 | Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT29 ; |
982 | 40 | |
983 | 40 | if (!IsSignedPredicate && !AllowUnsignedLatchCondition16 ) { |
984 | 0 | FailureReason = "unsigned latch conditions are explicitly prohibited"; |
985 | 0 | return None; |
986 | 0 | } |
987 | 40 | |
988 | 40 | if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, |
989 | 40 | LatchBrExitIdx, &L, SE)) { |
990 | 1 | FailureReason = "Unsafe bounds"; |
991 | 1 | return None; |
992 | 1 | } |
993 | 39 | |
994 | 39 | if (LatchBrExitIdx == 0) { |
995 | 16 | // We need to decrease the right value unless we have already increased |
996 | 16 | // it virtually when we replaced EQ with SLT. |
997 | 16 | if (!IncreasedRightValueByOne) { |
998 | 14 | IRBuilder<> B(Preheader->getTerminator()); |
999 | 14 | RightValue = B.CreateSub(RightValue, One); |
1000 | 14 | } |
1001 | 23 | } else { |
1002 | 23 | assert(!IncreasedRightValueByOne && |
1003 | 23 | "Right value can be increased only for LatchBrExitIdx == 0!"); |
1004 | 23 | } |
1005 | 39 | } |
1006 | 213 | BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); |
1007 | 204 | |
1008 | 204 | assert(SE.getLoopDisposition(LatchCount, &L) == |
1009 | 204 | ScalarEvolution::LoopInvariant && |
1010 | 204 | "loop variant exit count doesn't make sense!"); |
1011 | 204 | |
1012 | 204 | assert(!L.contains(LatchExit) && "expected an exit block!"); |
1013 | 204 | const DataLayout &DL = Preheader->getModule()->getDataLayout(); |
1014 | 204 | Value *IndVarStartV = |
1015 | 204 | SCEVExpander(SE, DL, "irce") |
1016 | 204 | .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator()); |
1017 | 204 | IndVarStartV->setName("indvar.start"); |
1018 | 204 | |
1019 | 204 | LoopStructure Result; |
1020 | 204 | |
1021 | 204 | Result.Tag = "main"; |
1022 | 204 | Result.Header = Header; |
1023 | 204 | Result.Latch = Latch; |
1024 | 204 | Result.LatchBr = LatchBr; |
1025 | 204 | Result.LatchExit = LatchExit; |
1026 | 204 | Result.LatchBrExitIdx = LatchBrExitIdx; |
1027 | 204 | Result.IndVarStart = IndVarStartV; |
1028 | 204 | Result.IndVarStep = StepCI; |
1029 | 204 | Result.IndVarBase = LeftValue; |
1030 | 204 | Result.IndVarIncreasing = IsIncreasing; |
1031 | 204 | Result.LoopExitAt = RightValue; |
1032 | 204 | Result.IsSignedPredicate = IsSignedPredicate; |
1033 | 204 | |
1034 | 204 | FailureReason = nullptr; |
1035 | 204 | |
1036 | 204 | return Result; |
1037 | 213 | } |
1038 | | |
1039 | | /// If the type of \p S matches with \p Ty, return \p S. Otherwise, return |
1040 | | /// signed or unsigned extension of \p S to type \p Ty. |
1041 | | static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, |
1042 | 882 | bool Signed) { |
1043 | 882 | return Signed ? SE.getNoopOrSignExtend(S, Ty)532 : SE.getNoopOrZeroExtend(S, Ty)350 ; |
1044 | 882 | } |
1045 | | |
1046 | | Optional<LoopConstrainer::SubRanges> |
1047 | 188 | LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { |
1048 | 188 | IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); |
1049 | 188 | |
1050 | 188 | auto *RTy = cast<IntegerType>(Range.getType()); |
1051 | 188 | |
1052 | 188 | // We only support wide range checks and narrow latches. |
1053 | 188 | if (!AllowNarrowLatchCondition && RTy != Ty0 ) |
1054 | 0 | return None; |
1055 | 188 | if (RTy->getBitWidth() < Ty->getBitWidth()) |
1056 | 0 | return None; |
1057 | 188 | |
1058 | 188 | LoopConstrainer::SubRanges Result; |
1059 | 188 | |
1060 | 188 | // I think we can be more aggressive here and make this nuw / nsw if the |
1061 | 188 | // addition that feeds into the icmp for the latch's terminating branch is nuw |
1062 | 188 | // / nsw. In any case, a wrapping 2's complement addition is safe. |
1063 | 188 | const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart), |
1064 | 188 | RTy, SE, IsSignedPredicate); |
1065 | 188 | const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy, |
1066 | 188 | SE, IsSignedPredicate); |
1067 | 188 | |
1068 | 188 | bool Increasing = MainLoopStructure.IndVarIncreasing; |
1069 | 188 | |
1070 | 188 | // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or |
1071 | 188 | // [Smallest, GreatestSeen] is the range of values the induction variable |
1072 | 188 | // takes. |
1073 | 188 | |
1074 | 188 | const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; |
1075 | 188 | |
1076 | 188 | const SCEV *One = SE.getOne(RTy); |
1077 | 188 | if (Increasing) { |
1078 | 149 | Smallest = Start; |
1079 | 149 | Greatest = End; |
1080 | 149 | // No overflow, because the range [Smallest, GreatestSeen] is not empty. |
1081 | 149 | GreatestSeen = SE.getMinusSCEV(End, One); |
1082 | 149 | } else { |
1083 | 39 | // These two computations may sign-overflow. Here is why that is okay: |
1084 | 39 | // |
1085 | 39 | // We know that the induction variable does not sign-overflow on any |
1086 | 39 | // iteration except the last one, and it starts at `Start` and ends at |
1087 | 39 | // `End`, decrementing by one every time. |
1088 | 39 | // |
1089 | 39 | // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the |
1090 | 39 | // induction variable is decreasing we know that that the smallest value |
1091 | 39 | // the loop body is actually executed with is `INT_SMIN` == `Smallest`. |
1092 | 39 | // |
1093 | 39 | // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In |
1094 | 39 | // that case, `Clamp` will always return `Smallest` and |
1095 | 39 | // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) |
1096 | 39 | // will be an empty range. Returning an empty range is always safe. |
1097 | 39 | |
1098 | 39 | Smallest = SE.getAddExpr(End, One); |
1099 | 39 | Greatest = SE.getAddExpr(Start, One); |
1100 | 39 | GreatestSeen = Start; |
1101 | 39 | } |
1102 | 188 | |
1103 | 205 | auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { |
1104 | 205 | return IsSignedPredicate |
1105 | 205 | ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S))133 |
1106 | 205 | : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S))72 ; |
1107 | 205 | }; |
1108 | 188 | |
1109 | 188 | // In some cases we can prove that we don't need a pre or post loop. |
1110 | 188 | ICmpInst::Predicate PredLE = |
1111 | 188 | IsSignedPredicate ? ICmpInst::ICMP_SLE114 : ICmpInst::ICMP_ULE74 ; |
1112 | 188 | ICmpInst::Predicate PredLT = |
1113 | 188 | IsSignedPredicate ? ICmpInst::ICMP_SLT114 : ICmpInst::ICMP_ULT74 ; |
1114 | 188 | |
1115 | 188 | bool ProvablyNoPreloop = |
1116 | 188 | SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest); |
1117 | 188 | if (!ProvablyNoPreloop) |
1118 | 25 | Result.LowLimit = Clamp(Range.getBegin()); |
1119 | 188 | |
1120 | 188 | bool ProvablyNoPostLoop = |
1121 | 188 | SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd()); |
1122 | 188 | if (!ProvablyNoPostLoop) |
1123 | 180 | Result.HighLimit = Clamp(Range.getEnd()); |
1124 | 188 | |
1125 | 188 | return Result; |
1126 | 188 | } |
1127 | | |
1128 | | void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, |
1129 | 203 | const char *Tag) const { |
1130 | 445 | for (BasicBlock *BB : OriginalLoop.getBlocks()) { |
1131 | 445 | BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); |
1132 | 445 | Result.Blocks.push_back(Clone); |
1133 | 445 | Result.Map[BB] = Clone; |
1134 | 445 | } |
1135 | 203 | |
1136 | 1.84k | auto GetClonedValue = [&Result](Value *V) { |
1137 | 1.84k | assert(V && "null values not in domain!"); |
1138 | 1.84k | auto It = Result.Map.find(V); |
1139 | 1.84k | if (It == Result.Map.end()) |
1140 | 812 | return V; |
1141 | 1.03k | return static_cast<Value *>(It->second); |
1142 | 1.03k | }; |
1143 | 203 | |
1144 | 203 | auto *ClonedLatch = |
1145 | 203 | cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); |
1146 | 203 | ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, |
1147 | 203 | MDNode::get(Ctx, {})); |
1148 | 203 | |
1149 | 203 | Result.Structure = MainLoopStructure.map(GetClonedValue); |
1150 | 203 | Result.Structure.Tag = Tag; |
1151 | 203 | |
1152 | 648 | for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i445 ) { |
1153 | 445 | BasicBlock *ClonedBB = Result.Blocks[i]; |
1154 | 445 | BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; |
1155 | 445 | |
1156 | 445 | assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); |
1157 | 445 | |
1158 | 445 | for (Instruction &I : *ClonedBB) |
1159 | 1.91k | RemapInstruction(&I, Result.Map, |
1160 | 1.91k | RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); |
1161 | 445 | |
1162 | 445 | // Exit blocks will now have one more predecessor and their PHI nodes need |
1163 | 445 | // to be edited to reflect that. No phi nodes need to be introduced because |
1164 | 445 | // the loop is in LCSSA. |
1165 | 445 | |
1166 | 870 | for (auto *SBB : successors(OriginalBB)) { |
1167 | 870 | if (OriginalLoop.contains(SBB)) |
1168 | 461 | continue; // not an exit block |
1169 | 409 | |
1170 | 409 | for (PHINode &PN : SBB->phis()) { |
1171 | 20 | Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); |
1172 | 20 | PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); |
1173 | 20 | } |
1174 | 409 | } |
1175 | 445 | } |
1176 | 203 | } |
1177 | | |
1178 | | LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( |
1179 | | const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, |
1180 | 203 | BasicBlock *ContinuationBlock) const { |
1181 | 203 | // We start with a loop with a single latch: |
1182 | 203 | // |
1183 | 203 | // +--------------------+ |
1184 | 203 | // | | |
1185 | 203 | // | preheader | |
1186 | 203 | // | | |
1187 | 203 | // +--------+-----------+ |
1188 | 203 | // | ----------------\ |
1189 | 203 | // | / | |
1190 | 203 | // +--------v----v------+ | |
1191 | 203 | // | | | |
1192 | 203 | // | header | | |
1193 | 203 | // | | | |
1194 | 203 | // +--------------------+ | |
1195 | 203 | // | |
1196 | 203 | // ..... | |
1197 | 203 | // | |
1198 | 203 | // +--------------------+ | |
1199 | 203 | // | | | |
1200 | 203 | // | latch >----------/ |
1201 | 203 | // | | |
1202 | 203 | // +-------v------------+ |
1203 | 203 | // | |
1204 | 203 | // | |
1205 | 203 | // | +--------------------+ |
1206 | 203 | // | | | |
1207 | 203 | // +---> original exit | |
1208 | 203 | // | | |
1209 | 203 | // +--------------------+ |
1210 | 203 | // |
1211 | 203 | // We change the control flow to look like |
1212 | 203 | // |
1213 | 203 | // |
1214 | 203 | // +--------------------+ |
1215 | 203 | // | | |
1216 | 203 | // | preheader >-------------------------+ |
1217 | 203 | // | | | |
1218 | 203 | // +--------v-----------+ | |
1219 | 203 | // | /-------------+ | |
1220 | 203 | // | / | | |
1221 | 203 | // +--------v--v--------+ | | |
1222 | 203 | // | | | | |
1223 | 203 | // | header | | +--------+ | |
1224 | 203 | // | | | | | | |
1225 | 203 | // +--------------------+ | | +-----v-----v-----------+ |
1226 | 203 | // | | | | |
1227 | 203 | // | | | .pseudo.exit | |
1228 | 203 | // | | | | |
1229 | 203 | // | | +-----------v-----------+ |
1230 | 203 | // | | | |
1231 | 203 | // ..... | | | |
1232 | 203 | // | | +--------v-------------+ |
1233 | 203 | // +--------------------+ | | | | |
1234 | 203 | // | | | | | ContinuationBlock | |
1235 | 203 | // | latch >------+ | | | |
1236 | 203 | // | | | +----------------------+ |
1237 | 203 | // +---------v----------+ | |
1238 | 203 | // | | |
1239 | 203 | // | | |
1240 | 203 | // | +---------------^-----+ |
1241 | 203 | // | | | |
1242 | 203 | // +-----> .exit.selector | |
1243 | 203 | // | | |
1244 | 203 | // +----------v----------+ |
1245 | 203 | // | |
1246 | 203 | // +--------------------+ | |
1247 | 203 | // | | | |
1248 | 203 | // | original exit <----+ |
1249 | 203 | // | | |
1250 | 203 | // +--------------------+ |
1251 | 203 | |
1252 | 203 | RewrittenRangeInfo RRI; |
1253 | 203 | |
1254 | 203 | BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); |
1255 | 203 | RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", |
1256 | 203 | &F, BBInsertLocation); |
1257 | 203 | RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, |
1258 | 203 | BBInsertLocation); |
1259 | 203 | |
1260 | 203 | BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); |
1261 | 203 | bool Increasing = LS.IndVarIncreasing; |
1262 | 203 | bool IsSignedPredicate = LS.IsSignedPredicate; |
1263 | 203 | |
1264 | 203 | IRBuilder<> B(PreheaderJump); |
1265 | 203 | auto *RangeTy = Range.getBegin()->getType(); |
1266 | 609 | auto NoopOrExt = [&](Value *V) { |
1267 | 609 | if (V->getType() == RangeTy) |
1268 | 551 | return V; |
1269 | 58 | return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())34 |
1270 | 58 | : B.CreateZExt(V, RangeTy, "wide." + V->getName())24 ; |
1271 | 58 | }; |
1272 | 203 | |
1273 | 203 | // EnterLoopCond - is it okay to start executing this `LS'? |
1274 | 203 | Value *EnterLoopCond = nullptr; |
1275 | 203 | auto Pred = |
1276 | 203 | Increasing |
1277 | 203 | ? (IsSignedPredicate 161 ? ICmpInst::ICMP_SLT105 : ICmpInst::ICMP_ULT56 ) |
1278 | 203 | : (IsSignedPredicate 42 ? ICmpInst::ICMP_SGT26 : ICmpInst::ICMP_UGT16 ); |
1279 | 203 | Value *IndVarStart = NoopOrExt(LS.IndVarStart); |
1280 | 203 | EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); |
1281 | 203 | |
1282 | 203 | B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); |
1283 | 203 | PreheaderJump->eraseFromParent(); |
1284 | 203 | |
1285 | 203 | LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); |
1286 | 203 | B.SetInsertPoint(LS.LatchBr); |
1287 | 203 | Value *IndVarBase = NoopOrExt(LS.IndVarBase); |
1288 | 203 | Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); |
1289 | 203 | |
1290 | 203 | Value *CondForBranch = LS.LatchBrExitIdx == 1 |
1291 | 203 | ? TakeBackedgeLoopCond177 |
1292 | 203 | : B.CreateNot(TakeBackedgeLoopCond)26 ; |
1293 | 203 | |
1294 | 203 | LS.LatchBr->setCondition(CondForBranch); |
1295 | 203 | |
1296 | 203 | B.SetInsertPoint(RRI.ExitSelector); |
1297 | 203 | |
1298 | 203 | // IterationsLeft - are there any more iterations left, given the original |
1299 | 203 | // upper bound on the induction variable? If not, we branch to the "real" |
1300 | 203 | // exit. |
1301 | 203 | Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); |
1302 | 203 | Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); |
1303 | 203 | B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); |
1304 | 203 | |
1305 | 203 | BranchInst *BranchToContinuation = |
1306 | 203 | BranchInst::Create(ContinuationBlock, RRI.PseudoExit); |
1307 | 203 | |
1308 | 203 | // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of |
1309 | 203 | // each of the PHI nodes in the loop header. This feeds into the initial |
1310 | 203 | // value of the same PHI nodes if/when we continue execution. |
1311 | 211 | for (PHINode &PN : LS.Header->phis()) { |
1312 | 211 | PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", |
1313 | 211 | BranchToContinuation); |
1314 | 211 | |
1315 | 211 | NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); |
1316 | 211 | NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), |
1317 | 211 | RRI.ExitSelector); |
1318 | 211 | RRI.PHIValuesAtPseudoExit.push_back(NewPHI); |
1319 | 211 | } |
1320 | 203 | |
1321 | 203 | RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", |
1322 | 203 | BranchToContinuation); |
1323 | 203 | RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); |
1324 | 203 | RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); |
1325 | 203 | |
1326 | 203 | // The latch exit now has a branch from `RRI.ExitSelector' instead of |
1327 | 203 | // `LS.Latch'. The PHI nodes need to be updated to reflect that. |
1328 | 203 | LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); |
1329 | 203 | |
1330 | 203 | return RRI; |
1331 | 203 | } |
1332 | | |
1333 | | void LoopConstrainer::rewriteIncomingValuesForPHIs( |
1334 | | LoopStructure &LS, BasicBlock *ContinuationBlock, |
1335 | 203 | const LoopConstrainer::RewrittenRangeInfo &RRI) const { |
1336 | 203 | unsigned PHIIndex = 0; |
1337 | 203 | for (PHINode &PN : LS.Header->phis()) |
1338 | 211 | PN.setIncomingValueForBlock(ContinuationBlock, |
1339 | 211 | RRI.PHIValuesAtPseudoExit[PHIIndex++]); |
1340 | 203 | |
1341 | 203 | LS.IndVarStart = RRI.IndVarEnd; |
1342 | 203 | } |
1343 | | |
1344 | | BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, |
1345 | | BasicBlock *OldPreheader, |
1346 | 203 | const char *Tag) const { |
1347 | 203 | BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); |
1348 | 203 | BranchInst::Create(LS.Header, Preheader); |
1349 | 203 | |
1350 | 203 | LS.Header->replacePhiUsesWith(OldPreheader, Preheader); |
1351 | 203 | |
1352 | 203 | return Preheader; |
1353 | 203 | } |
1354 | | |
1355 | 186 | void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { |
1356 | 186 | Loop *ParentLoop = OriginalLoop.getParentLoop(); |
1357 | 186 | if (!ParentLoop) |
1358 | 169 | return; |
1359 | 17 | |
1360 | 17 | for (BasicBlock *BB : BBs) |
1361 | 51 | ParentLoop->addBasicBlockToLoop(BB, LI); |
1362 | 17 | } |
1363 | | |
1364 | | Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, |
1365 | | ValueToValueMapTy &VM, |
1366 | 209 | bool IsSubloop) { |
1367 | 209 | Loop &New = *LI.AllocateLoop(); |
1368 | 209 | if (Parent) |
1369 | 23 | Parent->addChildLoop(&New); |
1370 | 186 | else |
1371 | 186 | LI.addTopLevelLoop(&New); |
1372 | 209 | LPMAddNewLoop(&New, IsSubloop); |
1373 | 209 | |
1374 | 209 | // Add all of the blocks in Original to the new loop. |
1375 | 209 | for (auto *BB : Original->blocks()) |
1376 | 451 | if (LI.getLoopFor(BB) == Original) |
1377 | 445 | New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); |
1378 | 209 | |
1379 | 209 | // Add all of the subloops to the new loop. |
1380 | 209 | for (Loop *SubLoop : *Original) |
1381 | 6 | createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); |
1382 | 209 | |
1383 | 209 | return &New; |
1384 | 209 | } |
1385 | | |
1386 | 188 | bool LoopConstrainer::run() { |
1387 | 188 | BasicBlock *Preheader = nullptr; |
1388 | 188 | LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch); |
1389 | 188 | Preheader = OriginalLoop.getLoopPreheader(); |
1390 | 188 | assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr && |
1391 | 188 | "preconditions!"); |
1392 | 188 | |
1393 | 188 | OriginalPreheader = Preheader; |
1394 | 188 | MainLoopPreheader = Preheader; |
1395 | 188 | |
1396 | 188 | bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; |
1397 | 188 | Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); |
1398 | 188 | if (!MaybeSR.hasValue()) { |
1399 | 0 | LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); |
1400 | 0 | return false; |
1401 | 0 | } |
1402 | 188 | |
1403 | 188 | SubRanges SR = MaybeSR.getValue(); |
1404 | 188 | bool Increasing = MainLoopStructure.IndVarIncreasing; |
1405 | 188 | IntegerType *IVTy = |
1406 | 188 | cast<IntegerType>(Range.getBegin()->getType()); |
1407 | 188 | |
1408 | 188 | SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); |
1409 | 188 | Instruction *InsertPt = OriginalPreheader->getTerminator(); |
1410 | 188 | |
1411 | 188 | // It would have been better to make `PreLoop' and `PostLoop' |
1412 | 188 | // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy |
1413 | 188 | // constructor. |
1414 | 188 | ClonedLoop PreLoop, PostLoop; |
1415 | 188 | bool NeedsPreLoop = |
1416 | 188 | Increasing ? SR.LowLimit.hasValue()149 : SR.HighLimit.hasValue()39 ; |
1417 | 188 | bool NeedsPostLoop = |
1418 | 188 | Increasing ? SR.HighLimit.hasValue()149 : SR.LowLimit.hasValue()39 ; |
1419 | 188 | |
1420 | 188 | Value *ExitPreLoopAt = nullptr; |
1421 | 188 | Value *ExitMainLoopAt = nullptr; |
1422 | 188 | const SCEVConstant *MinusOneS = |
1423 | 188 | cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); |
1424 | 188 | |
1425 | 188 | if (NeedsPreLoop) { |
1426 | 61 | const SCEV *ExitPreLoopAtSCEV = nullptr; |
1427 | 61 | |
1428 | 61 | if (Increasing) |
1429 | 22 | ExitPreLoopAtSCEV = *SR.LowLimit; |
1430 | 39 | else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, |
1431 | 39 | IsSignedPredicate)) |
1432 | 39 | ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); |
1433 | 0 | else { |
1434 | 0 | LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " |
1435 | 0 | << "preloop exit limit. HighLimit = " |
1436 | 0 | << *(*SR.HighLimit) << "\n"); |
1437 | 0 | return false; |
1438 | 0 | } |
1439 | 61 | |
1440 | 61 | if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { |
1441 | 0 | LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" |
1442 | 0 | << " preloop exit limit " << *ExitPreLoopAtSCEV |
1443 | 0 | << " at block " << InsertPt->getParent()->getName() |
1444 | 0 | << "\n"); |
1445 | 0 | return false; |
1446 | 0 | } |
1447 | 61 | |
1448 | 61 | ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); |
1449 | 61 | ExitPreLoopAt->setName("exit.preloop.at"); |
1450 | 61 | } |
1451 | 188 | |
1452 | 188 | if (NeedsPostLoop) { |
1453 | 144 | const SCEV *ExitMainLoopAtSCEV = nullptr; |
1454 | 144 | |
1455 | 144 | if (Increasing) |
1456 | 141 | ExitMainLoopAtSCEV = *SR.HighLimit; |
1457 | 3 | else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, |
1458 | 3 | IsSignedPredicate)) |
1459 | 3 | ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); |
1460 | 0 | else { |
1461 | 0 | LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " |
1462 | 0 | << "mainloop exit limit. LowLimit = " |
1463 | 0 | << *(*SR.LowLimit) << "\n"); |
1464 | 0 | return false; |
1465 | 0 | } |
1466 | 144 | |
1467 | 144 | if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { |
1468 | 2 | LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" |
1469 | 2 | << " main loop exit limit " << *ExitMainLoopAtSCEV |
1470 | 2 | << " at block " << InsertPt->getParent()->getName() |
1471 | 2 | << "\n"); |
1472 | 2 | return false; |
1473 | 2 | } |
1474 | 142 | |
1475 | 142 | ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); |
1476 | 142 | ExitMainLoopAt->setName("exit.mainloop.at"); |
1477 | 142 | } |
1478 | 188 | |
1479 | 188 | // We clone these ahead of time so that we don't have to deal with changing |
1480 | 188 | // and temporarily invalid IR as we transform the loops. |
1481 | 188 | if (186 NeedsPreLoop186 ) |
1482 | 61 | cloneLoop(PreLoop, "preloop"); |
1483 | 186 | if (NeedsPostLoop) |
1484 | 142 | cloneLoop(PostLoop, "postloop"); |
1485 | 186 | |
1486 | 186 | RewrittenRangeInfo PreLoopRRI; |
1487 | 186 | |
1488 | 186 | if (NeedsPreLoop) { |
1489 | 61 | Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, |
1490 | 61 | PreLoop.Structure.Header); |
1491 | 61 | |
1492 | 61 | MainLoopPreheader = |
1493 | 61 | createPreheader(MainLoopStructure, Preheader, "mainloop"); |
1494 | 61 | PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, |
1495 | 61 | ExitPreLoopAt, MainLoopPreheader); |
1496 | 61 | rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, |
1497 | 61 | PreLoopRRI); |
1498 | 61 | } |
1499 | 186 | |
1500 | 186 | BasicBlock *PostLoopPreheader = nullptr; |
1501 | 186 | RewrittenRangeInfo PostLoopRRI; |
1502 | 186 | |
1503 | 186 | if (NeedsPostLoop) { |
1504 | 142 | PostLoopPreheader = |
1505 | 142 | createPreheader(PostLoop.Structure, Preheader, "postloop"); |
1506 | 142 | PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, |
1507 | 142 | ExitMainLoopAt, PostLoopPreheader); |
1508 | 142 | rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, |
1509 | 142 | PostLoopRRI); |
1510 | 142 | } |
1511 | 186 | |
1512 | 186 | BasicBlock *NewMainLoopPreheader = |
1513 | 186 | MainLoopPreheader != Preheader ? MainLoopPreheader61 : nullptr125 ; |
1514 | 186 | BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, |
1515 | 186 | PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, |
1516 | 186 | PostLoopRRI.ExitSelector, NewMainLoopPreheader}; |
1517 | 186 | |
1518 | 186 | // Some of the above may be nullptr, filter them out before passing to |
1519 | 186 | // addToParentLoopIfNeeded. |
1520 | 186 | auto NewBlocksEnd = |
1521 | 186 | std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); |
1522 | 186 | |
1523 | 186 | addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd)); |
1524 | 186 | |
1525 | 186 | DT.recalculate(F); |
1526 | 186 | |
1527 | 186 | // We need to first add all the pre and post loop blocks into the loop |
1528 | 186 | // structures (as part of createClonedLoopStructure), and then update the |
1529 | 186 | // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating |
1530 | 186 | // LI when LoopSimplifyForm is generated. |
1531 | 186 | Loop *PreL = nullptr, *PostL = nullptr; |
1532 | 186 | if (!PreLoop.Blocks.empty()) { |
1533 | 61 | PreL = createClonedLoopStructure(&OriginalLoop, |
1534 | 61 | OriginalLoop.getParentLoop(), PreLoop.Map, |
1535 | 61 | /* IsSubLoop */ false); |
1536 | 61 | } |
1537 | 186 | |
1538 | 186 | if (!PostLoop.Blocks.empty()) { |
1539 | 142 | PostL = |
1540 | 142 | createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), |
1541 | 142 | PostLoop.Map, /* IsSubLoop */ false); |
1542 | 142 | } |
1543 | 186 | |
1544 | 186 | // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. |
1545 | 389 | auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) { |
1546 | 389 | formLCSSARecursively(*L, DT, &LI, &SE); |
1547 | 389 | simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); |
1548 | 389 | // Pre/post loops are slow paths, we do not need to perform any loop |
1549 | 389 | // optimizations on them. |
1550 | 389 | if (!IsOriginalLoop) |
1551 | 203 | DisableAllLoopOptsOnLoop(*L); |
1552 | 389 | }; |
1553 | 186 | if (PreL) |
1554 | 61 | CanonicalizeLoop(PreL, false); |
1555 | 186 | if (PostL) |
1556 | 142 | CanonicalizeLoop(PostL, false); |
1557 | 186 | CanonicalizeLoop(&OriginalLoop, true); |
1558 | 186 | |
1559 | 186 | return true; |
1560 | 188 | } |
1561 | | |
1562 | | /// Computes and returns a range of values for the induction variable (IndVar) |
1563 | | /// in which the range check can be safely elided. If it cannot compute such a |
1564 | | /// range, returns None. |
1565 | | Optional<InductiveRangeCheck::Range> |
1566 | | InductiveRangeCheck::computeSafeIterationSpace( |
1567 | | ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, |
1568 | 255 | bool IsLatchSigned) const { |
1569 | 255 | // We can deal when types of latch check and range checks don't match in case |
1570 | 255 | // if latch check is more narrow. |
1571 | 255 | auto *IVType = cast<IntegerType>(IndVar->getType()); |
1572 | 255 | auto *RCType = cast<IntegerType>(getBegin()->getType()); |
1573 | 255 | if (IVType->getBitWidth() > RCType->getBitWidth()) |
1574 | 2 | return None; |
1575 | 253 | // IndVar is of the form "A + B * I" (where "I" is the canonical induction |
1576 | 253 | // variable, that may or may not exist as a real llvm::Value in the loop) and |
1577 | 253 | // this inductive range check is a range check on the "C + D * I" ("C" is |
1578 | 253 | // getBegin() and "D" is getStep()). We rewrite the value being range |
1579 | 253 | // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". |
1580 | 253 | // |
1581 | 253 | // The actual inequalities we solve are of the form |
1582 | 253 | // |
1583 | 253 | // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) |
1584 | 253 | // |
1585 | 253 | // Here L stands for upper limit of the safe iteration space. |
1586 | 253 | // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid |
1587 | 253 | // overflows when calculating (0 - M) and (L - M) we, depending on type of |
1588 | 253 | // IV's iteration space, limit the calculations by borders of the iteration |
1589 | 253 | // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0. |
1590 | 253 | // If we figured out that "anything greater than (-M) is safe", we strengthen |
1591 | 253 | // this to "everything greater than 0 is safe", assuming that values between |
1592 | 253 | // -M and 0 just do not exist in unsigned iteration space, and we don't want |
1593 | 253 | // to deal with overflown values. |
1594 | 253 | |
1595 | 253 | if (!IndVar->isAffine()) |
1596 | 0 | return None; |
1597 | 253 | |
1598 | 253 | const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned); |
1599 | 253 | const SCEVConstant *B = dyn_cast<SCEVConstant>( |
1600 | 253 | NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned)); |
1601 | 253 | if (!B) |
1602 | 0 | return None; |
1603 | 253 | assert(!B->isZero() && "Recurrence with zero step?"); |
1604 | 253 | |
1605 | 253 | const SCEV *C = getBegin(); |
1606 | 253 | const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep()); |
1607 | 253 | if (D != B) |
1608 | 4 | return None; |
1609 | 249 | |
1610 | 249 | assert(!D->getValue()->isZero() && "Recurrence with zero step?"); |
1611 | 249 | unsigned BitWidth = RCType->getBitWidth(); |
1612 | 249 | const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); |
1613 | 249 | |
1614 | 249 | // Subtract Y from X so that it does not go through border of the IV |
1615 | 249 | // iteration space. Mathematically, it is equivalent to: |
1616 | 249 | // |
1617 | 249 | // ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] |
1618 | 249 | // |
1619 | 249 | // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to |
1620 | 249 | // any width of bit grid). But after we take min/max, the result is |
1621 | 249 | // guaranteed to be within [INT_MIN, INT_MAX]. |
1622 | 249 | // |
1623 | 249 | // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min |
1624 | 249 | // values, depending on type of latch condition that defines IV iteration |
1625 | 249 | // space. |
1626 | 498 | auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) { |
1627 | 498 | // FIXME: The current implementation assumes that X is in [0, SINT_MAX]. |
1628 | 498 | // This is required to ensure that SINT_MAX - X does not overflow signed and |
1629 | 498 | // that X - Y does not overflow unsigned if Y is negative. Can we lift this |
1630 | 498 | // restriction and make it work for negative X either? |
1631 | 498 | if (IsLatchSigned) { |
1632 | 296 | // X is a number from signed range, Y is interpreted as signed. |
1633 | 296 | // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only |
1634 | 296 | // thing we should care about is that we didn't cross SINT_MAX. |
1635 | 296 | // So, if Y is positive, we subtract Y safely. |
1636 | 296 | // Rule 1: Y > 0 ---> Y. |
1637 | 296 | // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely. |
1638 | 296 | // Rule 2: Y >=s (X - SINT_MAX) ---> Y. |
1639 | 296 | // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX). |
1640 | 296 | // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). |
1641 | 296 | // It gives us smax(Y, X - SINT_MAX) to subtract in all cases. |
1642 | 296 | const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax); |
1643 | 296 | return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax), |
1644 | 296 | SCEV::FlagNSW); |
1645 | 296 | } else |
1646 | 202 | // X is a number from unsigned range, Y is interpreted as signed. |
1647 | 202 | // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only |
1648 | 202 | // thing we should care about is that we didn't cross zero. |
1649 | 202 | // So, if Y is negative, we subtract Y safely. |
1650 | 202 | // Rule 1: Y <s 0 ---> Y. |
1651 | 202 | // If 0 <= Y <= X, we subtract Y safely. |
1652 | 202 | // Rule 2: Y <=s X ---> Y. |
1653 | 202 | // If 0 <= X < Y, we should stop at 0 and can only subtract X. |
1654 | 202 | // Rule 3: Y >s X ---> X. |
1655 | 202 | // It gives us smin(X, Y) to subtract in all cases. |
1656 | 202 | return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW); |
1657 | 498 | }; |
1658 | 249 | const SCEV *M = SE.getMinusSCEV(C, A); |
1659 | 249 | const SCEV *Zero = SE.getZero(M->getType()); |
1660 | 249 | |
1661 | 249 | // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. |
1662 | 249 | auto SCEVCheckNonNegative = [&](const SCEV *X) { |
1663 | 249 | const Loop *L = IndVar->getLoop(); |
1664 | 249 | const SCEV *One = SE.getOne(X->getType()); |
1665 | 249 | // Can we trivially prove that X is a non-negative or negative value? |
1666 | 249 | if (isKnownNonNegativeInLoop(X, L, SE)) |
1667 | 205 | return One; |
1668 | 44 | else if (isKnownNegativeInLoop(X, L, SE)) |
1669 | 8 | return Zero; |
1670 | 36 | // If not, we will have to figure it out during the execution. |
1671 | 36 | // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0. |
1672 | 36 | const SCEV *NegOne = SE.getNegativeSCEV(One); |
1673 | 36 | return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One); |
1674 | 36 | }; |
1675 | 249 | // FIXME: Current implementation of ClampedSubtract implicitly assumes that |
1676 | 249 | // X is non-negative (in sense of a signed value). We need to re-implement |
1677 | 249 | // this function in a way that it will correctly handle negative X as well. |
1678 | 249 | // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can |
1679 | 249 | // end up with a negative X and produce wrong results. So currently we ensure |
1680 | 249 | // that if getEnd() is negative then both ends of the safe range are zero. |
1681 | 249 | // Note that this may pessimize elimination of unsigned range checks against |
1682 | 249 | // negative values. |
1683 | 249 | const SCEV *REnd = getEnd(); |
1684 | 249 | const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd); |
1685 | 249 | |
1686 | 249 | const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative); |
1687 | 249 | const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative); |
1688 | 249 | return InductiveRangeCheck::Range(Begin, End); |
1689 | 249 | } |
1690 | | |
1691 | | static Optional<InductiveRangeCheck::Range> |
1692 | | IntersectSignedRange(ScalarEvolution &SE, |
1693 | | const Optional<InductiveRangeCheck::Range> &R1, |
1694 | 148 | const InductiveRangeCheck::Range &R2) { |
1695 | 148 | if (R2.isEmpty(SE, /* IsSigned */ true)) |
1696 | 8 | return None; |
1697 | 140 | if (!R1.hasValue()) |
1698 | 114 | return R2; |
1699 | 26 | auto &R1Value = R1.getValue(); |
1700 | 26 | // We never return empty ranges from this function, and R1 is supposed to be |
1701 | 26 | // a result of intersection. Thus, R1 is never empty. |
1702 | 26 | assert(!R1Value.isEmpty(SE, /* IsSigned */ true) && |
1703 | 26 | "We should never have empty R1!"); |
1704 | 26 | |
1705 | 26 | // TODO: we could widen the smaller range and have this work; but for now we |
1706 | 26 | // bail out to keep things simple. |
1707 | 26 | if (R1Value.getType() != R2.getType()) |
1708 | 0 | return None; |
1709 | 26 | |
1710 | 26 | const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); |
1711 | 26 | const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); |
1712 | 26 | |
1713 | 26 | // If the resulting range is empty, just return None. |
1714 | 26 | auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); |
1715 | 26 | if (Ret.isEmpty(SE, /* IsSigned */ true)) |
1716 | 0 | return None; |
1717 | 26 | return Ret; |
1718 | 26 | } |
1719 | | |
1720 | | static Optional<InductiveRangeCheck::Range> |
1721 | | IntersectUnsignedRange(ScalarEvolution &SE, |
1722 | | const Optional<InductiveRangeCheck::Range> &R1, |
1723 | 101 | const InductiveRangeCheck::Range &R2) { |
1724 | 101 | if (R2.isEmpty(SE, /* IsSigned */ false)) |
1725 | 10 | return None; |
1726 | 91 | if (!R1.hasValue()) |
1727 | 74 | return R2; |
1728 | 17 | auto &R1Value = R1.getValue(); |
1729 | 17 | // We never return empty ranges from this function, and R1 is supposed to be |
1730 | 17 | // a result of intersection. Thus, R1 is never empty. |
1731 | 17 | assert(!R1Value.isEmpty(SE, /* IsSigned */ false) && |
1732 | 17 | "We should never have empty R1!"); |
1733 | 17 | |
1734 | 17 | // TODO: we could widen the smaller range and have this work; but for now we |
1735 | 17 | // bail out to keep things simple. |
1736 | 17 | if (R1Value.getType() != R2.getType()) |
1737 | 0 | return None; |
1738 | 17 | |
1739 | 17 | const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin()); |
1740 | 17 | const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd()); |
1741 | 17 | |
1742 | 17 | // If the resulting range is empty, just return None. |
1743 | 17 | auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); |
1744 | 17 | if (Ret.isEmpty(SE, /* IsSigned */ false)) |
1745 | 0 | return None; |
1746 | 17 | return Ret; |
1747 | 17 | } |
1748 | | |
1749 | | PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM, |
1750 | | LoopStandardAnalysisResults &AR, |
1751 | 225 | LPMUpdater &U) { |
1752 | 225 | Function *F = L.getHeader()->getParent(); |
1753 | 225 | const auto &FAM = |
1754 | 225 | AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); |
1755 | 225 | auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); |
1756 | 225 | InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI); |
1757 | 225 | auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) { |
1758 | 99 | if (!IsSubloop) |
1759 | 96 | U.addSiblingLoops(NL); |
1760 | 99 | }; |
1761 | 225 | bool Changed = IRCE.run(&L, LPMAddNewLoop); |
1762 | 225 | if (!Changed) |
1763 | 136 | return PreservedAnalyses::all(); |
1764 | 89 | |
1765 | 89 | return getLoopPassPreservedAnalyses(); |
1766 | 89 | } |
1767 | | |
1768 | 252 | bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { |
1769 | 252 | if (skipLoop(L)) |
1770 | 0 | return false; |
1771 | 252 | |
1772 | 252 | ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
1773 | 252 | BranchProbabilityInfo &BPI = |
1774 | 252 | getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); |
1775 | 252 | auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
1776 | 252 | auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
1777 | 252 | InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); |
1778 | 252 | auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) { |
1779 | 110 | LPM.addLoop(*NL); |
1780 | 110 | }; |
1781 | 252 | return IRCE.run(L, LPMAddNewLoop); |
1782 | 252 | } |
1783 | | |
1784 | | bool InductiveRangeCheckElimination::run( |
1785 | 477 | Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) { |
1786 | 477 | if (L->getBlocks().size() >= LoopSizeCutoff) { |
1787 | 0 | LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n"); |
1788 | 0 | return false; |
1789 | 0 | } |
1790 | 477 | |
1791 | 477 | BasicBlock *Preheader = L->getLoopPreheader(); |
1792 | 477 | if (!Preheader) { |
1793 | 0 | LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); |
1794 | 0 | return false; |
1795 | 0 | } |
1796 | 477 | |
1797 | 477 | LLVMContext &Context = Preheader->getContext(); |
1798 | 477 | SmallVector<InductiveRangeCheck, 16> RangeChecks; |
1799 | 477 | |
1800 | 477 | for (auto BBI : L->getBlocks()) |
1801 | 1.44k | if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) |
1802 | 1.43k | InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI, |
1803 | 1.43k | RangeChecks); |
1804 | 477 | |
1805 | 477 | if (RangeChecks.empty()) |
1806 | 235 | return false; |
1807 | 242 | |
1808 | 242 | auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { |
1809 | 12 | OS << "irce: looking at loop "; L->print(OS); |
1810 | 12 | OS << "irce: loop has " << RangeChecks.size() |
1811 | 12 | << " inductive range checks: \n"; |
1812 | 12 | for (InductiveRangeCheck &IRC : RangeChecks) |
1813 | 20 | IRC.print(OS); |
1814 | 12 | }; |
1815 | 242 | |
1816 | 242 | LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs())); |
1817 | 242 | |
1818 | 242 | if (PrintRangeChecks) |
1819 | 12 | PrintRecognizedRangeChecks(errs()); |
1820 | 242 | |
1821 | 242 | const char *FailureReason = nullptr; |
1822 | 242 | Optional<LoopStructure> MaybeLoopStructure = |
1823 | 242 | LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); |
1824 | 242 | if (!MaybeLoopStructure.hasValue()) { |
1825 | 38 | LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " |
1826 | 38 | << FailureReason << "\n";); |
1827 | 38 | return false; |
1828 | 38 | } |
1829 | 204 | LoopStructure LS = MaybeLoopStructure.getValue(); |
1830 | 204 | const SCEVAddRecExpr *IndVar = |
1831 | 204 | cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep))); |
1832 | 204 | |
1833 | 204 | Optional<InductiveRangeCheck::Range> SafeIterRange; |
1834 | 204 | Instruction *ExprInsertPt = Preheader->getTerminator(); |
1835 | 204 | |
1836 | 204 | SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; |
1837 | 204 | // Basing on the type of latch predicate, we interpret the IV iteration range |
1838 | 204 | // as signed or unsigned range. We use different min/max functions (signed or |
1839 | 204 | // unsigned) when intersecting this range with safe iteration ranges implied |
1840 | 204 | // by range checks. |
1841 | 204 | auto IntersectRange = |
1842 | 204 | LS.IsSignedPredicate ? IntersectSignedRange124 : IntersectUnsignedRange80 ; |
1843 | 204 | |
1844 | 204 | IRBuilder<> B(ExprInsertPt); |
1845 | 255 | for (InductiveRangeCheck &IRC : RangeChecks) { |
1846 | 255 | auto Result = IRC.computeSafeIterationSpace(SE, IndVar, |
1847 | 255 | LS.IsSignedPredicate); |
1848 | 255 | if (Result.hasValue()) { |
1849 | 249 | auto MaybeSafeIterRange = |
1850 | 249 | IntersectRange(SE, SafeIterRange, Result.getValue()); |
1851 | 249 | if (MaybeSafeIterRange.hasValue()) { |
1852 | 231 | assert( |
1853 | 231 | !MaybeSafeIterRange.getValue().isEmpty(SE, LS.IsSignedPredicate) && |
1854 | 231 | "We should never return empty ranges!"); |
1855 | 231 | RangeChecksToEliminate.push_back(IRC); |
1856 | 231 | SafeIterRange = MaybeSafeIterRange.getValue(); |
1857 | 231 | } |
1858 | 249 | } |
1859 | 255 | } |
1860 | 204 | |
1861 | 204 | if (!SafeIterRange.hasValue()) |
1862 | 16 | return false; |
1863 | 188 | |
1864 | 188 | LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, |
1865 | 188 | SafeIterRange.getValue()); |
1866 | 188 | bool Changed = LC.run(); |
1867 | 188 | |
1868 | 188 | if (Changed) { |
1869 | 186 | auto PrintConstrainedLoopInfo = [L]() { |
1870 | 152 | dbgs() << "irce: in function "; |
1871 | 152 | dbgs() << L->getHeader()->getParent()->getName() << ": "; |
1872 | 152 | dbgs() << "constrained "; |
1873 | 152 | L->print(dbgs()); |
1874 | 152 | }; |
1875 | 186 | |
1876 | 186 | LLVM_DEBUG(PrintConstrainedLoopInfo()); |
1877 | 186 | |
1878 | 186 | if (PrintChangedLoops) |
1879 | 152 | PrintConstrainedLoopInfo(); |
1880 | 186 | |
1881 | 186 | // Optimize away the now-redundant range checks. |
1882 | 186 | |
1883 | 229 | for (InductiveRangeCheck &IRC : RangeChecksToEliminate) { |
1884 | 229 | ConstantInt *FoldedRangeCheck = IRC.getPassingDirection() |
1885 | 229 | ? ConstantInt::getTrue(Context) |
1886 | 229 | : ConstantInt::getFalse(Context)0 ; |
1887 | 229 | IRC.getCheckUse()->set(FoldedRangeCheck); |
1888 | 229 | } |
1889 | 186 | } |
1890 | 188 | |
1891 | 188 | return Changed; |
1892 | 188 | } |
1893 | | |
1894 | 0 | Pass *llvm::createInductiveRangeCheckEliminationPass() { |
1895 | 0 | return new IRCELegacyPass(); |
1896 | 0 | } |