Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/tools/polly/lib/Support/SCEVValidator.cpp
Line
Count
Source (jump to first uncovered line)
1
2
#include "polly/Support/SCEVValidator.h"
3
#include "polly/ScopDetection.h"
4
#include "llvm/Analysis/RegionInfo.h"
5
#include "llvm/Analysis/ScalarEvolution.h"
6
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
7
#include "llvm/Support/Debug.h"
8
9
using namespace llvm;
10
using namespace polly;
11
12
#define DEBUG_TYPE "polly-scev-validator"
13
14
namespace SCEVType {
15
/// The type of a SCEV
16
///
17
/// To check for the validity of a SCEV we assign to each SCEV a type. The
18
/// possible types are INT, PARAM, IV and INVALID. The order of the types is
19
/// important. The subexpressions of SCEV with a type X can only have a type
20
/// that is smaller or equal than X.
21
enum TYPE {
22
  // An integer value.
23
  INT,
24
25
  // An expression that is constant during the execution of the Scop,
26
  // but that may depend on parameters unknown at compile time.
27
  PARAM,
28
29
  // An expression that may change during the execution of the SCoP.
30
  IV,
31
32
  // An invalid expression.
33
  INVALID
34
};
35
} // namespace SCEVType
36
37
/// The result the validator returns for a SCEV expression.
38
class ValidatorResult {
39
  /// The type of the expression
40
  SCEVType::TYPE Type;
41
42
  /// The set of Parameters in the expression.
43
  ParameterSetTy Parameters;
44
45
public:
46
  /// The copy constructor
47
7.21k
  ValidatorResult(const ValidatorResult &Source) {
48
7.21k
    Type = Source.Type;
49
7.21k
    Parameters = Source.Parameters;
50
7.21k
  }
51
52
  /// Construct a result with a certain type and no parameters.
53
123k
  ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54
123k
    assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55
123k
  }
56
57
  /// Construct a result with a certain type and a single parameter.
58
17.9k
  ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59
17.9k
    Parameters.insert(Expr);
60
17.9k
  }
61
62
  /// Get the type of the ValidatorResult.
63
1.10k
  SCEVType::TYPE getType() { return Type; }
64
65
  /// Is the analyzed SCEV constant during the execution of the SCoP.
66
166
  bool isConstant() { return Type == SCEVType::INT || 
Type == SCEVType::PARAM110
; }
67
68
  /// Is the analyzed SCEV valid.
69
118k
  bool isValid() { return Type != SCEVType::INVALID; }
70
71
  /// Is the analyzed SCEV of Type IV.
72
4.40k
  bool isIV() { return Type == SCEVType::IV; }
73
74
  /// Is the analyzed SCEV of Type INT.
75
42.4k
  bool isINT() { return Type == SCEVType::INT; }
76
77
  /// Is the analyzed SCEV of Type PARAM.
78
14.8k
  bool isPARAM() { return Type == SCEVType::PARAM; }
79
80
  /// Get the parameters of this validator result.
81
12.6k
  const ParameterSetTy &getParameters() { return Parameters; }
82
83
  /// Add the parameters of Source to this result.
84
41.9k
  void addParamsFrom(const ValidatorResult &Source) {
85
41.9k
    Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86
41.9k
  }
87
88
  /// Merge a result.
89
  ///
90
  /// This means to merge the parameters and to set the Type to the most
91
  /// specific Type that matches both.
92
14.0k
  void merge(const ValidatorResult &ToMerge) {
93
14.0k
    Type = std::max(Type, ToMerge.Type);
94
14.0k
    addParamsFrom(ToMerge);
95
14.0k
  }
96
97
  void print(raw_ostream &OS) {
98
    switch (Type) {
99
    case SCEVType::INT:
100
      OS << "SCEVType::INT";
101
      break;
102
    case SCEVType::PARAM:
103
      OS << "SCEVType::PARAM";
104
      break;
105
    case SCEVType::IV:
106
      OS << "SCEVType::IV";
107
      break;
108
    case SCEVType::INVALID:
109
      OS << "SCEVType::INVALID";
110
      break;
111
    }
112
  }
113
};
114
115
0
raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116
0
  VR.print(OS);
117
0
  return OS;
118
0
}
119
120
202
bool polly::isConstCall(llvm::CallInst *Call) {
121
202
  if (Call->mayReadOrWriteMemory())
122
95
    return false;
123
107
124
107
  for (auto &Operand : Call->arg_operands())
125
101
    if (!isa<ConstantInt>(&Operand))
126
43
      return false;
127
107
128
107
  
return true64
;
129
107
}
130
131
/// Check if a SCEV is valid in a SCoP.
132
struct SCEVValidator
133
    : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
134
private:
135
  const Region *R;
136
  Loop *Scope;
137
  ScalarEvolution &SE;
138
  InvariantLoadsSetTy *ILS;
139
140
public:
141
  SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
142
                InvariantLoadsSetTy *ILS)
143
61.0k
      : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
144
145
85.3k
  class ValidatorResult visitConstant(const SCEVConstant *Constant) {
146
85.3k
    return ValidatorResult(SCEVType::INT);
147
85.3k
  }
148
149
  class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
150
1.10k
                                                      const SCEV *Operand) {
151
1.10k
    ValidatorResult Op = visit(Operand);
152
1.10k
    auto Type = Op.getType();
153
1.10k
154
1.10k
    // If unsigned operations are allowed return the operand, otherwise
155
1.10k
    // check if we can model the expression without unsigned assumptions.
156
1.10k
    if (PollyAllowUnsignedOperations || 
Type == SCEVType::INVALID0
)
157
1.10k
      return Op;
158
0
159
0
    if (Type == SCEVType::IV)
160
0
      return ValidatorResult(SCEVType::INVALID);
161
0
    return ValidatorResult(SCEVType::PARAM, Expr);
162
0
  }
163
164
283
  class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
165
283
    return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
166
283
  }
167
168
823
  class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
169
823
    return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
170
823
  }
171
172
2.69k
  class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
173
2.69k
    return visit(Expr->getOperand());
174
2.69k
  }
175
176
3.26k
  class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
177
3.26k
    ValidatorResult Return(SCEVType::INT);
178
3.26k
179
9.98k
    for (int i = 0, e = Expr->getNumOperands(); i < e; 
++i6.71k
) {
180
6.79k
      ValidatorResult Op = visit(Expr->getOperand(i));
181
6.79k
      Return.merge(Op);
182
6.79k
183
6.79k
      // Early exit.
184
6.79k
      if (!Return.isValid())
185
72
        break;
186
6.79k
    }
187
3.26k
188
3.26k
    return Return;
189
3.26k
  }
190
191
4.39k
  class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
192
4.39k
    ValidatorResult Return(SCEVType::INT);
193
4.39k
194
4.39k
    bool HasMultipleParams = false;
195
4.39k
196
14.0k
    for (int i = 0, e = Expr->getNumOperands(); i < e; 
++i9.67k
) {
197
9.73k
      ValidatorResult Op = visit(Expr->getOperand(i));
198
9.73k
199
9.73k
      if (Op.isINT())
200
4.13k
        continue;
201
5.59k
202
5.59k
      if (Op.isPARAM() && 
Return.isPARAM()5.13k
) {
203
1.19k
        HasMultipleParams = true;
204
1.19k
        continue;
205
1.19k
      }
206
4.40k
207
4.40k
      if ((Op.isIV() || 
Op.isPARAM()4.11k
) &&
!Return.isINT()4.23k
) {
208
58
        LLVM_DEBUG(
209
58
            dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
210
58
                   << "\tExpr: " << *Expr << "\n"
211
58
                   << "\tPrevious expression type: " << Return << "\n"
212
58
                   << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
213
58
                   << "\n");
214
58
215
58
        return ValidatorResult(SCEVType::INVALID);
216
58
      }
217
4.34k
218
4.34k
      Return.merge(Op);
219
4.34k
    }
220
4.39k
221
4.39k
    
if (4.33k
HasMultipleParams4.33k
&&
Return.isValid()955
)
222
955
      return ValidatorResult(SCEVType::PARAM, Expr);
223
3.37k
224
3.37k
    return Return;
225
3.37k
  }
226
227
30.4k
  class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
228
30.4k
    if (!Expr->isAffine()) {
229
124
      LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
230
124
      return ValidatorResult(SCEVType::INVALID);
231
124
    }
232
30.3k
233
30.3k
    ValidatorResult Start = visit(Expr->getStart());
234
30.3k
    ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
235
30.3k
236
30.3k
    if (!Start.isValid())
237
1.11k
      return Start;
238
29.2k
239
29.2k
    if (!Recurrence.isValid())
240
0
      return Recurrence;
241
29.2k
242
29.2k
    auto *L = Expr->getLoop();
243
29.2k
    if (R->contains(L) && 
(28.4k
!Scope28.4k
||
!L->contains(Scope)28.4k
)) {
244
16
      LLVM_DEBUG(
245
16
          dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
246
16
                    "non-affine subregion or has a non-synthesizable exit "
247
16
                    "value.");
248
16
      return ValidatorResult(SCEVType::INVALID);
249
16
    }
250
29.1k
251
29.1k
    if (R->contains(L)) {
252
28.4k
      if (Recurrence.isINT()) {
253
27.7k
        ValidatorResult Result(SCEVType::IV);
254
27.7k
        Result.addParamsFrom(Start);
255
27.7k
        return Result;
256
27.7k
      }
257
736
258
736
      LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
259
736
                           "recurrence part");
260
736
      return ValidatorResult(SCEVType::INVALID);
261
736
    }
262
726
263
726
    assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
264
726
265
726
    // Directly generate ValidatorResult for Expr if 'start' is zero.
266
726
    if (Expr->getStart()->isZero())
267
523
      return ValidatorResult(SCEVType::PARAM, Expr);
268
203
269
203
    // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
270
203
    // if 'start' is not zero.
271
203
    const SCEV *ZeroStartExpr = SE.getAddRecExpr(
272
203
        SE.getConstant(Expr->getStart()->getType(), 0),
273
203
        Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
274
203
275
203
    ValidatorResult ZeroStartResult =
276
203
        ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
277
203
    ZeroStartResult.addParamsFrom(Start);
278
203
279
203
    return ZeroStartResult;
280
203
  }
281
282
1.36k
  class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
283
1.36k
    ValidatorResult Return(SCEVType::INT);
284
1.36k
285
4.17k
    for (int i = 0, e = Expr->getNumOperands(); i < e; 
++i2.80k
) {
286
2.80k
      ValidatorResult Op = visit(Expr->getOperand(i));
287
2.80k
288
2.80k
      if (!Op.isValid())
289
0
        return Op;
290
2.80k
291
2.80k
      Return.merge(Op);
292
2.80k
    }
293
1.36k
294
1.36k
    return Return;
295
1.36k
  }
296
297
46
  class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
298
46
    ValidatorResult Return(SCEVType::INT);
299
46
300
149
    for (int i = 0, e = Expr->getNumOperands(); i < e; 
++i103
) {
301
103
      ValidatorResult Op = visit(Expr->getOperand(i));
302
103
303
103
      if (!Op.isValid())
304
0
        return Op;
305
103
306
103
      Return.merge(Op);
307
103
    }
308
46
309
46
    return Return;
310
46
  }
311
312
33
  class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
313
33
    // We do not support unsigned max operations. If 'Expr' is constant during
314
33
    // Scop execution we treat this as a parameter, otherwise we bail out.
315
99
    for (int i = 0, e = Expr->getNumOperands(); i < e; 
++i66
) {
316
66
      ValidatorResult Op = visit(Expr->getOperand(i));
317
66
318
66
      if (!Op.isConstant()) {
319
0
        LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
320
0
        return ValidatorResult(SCEVType::INVALID);
321
0
      }
322
66
    }
323
33
324
33
    return ValidatorResult(SCEVType::PARAM, Expr);
325
33
  }
326
327
6
  class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
328
6
    // We do not support unsigned min operations. If 'Expr' is constant during
329
6
    // Scop execution we treat this as a parameter, otherwise we bail out.
330
18
    for (int i = 0, e = Expr->getNumOperands(); i < e; 
++i12
) {
331
12
      ValidatorResult Op = visit(Expr->getOperand(i));
332
12
333
12
      if (!Op.isConstant()) {
334
0
        LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
335
0
        return ValidatorResult(SCEVType::INVALID);
336
0
      }
337
12
    }
338
6
339
6
    return ValidatorResult(SCEVType::PARAM, Expr);
340
6
  }
341
342
2.44k
  ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
343
2.44k
    if (R->contains(I)) {
344
111
      LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
345
111
                           "within the region\n");
346
111
      return ValidatorResult(SCEVType::INVALID);
347
111
    }
348
2.33k
349
2.33k
    return ValidatorResult(SCEVType::PARAM, S);
350
2.33k
  }
351
352
98
  ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) {
353
98
    assert(I->getOpcode() == Instruction::Call && "Call instruction expected");
354
98
355
98
    if (R->contains(I)) {
356
64
      auto Call = cast<CallInst>(I);
357
64
358
64
      if (!isConstCall(Call))
359
30
        return ValidatorResult(SCEVType::INVALID, S);
360
68
    }
361
68
    return ValidatorResult(SCEVType::PARAM, S);
362
68
  }
363
364
4.99k
  ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
365
4.99k
    if (R->contains(I) && 
ILS2.84k
) {
366
2.84k
      ILS->insert(cast<LoadInst>(I));
367
2.84k
      return ValidatorResult(SCEVType::PARAM, S);
368
2.84k
    }
369
2.15k
370
2.15k
    return visitGenericInst(I, S);
371
2.15k
  }
372
373
  ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
374
                                const SCEV *DivExpr,
375
758
                                Instruction *SDiv = nullptr) {
376
758
377
758
    // First check if we might be able to model the division, thus if the
378
758
    // divisor is constant. If so, check the dividend, otherwise check if
379
758
    // the whole division can be seen as a parameter.
380
758
    if (isa<SCEVConstant>(Divisor) && 
!Divisor->isZero()709
)
381
706
      return visit(Dividend);
382
52
383
52
    // For signed divisions use the SDiv instruction to check for a parameter
384
52
    // division, for unsigned divisions check the operands.
385
52
    if (SDiv)
386
8
      return visitGenericInst(SDiv, DivExpr);
387
44
388
44
    ValidatorResult LHS = visit(Dividend);
389
44
    ValidatorResult RHS = visit(Divisor);
390
44
    if (LHS.isConstant() && RHS.isConstant())
391
44
      return ValidatorResult(SCEVType::PARAM, DivExpr);
392
0
393
0
    LLVM_DEBUG(
394
0
        dbgs() << "INVALID: unsigned division of non-constant expressions");
395
0
    return ValidatorResult(SCEVType::INVALID);
396
0
  }
397
398
410
  ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
399
410
    if (!PollyAllowUnsignedOperations)
400
0
      return ValidatorResult(SCEVType::INVALID);
401
410
402
410
    auto *Dividend = Expr->getLHS();
403
410
    auto *Divisor = Expr->getRHS();
404
410
    return visitDivision(Dividend, Divisor, Expr);
405
410
  }
406
407
348
  ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
408
348
    assert(SDiv->getOpcode() == Instruction::SDiv &&
409
348
           "Assumed SDiv instruction!");
410
348
411
348
    auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
412
348
    auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
413
348
    return visitDivision(Dividend, Divisor, Expr, SDiv);
414
348
  }
415
416
212
  ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
417
212
    assert(SRem->getOpcode() == Instruction::SRem &&
418
212
           "Assumed SRem instruction!");
419
212
420
212
    auto *Divisor = SRem->getOperand(1);
421
212
    auto *CI = dyn_cast<ConstantInt>(Divisor);
422
212
    if (!CI || 
CI->isZeroValue()209
)
423
3
      return visitGenericInst(SRem, S);
424
209
425
209
    auto *Dividend = SRem->getOperand(0);
426
209
    auto *DividendSCEV = SE.getSCEV(Dividend);
427
209
    return visit(DividendSCEV);
428
209
  }
429
430
17.0k
  ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
431
17.0k
    Value *V = Expr->getValue();
432
17.0k
433
17.0k
    if (!Expr->getType()->isIntegerTy() && 
!Expr->getType()->isPointerTy()230
) {
434
0
      LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
435
0
      return ValidatorResult(SCEVType::INVALID);
436
0
    }
437
17.0k
438
17.0k
    if (isa<UndefValue>(V)) {
439
27
      LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
440
27
      return ValidatorResult(SCEVType::INVALID);
441
27
    }
442
16.9k
443
16.9k
    if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
444
6.03k
      switch (I->getOpcode()) {
445
6.03k
      case Instruction::IntToPtr:
446
30
        return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
447
6.03k
      case Instruction::PtrToInt:
448
67
        return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
449
6.03k
      case Instruction::Load:
450
4.99k
        return visitLoadInstruction(I, Expr);
451
6.03k
      case Instruction::SDiv:
452
348
        return visitSDivInstruction(I, Expr);
453
6.03k
      case Instruction::SRem:
454
212
        return visitSRemInstruction(I, Expr);
455
6.03k
      case Instruction::Call:
456
98
        return visitCallInstruction(I, Expr);
457
6.03k
      default:
458
280
        return visitGenericInst(I, Expr);
459
10.9k
      }
460
10.9k
    }
461
10.9k
462
10.9k
    return ValidatorResult(SCEVType::PARAM, Expr);
463
10.9k
  }
464
};
465
466
class SCEVHasIVParams {
467
  bool HasIVParams = false;
468
469
public:
470
11.7k
  SCEVHasIVParams() {}
471
472
26.9k
  bool follow(const SCEV *S) {
473
26.9k
    const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
474
26.9k
    if (!Unknown)
475
25.7k
      return true;
476
1.19k
477
1.19k
    CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
478
1.19k
479
1.19k
    if (!Call)
480
1.19k
      return true;
481
6
482
6
    if (isConstCall(Call)) {
483
6
      HasIVParams = true;
484
6
      return false;
485
6
    }
486
0
487
0
    return true;
488
0
  }
489
490
26.9k
  bool isDone() { return HasIVParams; }
491
11.7k
  bool hasIVParams() { return HasIVParams; }
492
};
493
494
/// Check whether a SCEV refers to an SSA name defined inside a region.
495
class SCEVInRegionDependences {
496
  const Region *R;
497
  Loop *Scope;
498
  const InvariantLoadsSetTy &ILS;
499
  bool AllowLoops;
500
  bool HasInRegionDeps = false;
501
502
public:
503
  SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
504
                          const InvariantLoadsSetTy &ILS)
505
24.7k
      : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
506
507
72.5k
  bool follow(const SCEV *S) {
508
72.5k
    if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
509
22.2k
      Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
510
22.2k
511
22.2k
      CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
512
22.2k
513
22.2k
      if (Call && 
isConstCall(Call)125
)
514
20
        return false;
515
22.2k
516
22.2k
      if (Inst) {
517
12.6k
        // When we invariant load hoist a load, we first make sure that there
518
12.6k
        // can be no dependences created by it in the Scop region. So, we should
519
12.6k
        // not consider scalar dependences to `LoadInst`s that are invariant
520
12.6k
        // load hoisted.
521
12.6k
        //
522
12.6k
        // If this check is not present, then we create data dependences which
523
12.6k
        // are strictly not necessary by tracking the invariant load as a
524
12.6k
        // scalar.
525
12.6k
        LoadInst *LI = dyn_cast<LoadInst>(Inst);
526
12.6k
        if (LI && 
ILS.count(LI) > 08.35k
)
527
2.02k
          return false;
528
20.2k
      }
529
20.2k
530
20.2k
      // Return true when Inst is defined inside the region R.
531
20.2k
      if (!Inst || 
!R->contains(Inst)10.6k
)
532
11.5k
        return true;
533
8.69k
534
8.69k
      HasInRegionDeps = true;
535
8.69k
      return false;
536
8.69k
    }
537
50.2k
538
50.2k
    if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
539
13.8k
      if (AllowLoops)
540
0
        return true;
541
13.8k
542
13.8k
      auto *L = AddRec->getLoop();
543
13.8k
      if (R->contains(L) && 
!L->contains(Scope)13.5k
) {
544
25
        HasInRegionDeps = true;
545
25
        return false;
546
25
      }
547
50.2k
    }
548
50.2k
549
50.2k
    return true;
550
50.2k
  }
551
61.7k
  bool isDone() { return false; }
552
24.7k
  bool hasDependences() { return HasInRegionDeps; }
553
};
554
555
namespace polly {
556
/// Find all loops referenced in SCEVAddRecExprs.
557
class SCEVFindLoops {
558
  SetVector<const Loop *> &Loops;
559
560
public:
561
15.9k
  SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
562
563
42.0k
  bool follow(const SCEV *S) {
564
42.0k
    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
565
9.43k
      Loops.insert(AddRec->getLoop());
566
42.0k
    return true;
567
42.0k
  }
568
42.0k
  bool isDone() { return false; }
569
};
570
571
15.9k
void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
572
15.9k
  SCEVFindLoops FindLoops(Loops);
573
15.9k
  SCEVTraversal<SCEVFindLoops> ST(FindLoops);
574
15.9k
  ST.visitAll(Expr);
575
15.9k
}
576
577
/// Find all values referenced in SCEVUnknowns.
578
class SCEVFindValues {
579
  ScalarEvolution &SE;
580
  SetVector<Value *> &Values;
581
582
public:
583
  SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
584
19.8k
      : SE(SE), Values(Values) {}
585
586
46.6k
  bool follow(const SCEV *S) {
587
46.6k
    const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
588
46.6k
    if (!Unknown)
589
39.6k
      return true;
590
6.96k
591
6.96k
    Values.insert(Unknown->getValue());
592
6.96k
    Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
593
6.96k
    if (!Inst || 
(2.96k
Inst->getOpcode() != Instruction::SRem2.96k
&&
594
2.96k
                  
Inst->getOpcode() != Instruction::SDiv2.92k
))
595
6.84k
      return false;
596
118
597
118
    auto *Dividend = SE.getSCEV(Inst->getOperand(1));
598
118
    if (!isa<SCEVConstant>(Dividend))
599
2
      return false;
600
116
601
116
    auto *Divisor = SE.getSCEV(Inst->getOperand(0));
602
116
    SCEVFindValues FindValues(SE, Values);
603
116
    SCEVTraversal<SCEVFindValues> ST(FindValues);
604
116
    ST.visitAll(Dividend);
605
116
    ST.visitAll(Divisor);
606
116
607
116
    return false;
608
116
  }
609
39.6k
  bool isDone() { return false; }
610
};
611
612
void findValues(const SCEV *Expr, ScalarEvolution &SE,
613
19.7k
                SetVector<Value *> &Values) {
614
19.7k
  SCEVFindValues FindValues(SE, Values);
615
19.7k
  SCEVTraversal<SCEVFindValues> ST(FindValues);
616
19.7k
  ST.visitAll(Expr);
617
19.7k
}
618
619
11.7k
bool hasIVParams(const SCEV *Expr) {
620
11.7k
  SCEVHasIVParams HasIVParams;
621
11.7k
  SCEVTraversal<SCEVHasIVParams> ST(HasIVParams);
622
11.7k
  ST.visitAll(Expr);
623
11.7k
  return HasIVParams.hasIVParams();
624
11.7k
}
625
626
bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
627
                               llvm::Loop *Scope, bool AllowLoops,
628
24.7k
                               const InvariantLoadsSetTy &ILS) {
629
24.7k
  SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
630
24.7k
  SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
631
24.7k
  ST.visitAll(Expr);
632
24.7k
  return InRegionDeps.hasDependences();
633
24.7k
}
634
635
bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
636
48.3k
                  ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
637
48.3k
  if (isa<SCEVCouldNotCompute>(Expr))
638
0
    return false;
639
48.3k
640
48.3k
  SCEVValidator Validator(R, Scope, SE, ILS);
641
48.3k
  LLVM_DEBUG({
642
48.3k
    dbgs() << "\n";
643
48.3k
    dbgs() << "Expr: " << *Expr << "\n";
644
48.3k
    dbgs() << "Region: " << R->getNameStr() << "\n";
645
48.3k
    dbgs() << " -> ";
646
48.3k
  });
647
48.3k
648
48.3k
  ValidatorResult Result = Validator.visit(Expr);
649
48.3k
650
48.3k
  LLVM_DEBUG({
651
48.3k
    if (Result.isValid())
652
48.3k
      dbgs() << "VALID\n";
653
48.3k
    dbgs() << "\n";
654
48.3k
  });
655
48.3k
656
48.3k
  return Result.isValid();
657
48.3k
}
658
659
static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
660
54
                         ScalarEvolution &SE, ParameterSetTy &Params) {
661
54
  auto *E = SE.getSCEV(V);
662
54
  if (isa<SCEVCouldNotCompute>(E))
663
0
    return false;
664
54
665
54
  SCEVValidator Validator(R, Scope, SE, nullptr);
666
54
  ValidatorResult Result = Validator.visit(E);
667
54
  if (!Result.isValid())
668
0
    return false;
669
54
670
54
  auto ResultParams = Result.getParameters();
671
54
  Params.insert(ResultParams.begin(), ResultParams.end());
672
54
673
54
  return true;
674
54
}
675
676
bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
677
                        ScalarEvolution &SE, ParameterSetTy &Params,
678
87
                        bool OrExpr) {
679
87
  if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
680
27
    return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
681
27
                              true) &&
682
27
           isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
683
60
  } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
684
10
    auto Opcode = BinOp->getOpcode();
685
10
    if (Opcode == Instruction::And || 
Opcode == Instruction::Or4
)
686
6
      return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
687
6
                                false) &&
688
6
             isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
689
6
                                false);
690
54
    /* Fall through */
691
54
  }
692
54
693
54
  if (!OrExpr)
694
0
    return false;
695
54
696
54
  return isAffineExpr(V, R, Scope, SE, Params);
697
54
}
698
699
ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
700
12.6k
                                     const SCEV *Expr, ScalarEvolution &SE) {
701
12.6k
  if (isa<SCEVCouldNotCompute>(Expr))
702
0
    return ParameterSetTy();
703
12.6k
704
12.6k
  InvariantLoadsSetTy ILS;
705
12.6k
  SCEVValidator Validator(R, Scope, SE, &ILS);
706
12.6k
  ValidatorResult Result = Validator.visit(Expr);
707
12.6k
  assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
708
12.6k
709
12.6k
  return Result.getParameters();
710
12.6k
}
711
712
std::pair<const SCEVConstant *, const SCEV *>
713
19.7k
extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
714
19.7k
  auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
715
19.7k
716
19.7k
  if (auto *Constant = dyn_cast<SCEVConstant>(S))
717
8.94k
    return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
718
10.7k
719
10.7k
  auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
720
10.7k
  if (AddRec) {
721
4.56k
    auto *StartExpr = AddRec->getStart();
722
4.56k
    if (StartExpr->isZero()) {
723
3.35k
      auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
724
3.35k
      auto *LeftOverAddRec =
725
3.35k
          SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
726
3.35k
                           AddRec->getNoWrapFlags());
727
3.35k
      return std::make_pair(StepPair.first, LeftOverAddRec);
728
3.35k
    }
729
1.20k
    return std::make_pair(ConstPart, S);
730
1.20k
  }
731
6.22k
732
6.22k
  if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
733
254
    SmallVector<const SCEV *, 4> LeftOvers;
734
254
    auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
735
254
    auto *Factor = Op0Pair.first;
736
254
    if (SE.isKnownNegative(Factor)) {
737
87
      Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
738
87
      LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
739
167
    } else {
740
167
      LeftOvers.push_back(Op0Pair.second);
741
167
    }
742
254
743
372
    for (unsigned u = 1, e = Add->getNumOperands(); u < e; 
u++118
) {
744
256
      auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
745
256
      // TODO: Use something smarter than equality here, e.g., gcd.
746
256
      if (Factor == OpUPair.first)
747
114
        LeftOvers.push_back(OpUPair.second);
748
142
      else if (Factor == SE.getNegativeSCEV(OpUPair.first))
749
4
        LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
750
138
      else
751
138
        return std::make_pair(ConstPart, S);
752
256
    }
753
254
754
254
    auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
755
116
    return std::make_pair(Factor, NewAdd);
756
5.97k
  }
757
5.97k
758
5.97k
  auto *Mul = dyn_cast<SCEVMulExpr>(S);
759
5.97k
  if (!Mul)
760
5.32k
    return std::make_pair(ConstPart, S);
761
649
762
649
  SmallVector<const SCEV *, 4> LeftOvers;
763
649
  for (auto *Op : Mul->operands())
764
1.34k
    if (isa<SCEVConstant>(Op))
765
583
      ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
766
765
    else
767
765
      LeftOvers.push_back(Op);
768
649
769
649
  return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
770
649
}
771
772
const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
773
                                 ScalarEvolution &SE, LoopInfo &LI,
774
32.6k
                                 const DominatorTree &DT) {
775
32.6k
  if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
776
5.63k
    Value *V = Unknown->getValue();
777
5.63k
    auto *PHI = dyn_cast<PHINode>(V);
778
5.63k
    if (!PHI)
779
5.58k
      return Expr;
780
48
781
48
    Value *Final = nullptr;
782
48
783
99
    for (unsigned i = 0; i < PHI->getNumIncomingValues(); 
i++51
) {
784
94
      BasicBlock *Incoming = PHI->getIncomingBlock(i);
785
94
      if (isErrorBlock(*Incoming, R, LI, DT) && 
R.contains(Incoming)7
)
786
3
        continue;
787
91
      if (Final)
788
43
        return Expr;
789
48
      Final = PHI->getIncomingValue(i);
790
48
    }
791
48
792
48
    
if (5
Final5
)
793
5
      return SE.getSCEV(Final);
794
27.0k
  }
795
27.0k
  return Expr;
796
27.0k
}
797
798
Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI,
799
11
                              const DominatorTree &DT) {
800
11
  Value *V = nullptr;
801
25
  for (unsigned i = 0; i < PHI->getNumIncomingValues(); 
i++14
) {
802
22
    BasicBlock *BB = PHI->getIncomingBlock(i);
803
22
    if (!isErrorBlock(*BB, *R, LI, DT)) {
804
19
      if (V)
805
8
        return nullptr;
806
11
      V = PHI->getIncomingValue(i);
807
11
    }
808
22
  }
809
11
810
11
  
return V3
;
811
11
}
812
} // namespace polly