Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- HexagonLoopIdiomRecognition.cpp ------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#define DEBUG_TYPE "hexagon-lir"
10
11
#include "llvm/ADT/APInt.h"
12
#include "llvm/ADT/DenseMap.h"
13
#include "llvm/ADT/SetVector.h"
14
#include "llvm/ADT/SmallPtrSet.h"
15
#include "llvm/ADT/SmallSet.h"
16
#include "llvm/ADT/SmallVector.h"
17
#include "llvm/ADT/StringRef.h"
18
#include "llvm/ADT/Triple.h"
19
#include "llvm/Analysis/AliasAnalysis.h"
20
#include "llvm/Analysis/InstructionSimplify.h"
21
#include "llvm/Analysis/LoopInfo.h"
22
#include "llvm/Analysis/LoopPass.h"
23
#include "llvm/Analysis/MemoryLocation.h"
24
#include "llvm/Analysis/ScalarEvolution.h"
25
#include "llvm/Analysis/ScalarEvolutionExpander.h"
26
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
27
#include "llvm/Analysis/TargetLibraryInfo.h"
28
#include "llvm/Transforms/Utils/Local.h"
29
#include "llvm/Analysis/ValueTracking.h"
30
#include "llvm/IR/Attributes.h"
31
#include "llvm/IR/BasicBlock.h"
32
#include "llvm/IR/Constant.h"
33
#include "llvm/IR/Constants.h"
34
#include "llvm/IR/DataLayout.h"
35
#include "llvm/IR/DebugLoc.h"
36
#include "llvm/IR/DerivedTypes.h"
37
#include "llvm/IR/Dominators.h"
38
#include "llvm/IR/Function.h"
39
#include "llvm/IR/IRBuilder.h"
40
#include "llvm/IR/InstrTypes.h"
41
#include "llvm/IR/Instruction.h"
42
#include "llvm/IR/Instructions.h"
43
#include "llvm/IR/IntrinsicInst.h"
44
#include "llvm/IR/Intrinsics.h"
45
#include "llvm/IR/Module.h"
46
#include "llvm/IR/PatternMatch.h"
47
#include "llvm/IR/Type.h"
48
#include "llvm/IR/User.h"
49
#include "llvm/IR/Value.h"
50
#include "llvm/Pass.h"
51
#include "llvm/Support/Casting.h"
52
#include "llvm/Support/CommandLine.h"
53
#include "llvm/Support/Compiler.h"
54
#include "llvm/Support/Debug.h"
55
#include "llvm/Support/ErrorHandling.h"
56
#include "llvm/Support/KnownBits.h"
57
#include "llvm/Support/raw_ostream.h"
58
#include "llvm/Transforms/Scalar.h"
59
#include "llvm/Transforms/Utils.h"
60
#include <algorithm>
61
#include <array>
62
#include <cassert>
63
#include <cstdint>
64
#include <cstdlib>
65
#include <deque>
66
#include <functional>
67
#include <iterator>
68
#include <map>
69
#include <set>
70
#include <utility>
71
#include <vector>
72
73
using namespace llvm;
74
75
static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom",
76
  cl::Hidden, cl::init(false),
77
  cl::desc("Disable generation of memcpy in loop idiom recognition"));
78
79
static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom",
80
  cl::Hidden, cl::init(false),
81
  cl::desc("Disable generation of memmove in loop idiom recognition"));
82
83
static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold",
84
  cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime "
85
  "check guarding the memmove."));
86
87
static cl::opt<unsigned> CompileTimeMemSizeThreshold(
88
  "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64),
89
  cl::desc("Threshold (in bytes) to perform the transformation, if the "
90
    "runtime loop count (mem transfer size) is known at compile-time."));
91
92
static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom",
93
  cl::Hidden, cl::init(true),
94
  cl::desc("Only enable generating memmove in non-nested loops"));
95
96
cl::opt<bool> HexagonVolatileMemcpy("disable-hexagon-volatile-memcpy",
97
  cl::Hidden, cl::init(false),
98
  cl::desc("Enable Hexagon-specific memcpy for volatile destination."));
99
100
static cl::opt<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000),
101
  cl::Hidden, cl::desc("Maximum number of simplification steps in HLIR"));
102
103
static const char *HexagonVolatileMemcpyName
104
  = "hexagon_memcpy_forward_vp4cp4n2";
105
106
107
namespace llvm {
108
109
  void initializeHexagonLoopIdiomRecognizePass(PassRegistry&);
110
  Pass *createHexagonLoopIdiomPass();
111
112
} // end namespace llvm
113
114
namespace {
115
116
  class HexagonLoopIdiomRecognize : public LoopPass {
117
  public:
118
    static char ID;
119
120
19
    explicit HexagonLoopIdiomRecognize() : LoopPass(ID) {
121
19
      initializeHexagonLoopIdiomRecognizePass(*PassRegistry::getPassRegistry());
122
19
    }
123
124
10
    StringRef getPassName() const override {
125
10
      return "Recognize Hexagon-specific loop idioms";
126
10
    }
127
128
19
   void getAnalysisUsage(AnalysisUsage &AU) const override {
129
19
      AU.addRequired<LoopInfoWrapperPass>();
130
19
      AU.addRequiredID(LoopSimplifyID);
131
19
      AU.addRequiredID(LCSSAID);
132
19
      AU.addRequired<AAResultsWrapperPass>();
133
19
      AU.addPreserved<AAResultsWrapperPass>();
134
19
      AU.addRequired<ScalarEvolutionWrapperPass>();
135
19
      AU.addRequired<DominatorTreeWrapperPass>();
136
19
      AU.addRequired<TargetLibraryInfoWrapperPass>();
137
19
      AU.addPreserved<TargetLibraryInfoWrapperPass>();
138
19
    }
139
140
    bool runOnLoop(Loop *L, LPPassManager &LPM) override;
141
142
  private:
143
    int getSCEVStride(const SCEVAddRecExpr *StoreEv);
144
    bool isLegalStore(Loop *CurLoop, StoreInst *SI);
145
    void collectStores(Loop *CurLoop, BasicBlock *BB,
146
        SmallVectorImpl<StoreInst*> &Stores);
147
    bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount);
148
    bool coverLoop(Loop *L, SmallVectorImpl<Instruction*> &Insts) const;
149
    bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount,
150
        SmallVectorImpl<BasicBlock*> &ExitBlocks);
151
    bool runOnCountableLoop(Loop *L);
152
153
    AliasAnalysis *AA;
154
    const DataLayout *DL;
155
    DominatorTree *DT;
156
    LoopInfo *LF;
157
    const TargetLibraryInfo *TLI;
158
    ScalarEvolution *SE;
159
    bool HasMemcpy, HasMemmove;
160
  };
161
162
  struct Simplifier {
163
    struct Rule {
164
      using FuncType = std::function<Value* (Instruction*, LLVMContext&)>;
165
50
      Rule(StringRef N, FuncType F) : Name(N), Fn(F) {}
166
      StringRef Name;   // For debugging.
167
      FuncType Fn;
168
    };
169
170
50
    void addRule(StringRef N, const Rule::FuncType &F) {
171
50
      Rules.push_back(Rule(N, F));
172
50
    }
173
174
  private:
175
    struct WorkListType {
176
1.33k
      WorkListType() = default;
177
178
185k
      void push_back(Value* V) {
179
185k
        // Do not push back duplicates.
180
185k
        if (!S.count(V)) 
{ Q.push_back(V); S.insert(V); }123k
181
185k
      }
182
183
116k
      Value *pop_front_val() {
184
116k
        Value *V = Q.front(); Q.pop_front(); S.erase(V);
185
116k
        return V;
186
116k
      }
187
188
117k
      bool empty() const { return Q.empty(); }
189
190
    private:
191
      std::deque<Value*> Q;
192
      std::set<Value*> S;
193
    };
194
195
    using ValueSetType = std::set<Value *>;
196
197
    std::vector<Rule> Rules;
198
199
  public:
200
    struct Context {
201
      using ValueMapType = DenseMap<Value *, Value *>;
202
203
      Value *Root;
204
      ValueSetType Used;    // The set of all cloned values used by Root.
205
      ValueSetType Clones;  // The set of all cloned values.
206
      LLVMContext &Ctx;
207
208
      Context(Instruction *Exp)
209
4
        : Ctx(Exp->getParent()->getParent()->getContext()) {
210
4
        initialize(Exp);
211
4
      }
212
213
4
      ~Context() { cleanup(); }
214
215
      void print(raw_ostream &OS, const Value *V) const;
216
      Value *materialize(BasicBlock *B, BasicBlock::iterator At);
217
218
    private:
219
      friend struct Simplifier;
220
221
      void initialize(Instruction *Exp);
222
      void cleanup();
223
224
      template <typename FuncT> void traverse(Value *V, FuncT F);
225
      void record(Value *V);
226
      void use(Value *V);
227
      void unuse(Value *V);
228
229
      bool equal(const Instruction *I, const Instruction *J) const;
230
      Value *find(Value *Tree, Value *Sub) const;
231
      Value *subst(Value *Tree, Value *OldV, Value *NewV);
232
      void replace(Value *OldV, Value *NewV);
233
      void link(Instruction *I, BasicBlock *B, BasicBlock::iterator At);
234
    };
235
236
    Value *simplify(Context &C);
237
  };
238
239
  struct PE {
240
0
    PE(const Simplifier::Context &c, Value *v = nullptr) : C(c), V(v) {}
241
242
    const Simplifier::Context &C;
243
    const Value *V;
244
  };
245
246
  LLVM_ATTRIBUTE_USED
247
0
  raw_ostream &operator<<(raw_ostream &OS, const PE &P) {
248
0
    P.C.print(OS, P.V ? P.V : P.C.Root);
249
0
    return OS;
250
0
  }
251
252
} // end anonymous namespace
253
254
char HexagonLoopIdiomRecognize::ID = 0;
255
256
101k
INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognize, "hexagon-loop-idiom",
257
101k
    "Recognize Hexagon-specific loop idioms", false, false)
258
101k
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
259
101k
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
260
101k
INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
261
101k
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
262
101k
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
263
101k
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
264
101k
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
265
101k
INITIALIZE_PASS_END(HexagonLoopIdiomRecognize, "hexagon-loop-idiom",
266
    "Recognize Hexagon-specific loop idioms", false, false)
267
268
template <typename FuncT>
269
408
void Simplifier::Context::traverse(Value *V, FuncT F) {
270
408
  WorkListType Q;
271
408
  Q.push_back(V);
272
408
273
26.2k
  while (!Q.empty()) {
274
25.8k
    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
275
25.8k
    if (!U || 
U->getParent()22.3k
)
276
5.72k
      continue;
277
20.1k
    if (!F(U))
278
195
      continue;
279
19.9k
    for (Value *Op : U->operands())
280
38.5k
      Q.push_back(Op);
281
19.9k
  }
282
408
}
HexagonLoopIdiomRecognition.cpp:void (anonymous namespace)::Simplifier::Context::traverse<(anonymous namespace)::Simplifier::Context::record(llvm::Value*)::$_0>(llvm::Value*, (anonymous namespace)::Simplifier::Context::record(llvm::Value*)::$_0)
Line
Count
Source
269
125
void Simplifier::Context::traverse(Value *V, FuncT F) {
270
125
  WorkListType Q;
271
125
  Q.push_back(V);
272
125
273
5.03k
  while (!Q.empty()) {
274
4.90k
    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
275
4.90k
    if (!U || 
U->getParent()3.87k
)
276
1.54k
      continue;
277
3.36k
    if (!F(U))
278
0
      continue;
279
3.36k
    for (Value *Op : U->operands())
280
6.00k
      Q.push_back(Op);
281
3.36k
  }
282
125
}
HexagonLoopIdiomRecognition.cpp:void (anonymous namespace)::Simplifier::Context::traverse<(anonymous namespace)::Simplifier::Context::use(llvm::Value*)::$_1>(llvm::Value*, (anonymous namespace)::Simplifier::Context::use(llvm::Value*)::$_1)
Line
Count
Source
269
125
void Simplifier::Context::traverse(Value *V, FuncT F) {
270
125
  WorkListType Q;
271
125
  Q.push_back(V);
272
125
273
20.7k
  while (!Q.empty()) {
274
20.6k
    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
275
20.6k
    if (!U || 
U->getParent()18.2k
)
276
4.15k
      continue;
277
16.5k
    if (!F(U))
278
0
      continue;
279
16.5k
    for (Value *Op : U->operands())
280
32.4k
      Q.push_back(Op);
281
16.5k
  }
282
125
}
HexagonLoopIdiomRecognition.cpp:void (anonymous namespace)::Simplifier::Context::traverse<(anonymous namespace)::Simplifier::Context::unuse(llvm::Value*)::$_2>(llvm::Value*, (anonymous namespace)::Simplifier::Context::unuse(llvm::Value*)::$_2)
Line
Count
Source
269
158
void Simplifier::Context::traverse(Value *V, FuncT F) {
270
158
  WorkListType Q;
271
158
  Q.push_back(V);
272
158
273
436
  while (!Q.empty()) {
274
278
    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
275
278
    if (!U || 
U->getParent()246
)
276
32
      continue;
277
246
    if (!F(U))
278
195
      continue;
279
51
    for (Value *Op : U->operands())
280
120
      Q.push_back(Op);
281
51
  }
282
158
}
283
284
0
void Simplifier::Context::print(raw_ostream &OS, const Value *V) const {
285
0
  const auto *U = dyn_cast<const Instruction>(V);
286
0
  if (!U) {
287
0
    OS << V << '(' << *V << ')';
288
0
    return;
289
0
  }
290
0
291
0
  if (U->getParent()) {
292
0
    OS << U << '(';
293
0
    U->printAsOperand(OS, true);
294
0
    OS << ')';
295
0
    return;
296
0
  }
297
0
298
0
  unsigned N = U->getNumOperands();
299
0
  if (N != 0)
300
0
    OS << U << '(';
301
0
  OS << U->getOpcodeName();
302
0
  for (const Value *Op : U->operands()) {
303
0
    OS << ' ';
304
0
    print(OS, Op);
305
0
  }
306
0
  if (N != 0)
307
0
    OS << ')';
308
0
}
309
310
4
void Simplifier::Context::initialize(Instruction *Exp) {
311
4
  // Perform a deep clone of the expression, set Root to the root
312
4
  // of the clone, and build a map from the cloned values to the
313
4
  // original ones.
314
4
  ValueMapType M;
315
4
  BasicBlock *Block = Exp->getParent();
316
4
  WorkListType Q;
317
4
  Q.push_back(Exp);
318
4
319
95
  while (!Q.empty()) {
320
91
    Value *V = Q.pop_front_val();
321
91
    if (M.find(V) != M.end())
322
5
      continue;
323
86
    if (Instruction *U = dyn_cast<Instruction>(V)) {
324
66
      if (isa<PHINode>(U) || 
U->getParent() != Block53
)
325
14
        continue;
326
52
      for (Value *Op : U->operands())
327
105
        Q.push_back(Op);
328
52
      M.insert({U, U->clone()});
329
52
    }
330
86
  }
331
4
332
52
  for (std::pair<Value*,Value*> P : M) {
333
52
    Instruction *U = cast<Instruction>(P.second);
334
157
    for (unsigned i = 0, n = U->getNumOperands(); i != n; 
++i105
) {
335
105
      auto F = M.find(U->getOperand(i));
336
105
      if (F != M.end())
337
60
        U->setOperand(i, F->second);
338
105
    }
339
52
  }
340
4
341
4
  auto R = M.find(Exp);
342
4
  assert(R != M.end());
343
4
  Root = R->second;
344
4
345
4
  record(Root);
346
4
  use(Root);
347
4
}
348
349
125
void Simplifier::Context::record(Value *V) {
350
3.36k
  auto Record = [this](Instruction *U) -> bool {
351
3.36k
    Clones.insert(U);
352
3.36k
    return true;
353
3.36k
  };
354
125
  traverse(V, Record);
355
125
}
356
357
125
void Simplifier::Context::use(Value *V) {
358
16.5k
  auto Use = [this](Instruction *U) -> bool {
359
16.5k
    Used.insert(U);
360
16.5k
    return true;
361
16.5k
  };
362
125
  traverse(V, Use);
363
125
}
364
365
158
void Simplifier::Context::unuse(Value *V) {
366
158
  if (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != nullptr)
367
0
    return;
368
158
369
246
  
auto Unuse = [this](Instruction *U) -> bool 158
{
370
246
    if (!U->use_empty())
371
195
      return false;
372
51
    Used.erase(U);
373
51
    return true;
374
51
  };
375
158
  traverse(V, Unuse);
376
158
}
377
378
125
Value *Simplifier::Context::subst(Value *Tree, Value *OldV, Value *NewV) {
379
125
  if (Tree == OldV)
380
0
    return NewV;
381
125
  if (OldV == NewV)
382
0
    return Tree;
383
125
384
125
  WorkListType Q;
385
125
  Q.push_back(Tree);
386
19.4k
  while (!Q.empty()) {
387
19.3k
    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
388
19.3k
    // If U is not an instruction, or it's not a clone, skip it.
389
19.3k
    if (!U || 
U->getParent()16.9k
)
390
3.99k
      continue;
391
45.3k
    
for (unsigned i = 0, n = U->getNumOperands(); 15.3k
i != n;
++i30.0k
) {
392
30.0k
      Value *Op = U->getOperand(i);
393
30.0k
      if (Op == OldV) {
394
158
        U->setOperand(i, NewV);
395
158
        unuse(OldV);
396
29.8k
      } else {
397
29.8k
        Q.push_back(Op);
398
29.8k
      }
399
30.0k
    }
400
15.3k
  }
401
125
  return Tree;
402
125
}
403
404
121
void Simplifier::Context::replace(Value *OldV, Value *NewV) {
405
121
  if (Root == OldV) {
406
0
    Root = NewV;
407
0
    use(Root);
408
0
    return;
409
0
  }
410
121
411
121
  // NewV may be a complex tree that has just been created by one of the
412
121
  // transformation rules. We need to make sure that it is commoned with
413
121
  // the existing Root to the maximum extent possible.
414
121
  // Identify all subtrees of NewV (including NewV itself) that have
415
121
  // equivalent counterparts in Root, and replace those subtrees with
416
121
  // these counterparts.
417
121
  WorkListType Q;
418
121
  Q.push_back(NewV);
419
883
  while (!Q.empty()) {
420
762
    Value *V = Q.pop_front_val();
421
762
    Instruction *U = dyn_cast<Instruction>(V);
422
762
    if (!U || 
U->getParent()672
)
423
92
      continue;
424
670
    if (Value *DupV = find(Root, V)) {
425
364
      if (DupV != V)
426
4
        NewV = subst(NewV, V, DupV);
427
364
    } else {
428
306
      for (Value *Op : U->operands())
429
732
        Q.push_back(Op);
430
306
    }
431
670
  }
432
121
433
121
  // Now, simply replace OldV with NewV in Root.
434
121
  Root = subst(Root, OldV, NewV);
435
121
  use(Root);
436
121
}
437
438
4
void Simplifier::Context::cleanup() {
439
358
  for (Value *V : Clones) {
440
358
    Instruction *U = cast<Instruction>(V);
441
358
    if (!U->getParent())
442
338
      U->dropAllReferences();
443
358
  }
444
4
445
358
  for (Value *V : Clones) {
446
358
    Instruction *U = cast<Instruction>(V);
447
358
    if (!U->getParent())
448
338
      U->deleteValue();
449
358
  }
450
4
}
451
452
bool Simplifier::Context::equal(const Instruction *I,
453
74.6k
                                const Instruction *J) const {
454
74.6k
  if (I == J)
455
0
    return true;
456
74.6k
  if (!I->isSameOperationAs(J))
457
51.9k
    return false;
458
22.6k
  if (isa<PHINode>(I))
459
595
    return I->isIdenticalTo(J);
460
22.0k
461
26.3k
  
for (unsigned i = 0, n = I->getNumOperands(); 22.0k
i != n;
++i4.33k
) {
462
26.1k
    Value *OpI = I->getOperand(i), *OpJ = J->getOperand(i);
463
26.1k
    if (OpI == OpJ)
464
4.09k
      continue;
465
22.0k
    auto *InI = dyn_cast<const Instruction>(OpI);
466
22.0k
    auto *InJ = dyn_cast<const Instruction>(OpJ);
467
22.0k
    if (InI && 
InJ19.8k
) {
468
18.6k
      if (!equal(InI, InJ))
469
18.4k
        return false;
470
3.40k
    } else if (InI != InJ || 
!InI22
)
471
3.40k
      return false;
472
22.0k
  }
473
22.0k
  
return true238
;
474
22.0k
}
475
476
670
Value *Simplifier::Context::find(Value *Tree, Value *Sub) const {
477
670
  Instruction *SubI = dyn_cast<Instruction>(Sub);
478
670
  WorkListType Q;
479
670
  Q.push_back(Tree);
480
670
481
69.6k
  while (!Q.empty()) {
482
69.3k
    Value *V = Q.pop_front_val();
483
69.3k
    if (V == Sub)
484
360
      return V;
485
68.9k
    Instruction *U = dyn_cast<Instruction>(V);
486
68.9k
    if (!U || 
U->getParent()61.1k
)
487
12.9k
      continue;
488
56.0k
    if (SubI && equal(SubI, U))
489
4
      return U;
490
55.9k
    assert(!isa<PHINode>(U));
491
55.9k
    for (Value *Op : U->operands())
492
113k
      Q.push_back(Op);
493
55.9k
  }
494
670
  
return nullptr306
;
495
670
}
496
497
void Simplifier::Context::link(Instruction *I, BasicBlock *B,
498
33
      BasicBlock::iterator At) {
499
33
  if (I->getParent())
500
13
    return;
501
20
502
41
  
for (Value *Op : I->operands())20
{
503
41
    if (Instruction *OpI = dyn_cast<Instruction>(Op))
504
30
      link(OpI, B, At);
505
41
  }
506
20
507
20
  B->getInstList().insert(At, I);
508
20
}
509
510
Value *Simplifier::Context::materialize(BasicBlock *B,
511
3
      BasicBlock::iterator At) {
512
3
  if (Instruction *RootI = dyn_cast<Instruction>(Root))
513
3
    link(RootI, B, At);
514
3
  return Root;
515
3
}
516
517
4
Value *Simplifier::simplify(Context &C) {
518
4
  WorkListType Q;
519
4
  Q.push_back(C.Root);
520
4
  unsigned Count = 0;
521
4
  const unsigned Limit = SimplifyLimit;
522
4
523
1.28k
  while (!Q.empty()) {
524
1.27k
    if (Count++ >= Limit)
525
0
      break;
526
1.27k
    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
527
1.27k
    if (!U || 
U->getParent()1.14k
||
!C.Used.count(U)1.03k
)
528
241
      continue;
529
1.03k
    bool Changed = false;
530
6.71k
    for (Rule &R : Rules) {
531
6.71k
      Value *W = R.Fn(U, C.Ctx);
532
6.71k
      if (!W)
533
6.59k
        continue;
534
121
      Changed = true;
535
121
      C.record(W);
536
121
      C.replace(U, W);
537
121
      Q.push_back(C.Root);
538
121
      break;
539
121
    }
540
1.03k
    if (!Changed) {
541
914
      for (Value *Op : U->operands())
542
1.82k
        Q.push_back(Op);
543
914
    }
544
1.03k
  }
545
4
  return Count < Limit ? C.Root : 
nullptr0
;
546
4
}
547
548
//===----------------------------------------------------------------------===//
549
//
550
//          Implementation of PolynomialMultiplyRecognize
551
//
552
//===----------------------------------------------------------------------===//
553
554
namespace {
555
556
  class PolynomialMultiplyRecognize {
557
  public:
558
    explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl,
559
        const DominatorTree &dt, const TargetLibraryInfo &tli,
560
        ScalarEvolution &se)
561
8
      : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {}
562
563
    bool recognize();
564
565
  private:
566
    using ValueSeq = SetVector<Value *>;
567
568
1
    IntegerType *getPmpyType() const {
569
1
      LLVMContext &Ctx = CurLoop->getHeader()->getParent()->getContext();
570
1
      return IntegerType::get(Ctx, 32);
571
1
    }
572
573
    bool isPromotableTo(Value *V, IntegerType *Ty);
574
    void promoteTo(Instruction *In, IntegerType *DestTy, BasicBlock *LoopB);
575
    bool promoteTypes(BasicBlock *LoopB, BasicBlock *ExitB);
576
577
    Value *getCountIV(BasicBlock *BB);
578
    bool findCycle(Value *Out, Value *In, ValueSeq &Cycle);
579
    void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early,
580
          ValueSeq &Late);
581
    bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late);
582
    bool commutesWithShift(Instruction *I);
583
    bool highBitsAreZero(Value *V, unsigned IterCount);
584
    bool keepsHighBitsZero(Value *V, unsigned IterCount);
585
    bool isOperandShifted(Instruction *I, Value *Op);
586
    bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB,
587
          unsigned IterCount);
588
    void cleanupLoopBody(BasicBlock *LoopB);
589
590
    struct ParsedValues {
591
7
      ParsedValues() = default;
592
593
      Value *M = nullptr;
594
      Value *P = nullptr;
595
      Value *Q = nullptr;
596
      Value *R = nullptr;
597
      Value *X = nullptr;
598
      Instruction *Res = nullptr;
599
      unsigned IterCount = 0;
600
      bool Left = false;
601
      bool Inv = false;
602
    };
603
604
    bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV);
605
    bool matchRightShift(SelectInst *SelI, ParsedValues &PV);
606
    bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB,
607
          Value *CIV, ParsedValues &PV, bool PreScan);
608
    unsigned getInverseMxN(unsigned QP);
609
    Value *generate(BasicBlock::iterator At, ParsedValues &PV);
610
611
    void setupPreSimplifier(Simplifier &S);
612
    void setupPostSimplifier(Simplifier &S);
613
614
    Loop *CurLoop;
615
    const DataLayout &DL;
616
    const DominatorTree &DT;
617
    const TargetLibraryInfo &TLI;
618
    ScalarEvolution &SE;
619
  };
620
621
} // end anonymous namespace
622
623
8
Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) {
624
8
  pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
625
8
  if (std::distance(PI, PE) != 2)
626
0
    return nullptr;
627
8
  BasicBlock *PB = (*PI == BB) ? 
*std::next(PI)5
:
*PI3
;
628
8
629
14
  for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); 
++I6
) {
630
12
    auto *PN = cast<PHINode>(I);
631
12
    Value *InitV = PN->getIncomingValueForBlock(PB);
632
12
    if (!isa<ConstantInt>(InitV) || 
!cast<ConstantInt>(InitV)->isZero()7
)
633
5
      continue;
634
7
    Value *IterV = PN->getIncomingValueForBlock(BB);
635
7
    if (!isa<BinaryOperator>(IterV))
636
1
      continue;
637
6
    auto *BO = dyn_cast<BinaryOperator>(IterV);
638
6
    if (BO->getOpcode() != Instruction::Add)
639
0
      continue;
640
6
    Value *IncV = nullptr;
641
6
    if (BO->getOperand(0) == PN)
642
6
      IncV = BO->getOperand(1);
643
0
    else if (BO->getOperand(1) == PN)
644
0
      IncV = BO->getOperand(0);
645
6
    if (IncV == nullptr)
646
0
      continue;
647
6
648
6
    if (auto *T = dyn_cast<ConstantInt>(IncV))
649
6
      if (T->getZExtValue() == 1)
650
6
        return PN;
651
6
  }
652
8
  
return nullptr2
;
653
8
}
654
655
2
static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) {
656
5
  for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) {
657
3
    Use &TheUse = UI.getUse();
658
3
    ++UI;
659
3
    if (auto *II = dyn_cast<Instruction>(TheUse.getUser()))
660
3
      if (BB == II->getParent())
661
3
        II->replaceUsesOfWith(I, J);
662
3
  }
663
2
}
664
665
bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI,
666
5
      Value *CIV, ParsedValues &PV) {
667
5
  // Match the following:
668
5
  //   select (X & (1 << i)) != 0 ? R ^ (Q << i) : R
669
5
  //   select (X & (1 << i)) == 0 ? R : R ^ (Q << i)
670
5
  // The condition may also check for equality with the masked value, i.e
671
5
  //   select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R
672
5
  //   select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i);
673
5
674
5
  Value *CondV = SelI->getCondition();
675
5
  Value *TrueV = SelI->getTrueValue();
676
5
  Value *FalseV = SelI->getFalseValue();
677
5
678
5
  using namespace PatternMatch;
679
5
680
5
  CmpInst::Predicate P;
681
5
  Value *A = nullptr, *B = nullptr, *C = nullptr;
682
5
683
5
  if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) &&
684
5
      
!match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B))))1
)
685
1
    return false;
686
4
  if (P != CmpInst::ICMP_EQ && 
P != CmpInst::ICMP_NE0
)
687
0
    return false;
688
4
  // Matched: select (A & B) == C ? ... : ...
689
4
  //          select (A & B) != C ? ... : ...
690
4
691
4
  Value *X = nullptr, *Sh1 = nullptr;
692
4
  // Check (A & B) for (X & (1 << i)):
693
4
  if (match(A, m_Shl(m_One(), m_Specific(CIV)))) {
694
2
    Sh1 = A;
695
2
    X = B;
696
2
  } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) {
697
1
    Sh1 = B;
698
1
    X = A;
699
1
  } else {
700
1
    // TODO: Could also check for an induction variable containing single
701
1
    // bit shifted left by 1 in each iteration.
702
1
    return false;
703
1
  }
704
3
705
3
  bool TrueIfZero;
706
3
707
3
  // Check C against the possible values for comparison: 0 and (1 << i):
708
3
  if (match(C, m_Zero()))
709
3
    TrueIfZero = (P == CmpInst::ICMP_EQ);
710
0
  else if (C == Sh1)
711
0
    TrueIfZero = (P == CmpInst::ICMP_NE);
712
0
  else
713
0
    return false;
714
3
715
3
  // So far, matched:
716
3
  //   select (X & (1 << i)) ? ... : ...
717
3
  // including variations of the check against zero/non-zero value.
718
3
719
3
  Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr;
720
3
  if (TrueIfZero) {
721
3
    ShouldSameV = TrueV;
722
3
    ShouldXoredV = FalseV;
723
3
  } else {
724
0
    ShouldSameV = FalseV;
725
0
    ShouldXoredV = TrueV;
726
0
  }
727
3
728
3
  Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr;
729
3
  Value *T = nullptr;
730
3
  if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) {
731
3
    // Matched: select +++ ? ... : Y ^ Z
732
3
    //          select +++ ? Y ^ Z : ...
733
3
    // where +++ denotes previously checked matches.
734
3
    if (ShouldSameV == Y)
735
1
      T = Z;
736
2
    else if (ShouldSameV == Z)
737
2
      T = Y;
738
0
    else
739
0
      return false;
740
3
    R = ShouldSameV;
741
3
    // Matched: select +++ ? R : R ^ T
742
3
    //          select +++ ? R ^ T : R
743
3
    // depending on TrueIfZero.
744
3
745
3
  } else 
if (0
match(ShouldSameV, m_Zero())0
) {
746
0
    // Matched: select +++ ? 0 : ...
747
0
    //          select +++ ? ... : 0
748
0
    if (!SelI->hasOneUse())
749
0
      return false;
750
0
    T = ShouldXoredV;
751
0
    // Matched: select +++ ? 0 : T
752
0
    //          select +++ ? T : 0
753
0
754
0
    Value *U = *SelI->user_begin();
755
0
    if (!match(U, m_Xor(m_Specific(SelI), m_Value(R))) &&
756
0
        !match(U, m_Xor(m_Value(R), m_Specific(SelI))))
757
0
      return false;
758
0
    // Matched: xor (select +++ ? 0 : T), R
759
0
    //          xor (select +++ ? T : 0), R
760
0
  } else
761
0
    return false;
762
3
763
3
  // The xor input value T is isolated into its own match so that it could
764
3
  // be checked against an induction variable containing a shifted bit
765
3
  // (todo).
766
3
  // For now, check against (Q << i).
767
3
  if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) &&
768
3
      
!match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV))))2
)
769
0
    return false;
770
3
  // Matched: select +++ ? R : R ^ (Q << i)
771
3
  //          select +++ ? R ^ (Q << i) : R
772
3
773
3
  PV.X = X;
774
3
  PV.Q = Q;
775
3
  PV.R = R;
776
3
  PV.Left = true;
777
3
  return true;
778
3
}
779
780
bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI,
781
2
      ParsedValues &PV) {
782
2
  // Match the following:
783
2
  //   select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1)
784
2
  //   select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q
785
2
  // The condition may also check for equality with the masked value, i.e
786
2
  //   select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1)
787
2
  //   select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q
788
2
789
2
  Value *CondV = SelI->getCondition();
790
2
  Value *TrueV = SelI->getTrueValue();
791
2
  Value *FalseV = SelI->getFalseValue();
792
2
793
2
  using namespace PatternMatch;
794
2
795
2
  Value *C = nullptr;
796
2
  CmpInst::Predicate P;
797
2
  bool TrueIfZero;
798
2
799
2
  if (match(CondV, m_ICmp(P, m_Value(C), m_Zero())) ||
800
2
      
match(CondV, m_ICmp(P, m_Zero(), m_Value(C)))1
) {
801
1
    if (P != CmpInst::ICMP_EQ && 
P != CmpInst::ICMP_NE0
)
802
0
      return false;
803
1
    // Matched: select C == 0 ? ... : ...
804
1
    //          select C != 0 ? ... : ...
805
1
    TrueIfZero = (P == CmpInst::ICMP_EQ);
806
1
  } else if (match(CondV, m_ICmp(P, m_Value(C), m_One())) ||
807
1
             match(CondV, m_ICmp(P, m_One(), m_Value(C)))) {
808
0
    if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
809
0
      return false;
810
0
    // Matched: select C == 1 ? ... : ...
811
0
    //          select C != 1 ? ... : ...
812
0
    TrueIfZero = (P == CmpInst::ICMP_NE);
813
0
  } else
814
1
    return false;
815
1
816
1
  Value *X = nullptr;
817
1
  if (!match(C, m_And(m_Value(X), m_One())) &&
818
1
      
!match(C, m_And(m_One(), m_Value(X)))0
)
819
0
    return false;
820
1
  // Matched: select (X & 1) == +++ ? ... : ...
821
1
  //          select (X & 1) != +++ ? ... : ...
822
1
823
1
  Value *R = nullptr, *Q = nullptr;
824
1
  if (TrueIfZero) {
825
1
    // The select's condition is true if the tested bit is 0.
826
1
    // TrueV must be the shift, FalseV must be the xor.
827
1
    if (!match(TrueV, m_LShr(m_Value(R), m_One())))
828
0
      return false;
829
1
    // Matched: select +++ ? (R >> 1) : ...
830
1
    if (!match(FalseV, m_Xor(m_Specific(TrueV), m_Value(Q))) &&
831
1
        
!match(FalseV, m_Xor(m_Value(Q), m_Specific(TrueV)))0
)
832
0
      return false;
833
0
    // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q
834
0
    // with commuting ^.
835
0
  } else {
836
0
    // The select's condition is true if the tested bit is 1.
837
0
    // TrueV must be the xor, FalseV must be the shift.
838
0
    if (!match(FalseV, m_LShr(m_Value(R), m_One())))
839
0
      return false;
840
0
    // Matched: select +++ ? ... : (R >> 1)
841
0
    if (!match(TrueV, m_Xor(m_Specific(FalseV), m_Value(Q))) &&
842
0
        !match(TrueV, m_Xor(m_Value(Q), m_Specific(FalseV))))
843
0
      return false;
844
1
    // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1)
845
1
    // with commuting ^.
846
1
  }
847
1
848
1
  PV.X = X;
849
1
  PV.Q = Q;
850
1
  PV.R = R;
851
1
  PV.Left = false;
852
1
  return true;
853
1
}
854
855
bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI,
856
      BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV,
857
5
      bool PreScan) {
858
5
  using namespace PatternMatch;
859
5
860
5
  // The basic pattern for R = P.Q is:
861
5
  // for i = 0..31
862
5
  //   R = phi (0, R')
863
5
  //   if (P & (1 << i))        ; test-bit(P, i)
864
5
  //     R' = R ^ (Q << i)
865
5
  //
866
5
  // Similarly, the basic pattern for R = (P/Q).Q - P
867
5
  // for i = 0..31
868
5
  //   R = phi(P, R')
869
5
  //   if (R & (1 << i))
870
5
  //     R' = R ^ (Q << i)
871
5
872
5
  // There exist idioms, where instead of Q being shifted left, P is shifted
873
5
  // right. This produces a result that is shifted right by 32 bits (the
874
5
  // non-shifted result is 64-bit).
875
5
  //
876
5
  // For R = P.Q, this would be:
877
5
  // for i = 0..31
878
5
  //   R = phi (0, R')
879
5
  //   if ((P >> i) & 1)
880
5
  //     R' = (R >> 1) ^ Q      ; R is cycled through the loop, so it must
881
5
  //   else                     ; be shifted by 1, not i.
882
5
  //     R' = R >> 1
883
5
  //
884
5
  // And for the inverse:
885
5
  // for i = 0..31
886
5
  //   R = phi (P, R')
887
5
  //   if (R & 1)
888
5
  //     R' = (R >> 1) ^ Q
889
5
  //   else
890
5
  //     R' = R >> 1
891
5
892
5
  // The left-shifting idioms share the same pattern:
893
5
  //   select (X & (1 << i)) ? R ^ (Q << i) : R
894
5
  // Similarly for right-shifting idioms:
895
5
  //   select (X & 1) ? (R >> 1) ^ Q
896
5
897
5
  if (matchLeftShift(SelI, CIV, PV)) {
898
3
    // If this is a pre-scan, getting this far is sufficient.
899
3
    if (PreScan)
900
1
      return true;
901
2
902
2
    // Need to make sure that the SelI goes back into R.
903
2
    auto *RPhi = dyn_cast<PHINode>(PV.R);
904
2
    if (!RPhi)
905
0
      return false;
906
2
    if (SelI != RPhi->getIncomingValueForBlock(LoopB))
907
0
      return false;
908
2
    PV.Res = SelI;
909
2
910
2
    // If X is loop invariant, it must be the input polynomial, and the
911
2
    // idiom is the basic polynomial multiply.
912
2
    if (CurLoop->isLoopInvariant(PV.X)) {
913
1
      PV.P = PV.X;
914
1
      PV.Inv = false;
915
1
    } else {
916
1
      // X is not loop invariant. If X == R, this is the inverse pmpy.
917
1
      // Otherwise, check for an xor with an invariant value. If the
918
1
      // variable argument to the xor is R, then this is still a valid
919
1
      // inverse pmpy.
920
1
      PV.Inv = true;
921
1
      if (PV.X != PV.R) {
922
1
        Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr;
923
1
        if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2))))
924
0
          return false;
925
1
        auto *I1 = dyn_cast<Instruction>(X1);
926
1
        auto *I2 = dyn_cast<Instruction>(X2);
927
1
        if (!I1 || I1->getParent() != LoopB) {
928
0
          Var = X2;
929
0
          Inv = X1;
930
1
        } else if (!I2 || I2->getParent() != LoopB) {
931
1
          Var = X1;
932
1
          Inv = X2;
933
1
        } else
934
0
          return false;
935
1
        if (Var != PV.R)
936
0
          return false;
937
1
        PV.M = Inv;
938
1
      }
939
1
      // The input polynomial P still needs to be determined. It will be
940
1
      // the entry value of R.
941
1
      Value *EntryP = RPhi->getIncomingValueForBlock(PrehB);
942
1
      PV.P = EntryP;
943
1
    }
944
2
945
2
    return true;
946
2
  }
947
2
948
2
  if (matchRightShift(SelI, PV)) {
949
1
    // If this is an inverse pattern, the Q polynomial must be known at
950
1
    // compile time.
951
1
    if (PV.Inv && 
!isa<ConstantInt>(PV.Q)0
)
952
0
      return false;
953
1
    if (PreScan)
954
1
      return true;
955
0
    // There is no exact matching of right-shift pmpy.
956
0
    return false;
957
0
  }
958
1
959
1
  return false;
960
1
}
961
962
bool PolynomialMultiplyRecognize::isPromotableTo(Value *Val,
963
13
      IntegerType *DestTy) {
964
13
  IntegerType *T = dyn_cast<IntegerType>(Val->getType());
965
13
  if (!T || T->getBitWidth() > DestTy->getBitWidth())
966
0
    return false;
967
13
  if (T->getBitWidth() == DestTy->getBitWidth())
968
0
    return true;
969
13
  // Non-instructions are promotable. The reason why an instruction may not
970
13
  // be promotable is that it may produce a different result if its operands
971
13
  // and the result are promoted, for example, it may produce more non-zero
972
13
  // bits. While it would still be possible to represent the proper result
973
13
  // in a wider type, it may require adding additional instructions (which
974
13
  // we don't want to do).
975
13
  Instruction *In = dyn_cast<Instruction>(Val);
976
13
  if (!In)
977
0
    return true;
978
13
  // The bitwidth of the source type is smaller than the destination.
979
13
  // Check if the individual operation can be promoted.
980
13
  switch (In->getOpcode()) {
981
13
    case Instruction::PHI:
982
10
    case Instruction::ZExt:
983
10
    case Instruction::And:
984
10
    case Instruction::Or:
985
10
    case Instruction::Xor:
986
10
    case Instruction::LShr: // Shift right is ok.
987
10
    case Instruction::Select:
988
10
    case Instruction::Trunc:
989
10
      return true;
990
10
    case Instruction::ICmp:
991
2
      if (CmpInst *CI = cast<CmpInst>(In))
992
2
        return CI->isEquality() || 
CI->isUnsigned()1
;
993
0
      llvm_unreachable("Cast failed unexpectedly");
994
1
    case Instruction::Add:
995
1
      return In->hasNoSignedWrap() && In->hasNoUnsignedWrap();
996
0
  }
997
0
  return false;
998
0
}
999
1000
void PolynomialMultiplyRecognize::promoteTo(Instruction *In,
1001
13
      IntegerType *DestTy, BasicBlock *LoopB) {
1002
13
  Type *OrigTy = In->getType();
1003
13
  assert(!OrigTy->isVoidTy() && "Invalid instruction to promote");
1004
13
1005
13
  // Leave boolean values alone.
1006
13
  if (!In->getType()->isIntegerTy(1))
1007
11
    In->mutateType(DestTy);
1008
13
  unsigned DestBW = DestTy->getBitWidth();
1009
13
1010
13
  // Handle PHIs.
1011
13
  if (PHINode *P = dyn_cast<PHINode>(In)) {
1012
3
    unsigned N = P->getNumIncomingValues();
1013
9
    for (unsigned i = 0; i != N; 
++i6
) {
1014
6
      BasicBlock *InB = P->getIncomingBlock(i);
1015
6
      if (InB == LoopB)
1016
3
        continue;
1017
3
      Value *InV = P->getIncomingValue(i);
1018
3
      IntegerType *Ty = cast<IntegerType>(InV->getType());
1019
3
      // Do not promote values in PHI nodes of type i1.
1020
3
      if (Ty != P->getType()) {
1021
3
        // If the value type does not match the PHI type, the PHI type
1022
3
        // must have been promoted.
1023
3
        assert(Ty->getBitWidth() < DestBW);
1024
3
        InV = IRBuilder<>(InB->getTerminator()).CreateZExt(InV, DestTy);
1025
3
        P->setIncomingValue(i, InV);
1026
3
      }
1027
3
    }
1028
10
  } else if (ZExtInst *Z = dyn_cast<ZExtInst>(In)) {
1029
0
    Value *Op = Z->getOperand(0);
1030
0
    if (Op->getType() == Z->getType())
1031
0
      Z->replaceAllUsesWith(Op);
1032
0
    Z->eraseFromParent();
1033
0
    return;
1034
0
  }
1035
13
  if (TruncInst *T = dyn_cast<TruncInst>(In)) {
1036
1
    IntegerType *TruncTy = cast<IntegerType>(OrigTy);
1037
1
    Value *Mask = ConstantInt::get(DestTy, (1u << TruncTy->getBitWidth()) - 1);
1038
1
    Value *And = IRBuilder<>(In).CreateAnd(T->getOperand(0), Mask);
1039
1
    T->replaceAllUsesWith(And);
1040
1
    T->eraseFromParent();
1041
1
    return;
1042
1
  }
1043
12
1044
12
  // Promote immediates.
1045
37
  
for (unsigned i = 0, n = In->getNumOperands(); 12
i != n;
++i25
) {
1046
25
    if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i)))
1047
8
      if (CI->getType()->getBitWidth() < DestBW)
1048
7
        In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue()));
1049
25
  }
1050
12
}
1051
1052
bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock *LoopB,
1053
1
      BasicBlock *ExitB) {
1054
1
  assert(LoopB);
1055
1
  // Skip loops where the exit block has more than one predecessor. The values
1056
1
  // coming from the loop block will be promoted to another type, and so the
1057
1
  // values coming into the exit block from other predecessors would also have
1058
1
  // to be promoted.
1059
1
  if (!ExitB || (ExitB->getSinglePredecessor() != LoopB))
1060
0
    return false;
1061
1
  IntegerType *DestTy = getPmpyType();
1062
1
  // Check if the exit values have types that are no wider than the type
1063
1
  // that we want to promote to.
1064
1
  unsigned DestBW = DestTy->getBitWidth();
1065
1
  for (PHINode &P : ExitB->phis()) {
1066
1
    if (P.getNumIncomingValues() != 1)
1067
0
      return false;
1068
1
    assert(P.getIncomingBlock(0) == LoopB);
1069
1
    IntegerType *T = dyn_cast<IntegerType>(P.getType());
1070
1
    if (!T || T->getBitWidth() > DestBW)
1071
0
      return false;
1072
1
  }
1073
1
1074
1
  // Check all instructions in the loop.
1075
1
  for (Instruction &In : *LoopB)
1076
14
    if (!In.isTerminator() && 
!isPromotableTo(&In, DestTy)13
)
1077
0
      return false;
1078
1
1079
1
  // Perform the promotion.
1080
1
  std::vector<Instruction*> LoopIns;
1081
1
  std::transform(LoopB->begin(), LoopB->end(), std::back_inserter(LoopIns),
1082
14
                 [](Instruction &In) { return &In; });
1083
1
  for (Instruction *In : LoopIns)
1084
14
    if (!In->isTerminator())
1085
13
      promoteTo(In, DestTy, LoopB);
1086
1
1087
1
  // Fix up the PHI nodes in the exit block.
1088
1
  Instruction *EndI = ExitB->getFirstNonPHI();
1089
1
  BasicBlock::iterator End = EndI ? EndI->getIterator() : 
ExitB->end()0
;
1090
2
  for (auto I = ExitB->begin(); I != End; 
++I1
) {
1091
2
    PHINode *P = dyn_cast<PHINode>(I);
1092
2
    if (!P)
1093
1
      break;
1094
1
    Type *Ty0 = P->getIncomingValue(0)->getType();
1095
1
    Type *PTy = P->getType();
1096
1
    if (PTy != Ty0) {
1097
1
      assert(Ty0 == DestTy);
1098
1
      // In order to create the trunc, P must have the promoted type.
1099
1
      P->mutateType(Ty0);
1100
1
      Value *T = IRBuilder<>(ExitB, End).CreateTrunc(P, PTy);
1101
1
      // In order for the RAUW to work, the types of P and T must match.
1102
1
      P->mutateType(PTy);
1103
1
      P->replaceAllUsesWith(T);
1104
1
      // Final update of the P's type.
1105
1
      P->mutateType(Ty0);
1106
1
      cast<Instruction>(T)->setOperand(0, P);
1107
1
    }
1108
1
  }
1109
1
1110
1
  return true;
1111
1
}
1112
1113
bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In,
1114
6
      ValueSeq &Cycle) {
1115
6
  // Out = ..., In, ...
1116
6
  if (Out == In)
1117
2
    return true;
1118
4
1119
4
  auto *BB = cast<Instruction>(Out)->getParent();
1120
4
  bool HadPhi = false;
1121
4
1122
5
  for (auto U : Out->users()) {
1123
5
    auto *I = dyn_cast<Instruction>(&*U);
1124
5
    if (I == nullptr || I->getParent() != BB)
1125
1
      continue;
1126
4
    // Make sure that there are no multi-iteration cycles, e.g.
1127
4
    //   p1 = phi(p2)
1128
4
    //   p2 = phi(p1)
1129
4
    // The cycle p1->p2->p1 would span two loop iterations.
1130
4
    // Check that there is only one phi in the cycle.
1131
4
    bool IsPhi = isa<PHINode>(I);
1132
4
    if (IsPhi && 
HadPhi2
)
1133
0
      return false;
1134
4
    HadPhi |= IsPhi;
1135
4
    if (Cycle.count(I))
1136
0
      return false;
1137
4
    Cycle.insert(I);
1138
4
    if (findCycle(I, In, Cycle))
1139
4
      break;
1140
0
    Cycle.remove(I);
1141
0
  }
1142
4
  return !Cycle.empty();
1143
4
}
1144
1145
void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI,
1146
2
      ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) {
1147
2
  // All the values in the cycle that are between the phi node and the
1148
2
  // divider instruction will be classified as "early", all other values
1149
2
  // will be "late".
1150
2
1151
2
  bool IsE = true;
1152
2
  unsigned I, N = Cycle.size();
1153
4
  for (I = 0; I < N; 
++I2
) {
1154
4
    Value *V = Cycle[I];
1155
4
    if (DivI == V)
1156
0
      IsE = false;
1157
4
    else if (!isa<PHINode>(V))
1158
2
      continue;
1159
2
    // Stop if found either.
1160
2
    break;
1161
2
  }
1162
2
  // "I" is the index of either DivI or the phi node, whichever was first.
1163
2
  // "E" is "false" or "true" respectively.
1164
2
  ValueSeq &First = !IsE ? 
Early0
: Late;
1165
4
  for (unsigned J = 0; J < I; 
++J2
)
1166
2
    First.insert(Cycle[J]);
1167
2
1168
2
  ValueSeq &Second = IsE ? Early : 
Late0
;
1169
2
  Second.insert(Cycle[I]);
1170
2
  for (++I; I < N; 
++I0
) {
1171
2
    Value *V = Cycle[I];
1172
2
    if (DivI == V || 
isa<PHINode>(V)0
)
1173
2
      break;
1174
0
    Second.insert(V);
1175
0
  }
1176
2
1177
4
  for (; I < N; 
++I2
)
1178
2
    First.insert(Cycle[I]);
1179
2
}
1180
1181
bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI,
1182
8
      ValueSeq &Early, ValueSeq &Late) {
1183
8
  // Select is an exception, since the condition value does not have to be
1184
8
  // classified in the same way as the true/false values. The true/false
1185
8
  // values do have to be both early or both late.
1186
8
  if (UseI->getOpcode() == Instruction::Select) {
1187
3
    Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2);
1188
3
    if (Early.count(TV) || Early.count(FV)) {
1189
0
      if (Late.count(TV) || Late.count(FV))
1190
0
        return false;
1191
0
      Early.insert(UseI);
1192
3
    } else if (Late.count(TV) || 
Late.count(FV)0
) {
1193
3
      if (Early.count(TV) || Early.count(FV))
1194
0
        return false;
1195
3
      Late.insert(UseI);
1196
3
    }
1197
3
    return true;
1198
5
  }
1199
5
1200
5
  // Not sure what would be the example of this, but the code below relies
1201
5
  // on having at least one operand.
1202
5
  if (UseI->getNumOperands() == 0)
1203
0
    return true;
1204
5
1205
5
  bool AE = true, AL = true;
1206
10
  for (auto &I : UseI->operands()) {
1207
10
    if (Early.count(&*I))
1208
6
      AL = false;
1209
4
    else if (Late.count(&*I))
1210
1
      AE = false;
1211
10
  }
1212
5
  // If the operands appear "all early" and "all late" at the same time,
1213
5
  // then it means that none of them are actually classified as either.
1214
5
  // This is harmless.
1215
5
  if (AE && 
AL4
)
1216
0
    return true;
1217
5
  // Conversely, if they are neither "all early" nor "all late", then
1218
5
  // we have a mixture of early and late operands that is not a known
1219
5
  // exception.
1220
5
  if (!AE && 
!AL1
)
1221
0
    return false;
1222
5
1223
5
  // Check that we have covered the two special cases.
1224
5
  assert(AE != AL);
1225
5
1226
5
  if (AE)
1227
4
    Early.insert(UseI);
1228
1
  else
1229
1
    Late.insert(UseI);
1230
5
  return true;
1231
5
}
1232
1233
9
bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) {
1234
9
  switch (I->getOpcode()) {
1235
9
    case Instruction::And:
1236
9
    case Instruction::Or:
1237
9
    case Instruction::Xor:
1238
9
    case Instruction::LShr:
1239
9
    case Instruction::Shl:
1240
9
    case Instruction::Select:
1241
9
    case Instruction::ICmp:
1242
9
    case Instruction::PHI:
1243
9
      break;
1244
9
    default:
1245
0
      return false;
1246
9
  }
1247
9
  return true;
1248
9
}
1249
1250
bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V,
1251
2
      unsigned IterCount) {
1252
2
  auto *T = dyn_cast<IntegerType>(V->getType());
1253
2
  if (!T)
1254
0
    return false;
1255
2
1256
2
  KnownBits Known(T->getBitWidth());
1257
2
  computeKnownBits(V, Known, DL);
1258
2
  return Known.countMinLeadingZeros() >= IterCount;
1259
2
}
1260
1261
bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V,
1262
12
      unsigned IterCount) {
1263
12
  // Assume that all inputs to the value have the high bits zero.
1264
12
  // Check if the value itself preserves the zeros in the high bits.
1265
12
  if (auto *C = dyn_cast<ConstantInt>(V))
1266
3
    return C->getValue().countLeadingZeros() >= IterCount;
1267
9
1268
9
  if (auto *I = dyn_cast<Instruction>(V)) {
1269
9
    switch (I->getOpcode()) {
1270
9
      case Instruction::And:
1271
9
      case Instruction::Or:
1272
9
      case Instruction::Xor:
1273
9
      case Instruction::LShr:
1274
9
      case Instruction::Select:
1275
9
      case Instruction::ICmp:
1276
9
      case Instruction::PHI:
1277
9
      case Instruction::ZExt:
1278
9
        return true;
1279
0
    }
1280
0
  }
1281
0
1282
0
  return false;
1283
0
}
1284
1285
11
bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) {
1286
11
  unsigned Opc = I->getOpcode();
1287
11
  if (Opc == Instruction::Shl || Opc == Instruction::LShr)
1288
0
    return Op != I->getOperand(1);
1289
11
  return true;
1290
11
}
1291
1292
bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB,
1293
1
      BasicBlock *ExitB, unsigned IterCount) {
1294
1
  Value *CIV = getCountIV(LoopB);
1295
1
  if (CIV == nullptr)
1296
0
    return false;
1297
1
  auto *CIVTy = dyn_cast<IntegerType>(CIV->getType());
1298
1
  if (CIVTy == nullptr)
1299
0
    return false;
1300
1
1301
1
  ValueSeq RShifts;
1302
1
  ValueSeq Early, Late, Cycled;
1303
1
1304
1
  // Find all value cycles that contain logical right shifts by 1.
1305
13
  for (Instruction &I : *LoopB) {
1306
13
    using namespace PatternMatch;
1307
13
1308
13
    Value *V = nullptr;
1309
13
    if (!match(&I, m_LShr(m_Value(V), m_One())))
1310
11
      continue;
1311
2
    ValueSeq C;
1312
2
    if (!findCycle(&I, V, C))
1313
0
      continue;
1314
2
1315
2
    // Found a cycle.
1316
2
    C.insert(&I);
1317
2
    classifyCycle(&I, C, Early, Late);
1318
2
    Cycled.insert(C.begin(), C.end());
1319
2
    RShifts.insert(&I);
1320
2
  }
1321
1
1322
1
  // Find the set of all values affected by the shift cycles, i.e. all
1323
1
  // cycled values, and (recursively) all their users.
1324
1
  ValueSeq Users(Cycled.begin(), Cycled.end());
1325
10
  for (unsigned i = 0; i < Users.size(); 
++i9
) {
1326
9
    Value *V = Users[i];
1327
9
    if (!isa<IntegerType>(V->getType()))
1328
0
      return false;
1329
9
    auto *R = cast<Instruction>(V);
1330
9
    // If the instruction does not commute with shifts, the loop cannot
1331
9
    // be unshifted.
1332
9
    if (!commutesWithShift(R))
1333
0
      return false;
1334
22
    
for (auto I = R->user_begin(), E = R->user_end(); 9
I != E;
++I13
) {
1335
13
      auto *T = cast<Instruction>(*I);
1336
13
      // Skip users from outside of the loop. They will be handled later.
1337
13
      // Also, skip the right-shifts and phi nodes, since they mix early
1338
13
      // and late values.
1339
13
      if (T->getParent() != LoopB || 
RShifts.count(T)12
||
isa<PHINode>(T)10
)
1340
5
        continue;
1341
8
1342
8
      Users.insert(T);
1343
8
      if (!classifyInst(T, Early, Late))
1344
0
        return false;
1345
8
    }
1346
9
  }
1347
1
1348
1
  if (Users.empty())
1349
0
    return false;
1350
1
1351
1
  // Verify that high bits remain zero.
1352
1
  ValueSeq Internal(Users.begin(), Users.end());
1353
1
  ValueSeq Inputs;
1354
13
  for (unsigned i = 0; i < Internal.size(); 
++i12
) {
1355
12
    auto *R = dyn_cast<Instruction>(Internal[i]);
1356
12
    if (!R)
1357
3
      continue;
1358
19
    
for (Value *Op : R->operands())9
{
1359
19
      auto *T = dyn_cast<Instruction>(Op);
1360
19
      if (T && 
T->getParent() != LoopB14
)
1361
2
        Inputs.insert(Op);
1362
17
      else
1363
17
        Internal.insert(Op);
1364
19
    }
1365
9
  }
1366
1
  for (Value *V : Inputs)
1367
2
    if (!highBitsAreZero(V, IterCount))
1368
0
      return false;
1369
1
  for (Value *V : Internal)
1370
12
    if (!keepsHighBitsZero(V, IterCount))
1371
0
      return false;
1372
1
1373
1
  // Finally, the work can be done. Unshift each user.
1374
1
  IRBuilder<> IRB(LoopB);
1375
1
  std::map<Value*,Value*> ShiftMap;
1376
1
1377
1
  using CastMapType = std::map<std::pair<Value *, Type *>, Value *>;
1378
1
1379
1
  CastMapType CastMap;
1380
1
1381
1
  auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V,
1382
1
        IntegerType *Ty) -> Value* {
1383
0
    auto H = CM.find(std::make_pair(V, Ty));
1384
0
    if (H != CM.end())
1385
0
      return H->second;
1386
0
    Value *CV = IRB.CreateIntCast(V, Ty, false);
1387
0
    CM.insert(std::make_pair(std::make_pair(V, Ty), CV));
1388
0
    return CV;
1389
0
  };
1390
1
1391
14
  for (auto I = LoopB->begin(), E = LoopB->end(); I != E; 
++I13
) {
1392
13
    using namespace PatternMatch;
1393
13
1394
13
    if (isa<PHINode>(I) || 
!Users.count(&*I)10
)
1395
6
      continue;
1396
7
1397
7
    // Match lshr x, 1.
1398
7
    Value *V = nullptr;
1399
7
    if (match(&*I, m_LShr(m_Value(V), m_One()))) {
1400
2
      replaceAllUsesOfWithIn(&*I, V, LoopB);
1401
2
      continue;
1402
2
    }
1403
5
    // For each non-cycled operand, replace it with the corresponding
1404
5
    // value shifted left.
1405
11
    
for (auto &J : I->operands())5
{
1406
11
      Value *Op = J.get();
1407
11
      if (!isOperandShifted(&*I, Op))
1408
0
        continue;
1409
11
      if (Users.count(Op))
1410
8
        continue;
1411
3
      // Skip shifting zeros.
1412
3
      if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero())
1413
1
        continue;
1414
2
      // Check if we have already generated a shift for this value.
1415
2
      auto F = ShiftMap.find(Op);
1416
2
      Value *W = (F != ShiftMap.end()) ? 
F->second0
: nullptr;
1417
2
      if (W == nullptr) {
1418
2
        IRB.SetInsertPoint(&*I);
1419
2
        // First, the shift amount will be CIV or CIV+1, depending on
1420
2
        // whether the value is early or late. Instead of creating CIV+1,
1421
2
        // do a single shift of the value.
1422
2
        Value *ShAmt = CIV, *ShVal = Op;
1423
2
        auto *VTy = cast<IntegerType>(ShVal->getType());
1424
2
        auto *ATy = cast<IntegerType>(ShAmt->getType());
1425
2
        if (Late.count(&*I))
1426
1
          ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1));
1427
2
        // Second, the types of the shifted value and the shift amount
1428
2
        // must match.
1429
2
        if (VTy != ATy) {
1430
0
          if (VTy->getBitWidth() < ATy->getBitWidth())
1431
0
            ShVal = upcast(CastMap, IRB, ShVal, ATy);
1432
0
          else
1433
0
            ShAmt = upcast(CastMap, IRB, ShAmt, VTy);
1434
0
        }
1435
2
        // Ready to generate the shift and memoize it.
1436
2
        W = IRB.CreateShl(ShVal, ShAmt);
1437
2
        ShiftMap.insert(std::make_pair(Op, W));
1438
2
      }
1439
2
      I->replaceUsesOfWith(Op, W);
1440
2
    }
1441
5
  }
1442
1
1443
1
  // Update the users outside of the loop to account for having left
1444
1
  // shifts. They would normally be shifted right in the loop, so shift
1445
1
  // them right after the loop exit.
1446
1
  // Take advantage of the loop-closed SSA form, which has all the post-
1447
1
  // loop values in phi nodes.
1448
1
  IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt());
1449
2
  for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; 
++P1
) {
1450
2
    if (!isa<PHINode>(P))
1451
1
      break;
1452
1
    auto *PN = cast<PHINode>(P);
1453
1
    Value *U = PN->getIncomingValueForBlock(LoopB);
1454
1
    if (!Users.count(U))
1455
0
      continue;
1456
1
    Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount));
1457
1
    PN->replaceAllUsesWith(S);
1458
1
    // The above RAUW will create
1459
1
    //   S = lshr S, IterCount
1460
1
    // so we need to fix it back into
1461
1
    //   S = lshr PN, IterCount
1462
1
    cast<User>(S)->replaceUsesOfWith(S, PN);
1463
1
  }
1464
1
1465
1
  return true;
1466
1
}
1467
1468
1
void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) {
1469
1
  for (auto &I : *LoopB)
1470
15
    if (Value *SV = SimplifyInstruction(&I, {DL, &TLI, &DT}))
1471
1
      I.replaceAllUsesWith(SV);
1472
1
1473
16
  for (auto I = LoopB->begin(), N = I; I != LoopB->end(); 
I = N15
) {
1474
15
    N = std::next(I);
1475
15
    RecursivelyDeleteTriviallyDeadInstructions(&*I, &TLI);
1476
15
  }
1477
1
}
1478
1479
1
unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) {
1480
1
  // Arrays of coefficients of Q and the inverse, C.
1481
1
  // Q[i] = coefficient at x^i.
1482
1
  std::array<char,32> Q, C;
1483
1
1484
33
  for (unsigned i = 0; i < 32; 
++i32
) {
1485
32
    Q[i] = QP & 1;
1486
32
    QP >>= 1;
1487
32
  }
1488
1
  assert(Q[0] == 1);
1489
1
1490
1
  // Find C, such that
1491
1
  // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1
1492
1
  //
1493
1
  // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the
1494
1
  // operations * and + are & and ^ respectively.
1495
1
  //
1496
1
  // Find C[i] recursively, by comparing i-th coefficient in the product
1497
1
  // with 0 (or 1 for i=0).
1498
1
  //
1499
1
  // C[0] = 1, since C[0] = Q[0], and Q[0] = 1.
1500
1
  C[0] = 1;
1501
32
  for (unsigned i = 1; i < 32; 
++i31
) {
1502
31
    // Solve for C[i] in:
1503
31
    //   C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0
1504
31
    // This is equivalent to
1505
31
    //   C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0
1506
31
    // which is
1507
31
    //   C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i]
1508
31
    unsigned T = 0;
1509
527
    for (unsigned j = 0; j < i; 
++j496
)
1510
496
      T = T ^ (C[j] & Q[i-j]);
1511
31
    C[i] = T;
1512
31
  }
1513
1
1514
1
  unsigned QV = 0;
1515
33
  for (unsigned i = 0; i < 32; 
++i32
)
1516
32
    if (C[i])
1517
32
      QV |= (1 << i);
1518
1
1519
1
  return QV;
1520
1
}
1521
1522
Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,
1523
2
      ParsedValues &PV) {
1524
2
  IRBuilder<> B(&*At);
1525
2
  Module *M = At->getParent()->getParent()->getParent();
1526
2
  Function *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw);
1527
2
1528
2
  Value *P = PV.P, *Q = PV.Q, *P0 = P;
1529
2
  unsigned IC = PV.IterCount;
1530
2
1531
2
  if (PV.M != nullptr)
1532
1
    P0 = P = B.CreateXor(P, PV.M);
1533
2
1534
2
  // Create a bit mask to clear the high bits beyond IterCount.
1535
2
  auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC));
1536
2
1537
2
  if (PV.IterCount != 32)
1538
1
    P = B.CreateAnd(P, BMI);
1539
2
1540
2
  if (PV.Inv) {
1541
1
    auto *QI = dyn_cast<ConstantInt>(PV.Q);
1542
1
    assert(QI && QI->getBitWidth() <= 32);
1543
1
1544
1
    // Again, clearing bits beyond IterCount.
1545
1
    unsigned M = (1 << PV.IterCount) - 1;
1546
1
    unsigned Tmp = (QI->getZExtValue() | 1) & M;
1547
1
    unsigned QV = getInverseMxN(Tmp) & M;
1548
1
    auto *QVI = ConstantInt::get(QI->getType(), QV);
1549
1
    P = B.CreateCall(PMF, {P, QVI});
1550
1
    P = B.CreateTrunc(P, QI->getType());
1551
1
    if (IC != 32)
1552
1
      P = B.CreateAnd(P, BMI);
1553
1
  }
1554
2
1555
2
  Value *R = B.CreateCall(PMF, {P, Q});
1556
2
1557
2
  if (PV.M != nullptr)
1558
1
    R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false));
1559
2
1560
2
  return R;
1561
2
}
1562
1563
0
static bool hasZeroSignBit(const Value *V) {
1564
0
  if (const auto *CI = dyn_cast<const ConstantInt>(V))
1565
0
    return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0;
1566
0
  const Instruction *I = dyn_cast<const Instruction>(V);
1567
0
  if (!I)
1568
0
    return false;
1569
0
  switch (I->getOpcode()) {
1570
0
    case Instruction::LShr:
1571
0
      if (const auto SI = dyn_cast<const ConstantInt>(I->getOperand(1)))
1572
0
        return SI->getZExtValue() > 0;
1573
0
      return false;
1574
0
    case Instruction::Or:
1575
0
    case Instruction::Xor:
1576
0
      return hasZeroSignBit(I->getOperand(0)) &&
1577
0
             hasZeroSignBit(I->getOperand(1));
1578
0
    case Instruction::And:
1579
0
      return hasZeroSignBit(I->getOperand(0)) ||
1580
0
             hasZeroSignBit(I->getOperand(1));
1581
0
  }
1582
0
  return false;
1583
0
}
1584
1585
7
void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) {
1586
7
  S.addRule("sink-zext",
1587
7
    // Sink zext past bitwise operations.
1588
1.02k
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1589
1.02k
      if (I->getOpcode() != Instruction::ZExt)
1590
1.02k
        return nullptr;
1591
1
      Instruction *T = dyn_cast<Instruction>(I->getOperand(0));
1592
1
      if (!T)
1593
0
        return nullptr;
1594
1
      switch (T->getOpcode()) {
1595
1
        case Instruction::And:
1596
0
        case Instruction::Or:
1597
0
        case Instruction::Xor:
1598
0
          break;
1599
1
        default:
1600
1
          return nullptr;
1601
0
      }
1602
0
      IRBuilder<> B(Ctx);
1603
0
      return B.CreateBinOp(cast<BinaryOperator>(T)->getOpcode(),
1604
0
                           B.CreateZExt(T->getOperand(0), I->getType()),
1605
0
                           B.CreateZExt(T->getOperand(1), I->getType()));
1606
0
    });
1607
7
  S.addRule("xor/and -> and/xor",
1608
7
    // (xor (and x a) (and y a)) -> (and (xor x y) a)
1609
1.02k
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1610
1.02k
      if (I->getOpcode() != Instruction::Xor)
1611
1.01k
        return nullptr;
1612
3
      Instruction *And0 = dyn_cast<Instruction>(I->getOperand(0));
1613
3
      Instruction *And1 = dyn_cast<Instruction>(I->getOperand(1));
1614
3
      if (!And0 || !And1)
1615
1
        return nullptr;
1616
2
      if (And0->getOpcode() != Instruction::And ||
1617
2
          
And1->getOpcode() != Instruction::And0
)
1618
2
        return nullptr;
1619
0
      if (And0->getOperand(1) != And1->getOperand(1))
1620
0
        return nullptr;
1621
0
      IRBuilder<> B(Ctx);
1622
0
      return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1->getOperand(0)),
1623
0
                         And0->getOperand(1));
1624
0
    });
1625
7
  S.addRule("sink binop into select",
1626
7
    // (Op (select c x y) z) -> (select c (Op x z) (Op y z))
1627
7
    // (Op x (select c y z)) -> (select c (Op x y) (Op x z))
1628
1.02k
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1629
1.02k
      BinaryOperator *BO = dyn_cast<BinaryOperator>(I);
1630
1.02k
      if (!BO)
1631
469
        return nullptr;
1632
553
      Instruction::BinaryOps Op = BO->getOpcode();
1633
553
      if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(0))) {
1634
92
        IRBuilder<> B(Ctx);
1635
92
        Value *X = Sel->getTrueValue(), *Y = Sel->getFalseValue();
1636
92
        Value *Z = BO->getOperand(1);
1637
92
        return B.CreateSelect(Sel->getCondition(),
1638
92
                              B.CreateBinOp(Op, X, Z),
1639
92
                              B.CreateBinOp(Op, Y, Z));
1640
92
      }
1641
461
      if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(1))) {
1642
4
        IRBuilder<> B(Ctx);
1643
4
        Value *X = BO->getOperand(0);
1644
4
        Value *Y = Sel->getTrueValue(), *Z = Sel->getFalseValue();
1645
4
        return B.CreateSelect(Sel->getCondition(),
1646
4
                              B.CreateBinOp(Op, X, Y),
1647
4
                              B.CreateBinOp(Op, X, Z));
1648
4
      }
1649
457
      return nullptr;
1650
457
    });
1651
7
  S.addRule("fold select-select",
1652
7
    // (select c (select c x y) z) -> (select c x z)
1653
7
    // (select c x (select c y z)) -> (select c x z)
1654
926
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1655
926
      SelectInst *Sel = dyn_cast<SelectInst>(I);
1656
926
      if (!Sel)
1657
717
        return nullptr;
1658
209
      IRBuilder<> B(Ctx);
1659
209
      Value *C = Sel->getCondition();
1660
209
      if (SelectInst *Sel0 = dyn_cast<SelectInst>(Sel->getTrueValue())) {
1661
83
        if (Sel0->getCondition() == C)
1662
12
          return B.CreateSelect(C, Sel0->getTrueValue(), Sel->getFalseValue());
1663
197
      }
1664
197
      if (SelectInst *Sel1 = dyn_cast<SelectInst>(Sel->getFalseValue())) {
1665
83
        if (Sel1->getCondition() == C)
1666
12
          return B.CreateSelect(C, Sel->getTrueValue(), Sel1->getFalseValue());
1667
185
      }
1668
185
      return nullptr;
1669
185
    });
1670
7
  S.addRule("or-signbit -> xor-signbit",
1671
7
    // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0)
1672
902
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1673
902
      if (I->getOpcode() != Instruction::Or)
1674
902
        return nullptr;
1675
0
      ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1));
1676
0
      if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit())
1677
0
        return nullptr;
1678
0
      if (!hasZeroSignBit(I->getOperand(0)))
1679
0
        return nullptr;
1680
0
      return IRBuilder<>(Ctx).CreateXor(I->getOperand(0), Msb);
1681
0
    });
1682
7
  S.addRule("sink lshr into binop",
1683
7
    // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c))
1684
902
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1685
902
      if (I->getOpcode() != Instruction::LShr)
1686
900
        return nullptr;
1687
2
      BinaryOperator *BitOp = dyn_cast<BinaryOperator>(I->getOperand(0));
1688
2
      if (!BitOp)
1689
2
        return nullptr;
1690
0
      switch (BitOp->getOpcode()) {
1691
0
        case Instruction::And:
1692
0
        case Instruction::Or:
1693
0
        case Instruction::Xor:
1694
0
          break;
1695
0
        default:
1696
0
          return nullptr;
1697
0
      }
1698
0
      IRBuilder<> B(Ctx);
1699
0
      Value *S = I->getOperand(1);
1700
0
      return B.CreateBinOp(BitOp->getOpcode(),
1701
0
                B.CreateLShr(BitOp->getOperand(0), S),
1702
0
                B.CreateLShr(BitOp->getOperand(1), S));
1703
0
    });
1704
7
  S.addRule("expose bitop-const",
1705
7
    // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b))
1706
902
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1707
902
      auto IsBitOp = [](unsigned Op) -> bool {
1708
461
        switch (Op) {
1709
461
          case Instruction::And:
1710
6
          case Instruction::Or:
1711
6
          case Instruction::Xor:
1712
6
            return true;
1713
455
        }
1714
455
        return false;
1715
455
      };
1716
902
      BinaryOperator *BitOp1 = dyn_cast<BinaryOperator>(I);
1717
902
      if (!BitOp1 || 
!IsBitOp(BitOp1->getOpcode())457
)
1718
897
        return nullptr;
1719
5
      BinaryOperator *BitOp2 = dyn_cast<BinaryOperator>(BitOp1->getOperand(0));
1720
5
      if (!BitOp2 || 
!IsBitOp(BitOp2->getOpcode())4
)
1721
4
        return nullptr;
1722
1
      ConstantInt *CA = dyn_cast<ConstantInt>(BitOp2->getOperand(1));
1723
1
      ConstantInt *CB = dyn_cast<ConstantInt>(BitOp1->getOperand(1));
1724
1
      if (!CA || 
!CB0
)
1725
1
        return nullptr;
1726
0
      IRBuilder<> B(Ctx);
1727
0
      Value *X = BitOp2->getOperand(0);
1728
0
      return B.CreateBinOp(BitOp2->getOpcode(), X,
1729
0
                B.CreateBinOp(BitOp1->getOpcode(), CA, CB));
1730
0
    });
1731
7
}
1732
1733
1
void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier &S) {
1734
1
  S.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a",
1735
13
    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1736
13
      if (I->getOpcode() != Instruction::And)
1737
11
        return nullptr;
1738
2
      Instruction *Xor = dyn_cast<Instruction>(I->getOperand(0));
1739
2
      ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(1));
1740
2
      if (!Xor || !C0)
1741
0
        return nullptr;
1742
2
      if (Xor->getOpcode() != Instruction::Xor)
1743
0
        return nullptr;
1744
2
      Instruction *And0 = dyn_cast<Instruction>(Xor->getOperand(0));
1745
2
      Instruction *And1 = dyn_cast<Instruction>(Xor->getOperand(1));
1746
2
      // Pick the first non-null and.
1747
2
      if (!And0 || And0->getOpcode() != Instruction::And)
1748
2
        std::swap(And0, And1);
1749
2
      ConstantInt *C1 = dyn_cast<ConstantInt>(And0->getOperand(1));
1750
2
      if (!C1)
1751
1
        return nullptr;
1752
1
      uint32_t V0 = C0->getZExtValue();
1753
1
      uint32_t V1 = C1->getZExtValue();
1754
1
      if (V0 != (V0 & V1))
1755
0
        return nullptr;
1756
1
      IRBuilder<> B(Ctx);
1757
1
      return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1), C0);
1758
1
    });
1759
1
}
1760
1761
8
bool PolynomialMultiplyRecognize::recognize() {
1762
8
  LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n"
1763
8
                    << *CurLoop << '\n');
1764
8
  // Restrictions:
1765
8
  // - The loop must consist of a single block.
1766
8
  // - The iteration count must be known at compile-time.
1767
8
  // - The loop must have an induction variable starting from 0, and
1768
8
  //   incremented in each iteration of the loop.
1769
8
  BasicBlock *LoopB = CurLoop->getHeader();
1770
8
  LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB);
1771
8
1772
8
  if (LoopB != CurLoop->getLoopLatch())
1773
1
    return false;
1774
7
  BasicBlock *ExitB = CurLoop->getExitBlock();
1775
7
  if (ExitB == nullptr)
1776
0
    return false;
1777
7
  BasicBlock *EntryB = CurLoop->getLoopPreheader();
1778
7
  if (EntryB == nullptr)
1779
0
    return false;
1780
7
1781
7
  unsigned IterCount = 0;
1782
7
  const SCEV *CT = SE.getBackedgeTakenCount(CurLoop);
1783
7
  if (isa<SCEVCouldNotCompute>(CT))
1784
0
    return false;
1785
7
  if (auto *CV = dyn_cast<SCEVConstant>(CT))
1786
4
    IterCount = CV->getValue()->getZExtValue() + 1;
1787
7
1788
7
  Value *CIV = getCountIV(LoopB);
1789
7
  ParsedValues PV;
1790
7
  Simplifier PreSimp;
1791
7
  PV.IterCount = IterCount;
1792
7
  LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV << "\nIterCount: " << IterCount
1793
7
                    << '\n');
1794
7
1795
7
  setupPreSimplifier(PreSimp);
1796
7
1797
7
  // Perform a preliminary scan of select instructions to see if any of them
1798
7
  // looks like a generator of the polynomial multiply steps. Assume that a
1799
7
  // loop can only contain a single transformable operation, so stop the
1800
7
  // traversal after the first reasonable candidate was found.
1801
7
  // XXX: Currently this approach can modify the loop before being 100% sure
1802
7
  // that the transformation can be carried out.
1803
7
  bool FoundPreScan = false;
1804
19
  auto FeedsPHI = [LoopB](const Value *V) -> bool {
1805
26
    for (const Value *U : V->users()) {
1806
26
      if (const auto *P = dyn_cast<const PHINode>(U))
1807
7
        if (P->getParent() == LoopB)
1808
4
          return true;
1809
26
    }
1810
19
    
return false15
;
1811
19
  };
1812
121
  for (Instruction &In : *LoopB) {
1813
121
    SelectInst *SI = dyn_cast<SelectInst>(&In);
1814
121
    if (!SI || 
!FeedsPHI(SI)18
)
1815
118
      continue;
1816
3
1817
3
    Simplifier::Context C(SI);
1818
3
    Value *T = PreSimp.simplify(C);
1819
3
    SelectInst *SelI = (T && isa<SelectInst>(T)) ? cast<SelectInst>(T) : 
SI0
;
1820
3
    LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C, SelI) << '\n');
1821
3
    if (scanSelect(SelI, LoopB, EntryB, CIV, PV, true)) {
1822
2
      FoundPreScan = true;
1823
2
      if (SelI != SI) {
1824
2
        Value *NewSel = C.materialize(LoopB, SI->getIterator());
1825
2
        SI->replaceAllUsesWith(NewSel);
1826
2
        RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI);
1827
2
      }
1828
2
      break;
1829
2
    }
1830
3
  }
1831
7
1832
7
  if (!FoundPreScan) {
1833
5
    LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n");
1834
5
    return false;
1835
5
  }
1836
2
1837
2
  if (!PV.Left) {
1838
1
    // The right shift version actually only returns the higher bits of
1839
1
    // the result (each iteration discards the LSB). If we want to convert it
1840
1
    // to a left-shifting loop, the working data type must be at least as
1841
1
    // wide as the target's pmpy instruction.
1842
1
    if (!promoteTypes(LoopB, ExitB))
1843
0
      return false;
1844
1
    // Run post-promotion simplifications.
1845
1
    Simplifier PostSimp;
1846
1
    setupPostSimplifier(PostSimp);
1847
11
    for (Instruction &In : *LoopB) {
1848
11
      SelectInst *SI = dyn_cast<SelectInst>(&In);
1849
11
      if (!SI || 
!FeedsPHI(SI)1
)
1850
10
        continue;
1851
1
      Simplifier::Context C(SI);
1852
1
      Value *T = PostSimp.simplify(C);
1853
1
      SelectInst *SelI = dyn_cast_or_null<SelectInst>(T);
1854
1
      if (SelI != SI) {
1855
1
        Value *NewSel = C.materialize(LoopB, SI->getIterator());
1856
1
        SI->replaceAllUsesWith(NewSel);
1857
1
        RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI);
1858
1
      }
1859
1
      break;
1860
1
    }
1861
1
1862
1
    if (!convertShiftsToLeft(LoopB, ExitB, IterCount))
1863
0
      return false;
1864
1
    cleanupLoopBody(LoopB);
1865
1
  }
1866
2
1867
2
  // Scan the loop again, find the generating select instruction.
1868
2
  bool FoundScan = false;
1869
18
  for (Instruction &In : *LoopB) {
1870
18
    SelectInst *SelI = dyn_cast<SelectInst>(&In);
1871
18
    if (!SelI)
1872
16
      continue;
1873
2
    LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI << '\n');
1874
2
    FoundScan = scanSelect(SelI, LoopB, EntryB, CIV, PV, false);
1875
2
    if (FoundScan)
1876
2
      break;
1877
2
  }
1878
2
  assert(FoundScan);
1879
2
1880
2
  LLVM_DEBUG({
1881
2
    StringRef PP = (PV.M ? "(P+M)" : "P");
1882
2
    if (!PV.Inv)
1883
2
      dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n";
1884
2
    else
1885
2
      dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + "
1886
2
             << PP << "\n";
1887
2
    dbgs() << "  Res:" << *PV.Res << "\n  P:" << *PV.P << "\n";
1888
2
    if (PV.M)
1889
2
      dbgs() << "  M:" << *PV.M << "\n";
1890
2
    dbgs() << "  Q:" << *PV.Q << "\n";
1891
2
    dbgs() << "  Iteration count:" << PV.IterCount << "\n";
1892
2
  });
1893
2
1894
2
  BasicBlock::iterator At(EntryB->getTerminator());
1895
2
  Value *PM = generate(At, PV);
1896
2
  if (PM == nullptr)
1897
0
    return false;
1898
2
1899
2
  if (PM->getType() != PV.Res->getType())
1900
1
    PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false);
1901
2
1902
2
  PV.Res->replaceAllUsesWith(PM);
1903
2
  PV.Res->eraseFromParent();
1904
2
  return true;
1905
2
}
1906
1907
9
int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) {
1908
9
  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1)))
1909
9
    return SC->getAPInt().getSExtValue();
1910
0
  return 0;
1911
0
}
1912
1913
3
bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) {
1914
3
  // Allow volatile stores if HexagonVolatileMemcpy is enabled.
1915
3
  if (!(SI->isVolatile() && 
HexagonVolatileMemcpy0
) && !SI->isSimple())
1916
0
    return false;
1917
3
1918
3
  Value *StoredVal = SI->getValueOperand();
1919
3
  Value *StorePtr = SI->getPointerOperand();
1920
3
1921
3
  // Reject stores that are so large that they overflow an unsigned.
1922
3
  uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType());
1923
3
  if ((SizeInBits & 7) || (SizeInBits >> 32) != 0)
1924
0
    return false;
1925
3
1926
3
  // See if the pointer expression is an AddRec like {base,+,1} on the current
1927
3
  // loop, which indicates a strided store.  If we have something else, it's a
1928
3
  // random store we can't handle.
1929
3
  auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
1930
3
  if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
1931
0
    return false;
1932
3
1933
3
  // Check to see if the stride matches the size of the store.  If so, then we
1934
3
  // know that every byte is touched in the loop.
1935
3
  int Stride = getSCEVStride(StoreEv);
1936
3
  if (Stride == 0)
1937
0
    return false;
1938
3
  unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
1939
3
  if (StoreSize != unsigned(std::abs(Stride)))
1940
0
    return false;
1941
3
1942
3
  // The store must be feeding a non-volatile load.
1943
3
  LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand());
1944
3
  if (!LI || !LI->isSimple())
1945
0
    return false;
1946
3
1947
3
  // See if the pointer expression is an AddRec like {base,+,1} on the current
1948
3
  // loop, which indicates a strided load.  If we have something else, it's a
1949
3
  // random load we can't handle.
1950
3
  Value *LoadPtr = LI->getPointerOperand();
1951
3
  auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr));
1952
3
  if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
1953
0
    return false;
1954
3
1955
3
  // The store and load must share the same stride.
1956
3
  if (StoreEv->getOperand(1) != LoadEv->getOperand(1))
1957
0
    return false;
1958
3
1959
3
  // Success.  This store can be converted into a memcpy.
1960
3
  return true;
1961
3
}
1962
1963
/// mayLoopAccessLocation - Return true if the specified loop might access the
1964
/// specified pointer location, which is a loop-strided access.  The 'Access'
1965
/// argument specifies what the verboten forms of access are (read or write).
1966
static bool
1967
mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
1968
                      const SCEV *BECount, unsigned StoreSize,
1969
                      AliasAnalysis &AA,
1970
9
                      SmallPtrSetImpl<Instruction *> &Ignored) {
1971
9
  // Get the location that may be stored across the loop.  Since the access
1972
9
  // is strided positively through memory, we say that the modified location
1973
9
  // starts at the pointer and has infinite size.
1974
9
  LocationSize AccessSize = LocationSize::unknown();
1975
9
1976
9
  // If the loop iterates a fixed number of times, we can refine the access
1977
9
  // size to be exactly the size of the memset, which is (BECount+1)*StoreSize
1978
9
  if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
1979
0
    AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
1980
0
                                       StoreSize);
1981
9
1982
9
  // TODO: For this to be really effective, we have to dive into the pointer
1983
9
  // operand in the store.  Store to &A[i] of 100 will always return may alias
1984
9
  // with store of &A[100], we need to StoreLoc to be "A" with size of 100,
1985
9
  // which will then no-alias a store to &A[100].
1986
9
  MemoryLocation StoreLoc(Ptr, AccessSize);
1987
9
1988
9
  for (auto *B : L->blocks())
1989
11
    for (auto &I : *B)
1990
64
      if (Ignored.count(&I) == 0 &&
1991
64
          isModOrRefSet(
1992
55
              intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access)))
1993
3
        return true;
1994
9
1995
9
  
return false6
;
1996
9
}
1997
1998
void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB,
1999
6
      SmallVectorImpl<StoreInst*> &Stores) {
2000
6
  Stores.clear();
2001
6
  for (Instruction &I : *BB)
2002
107
    if (StoreInst *SI = dyn_cast<StoreInst>(&I))
2003
3
      if (isLegalStore(CurLoop, SI))
2004
3
        Stores.push_back(SI);
2005
6
}
2006
2007
bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop,
2008
3
      StoreInst *SI, const SCEV *BECount) {
2009
3
  assert((SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy)) &&
2010
3
         "Expected only non-volatile stores, or Hexagon-specific memcpy"
2011
3
         "to volatile destination.");
2012
3
2013
3
  Value *StorePtr = SI->getPointerOperand();
2014
3
  auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
2015
3
  unsigned Stride = getSCEVStride(StoreEv);
2016
3
  unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
2017
3
  if (Stride != StoreSize)
2018
0
    return false;
2019
3
2020
3
  // See if the pointer expression is an AddRec like {base,+,1} on the current
2021
3
  // loop, which indicates a strided load.  If we have something else, it's a
2022
3
  // random load we can't handle.
2023
3
  LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand());
2024
3
  auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand()));
2025
3
2026
3
  // The trip count of the loop and the base pointer of the addrec SCEV is
2027
3
  // guaranteed to be loop invariant, which means that it should dominate the
2028
3
  // header.  This allows us to insert code for it in the preheader.
2029
3
  BasicBlock *Preheader = CurLoop->getLoopPreheader();
2030
3
  Instruction *ExpPt = Preheader->getTerminator();
2031
3
  IRBuilder<> Builder(ExpPt);
2032
3
  SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom");
2033
3
2034
3
  Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace());
2035
3
2036
3
  // Okay, we have a strided store "p[i]" of a loaded value.  We can turn
2037
3
  // this into a memcpy/memmove in the loop preheader now if we want.  However,
2038
3
  // this would be unsafe to do if there is anything else in the loop that may
2039
3
  // read or write the memory region we're storing to.  For memcpy, this
2040
3
  // includes the load that feeds the stores.  Check for an alias by generating
2041
3
  // the base address and checking everything.
2042
3
  Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(),
2043
3
      Builder.getInt8PtrTy(SI->getPointerAddressSpace()), ExpPt);
2044
3
  Value *LoadBasePtr = nullptr;
2045
3
2046
3
  bool Overlap = false;
2047
3
  bool DestVolatile = SI->isVolatile();
2048
3
  Type *BECountTy = BECount->getType();
2049
3
2050
3
  if (DestVolatile) {
2051
0
    // The trip count must fit in i32, since it is the type of the "num_words"
2052
0
    // argument to hexagon_memcpy_forward_vp4cp4n2.
2053
0
    if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) {
2054
0
CleanupAndExit:
2055
0
      // If we generated new code for the base pointer, clean up.
2056
0
      Expander.clear();
2057
0
      if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) {
2058
0
        RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI);
2059
0
        StoreBasePtr = nullptr;
2060
0
      }
2061
0
      if (LoadBasePtr) {
2062
0
        RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI);
2063
0
        LoadBasePtr = nullptr;
2064
0
      }
2065
0
      return false;
2066
3
    }
2067
0
  }
2068
3
2069
3
  SmallPtrSet<Instruction*, 2> Ignore1;
2070
3
  Ignore1.insert(SI);
2071
3
  if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
2072
3
                            StoreSize, *AA, Ignore1)) {
2073
3
    // Check if the load is the offending instruction.
2074
3
    Ignore1.insert(LI);
2075
3
    if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
2076
3
                              BECount, StoreSize, *AA, Ignore1)) {
2077
0
      // Still bad. Nothing we can do.
2078
0
      goto CleanupAndExit;
2079
0
    }
2080
3
    // It worked with the load ignored.
2081
3
    Overlap = true;
2082
3
  }
2083
3
2084
3
  if (!Overlap) {
2085
0
    if (DisableMemcpyIdiom || !HasMemcpy)
2086
0
      goto CleanupAndExit;
2087
3
  } else {
2088
3
    // Don't generate memmove if this function will be inlined. This is
2089
3
    // because the caller will undergo this transformation after inlining.
2090
3
    Function *Func = CurLoop->getHeader()->getParent();
2091
3
    if (Func->hasFnAttribute(Attribute::AlwaysInline))
2092
0
      goto CleanupAndExit;
2093
3
2094
3
    // In case of a memmove, the call to memmove will be executed instead
2095
3
    // of the loop, so we need to make sure that there is nothing else in
2096
3
    // the loop than the load, store and instructions that these two depend
2097
3
    // on.
2098
3
    SmallVector<Instruction*,2> Insts;
2099
3
    Insts.push_back(SI);
2100
3
    Insts.push_back(LI);
2101
3
    if (!coverLoop(CurLoop, Insts))
2102
0
      goto CleanupAndExit;
2103
3
2104
3
    if (DisableMemmoveIdiom || !HasMemmove)
2105
0
      goto CleanupAndExit;
2106
3
    bool IsNested = CurLoop->getParentLoop() != nullptr;
2107
3
    if (IsNested && 
OnlyNonNestedMemmove0
)
2108
0
      goto CleanupAndExit;
2109
3
  }
2110
3
2111
3
  // For a memcpy, we have to make sure that the input array is not being
2112
3
  // mutated by the loop.
2113
3
  LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(),
2114
3
      Builder.getInt8PtrTy(LI->getPointerAddressSpace()), ExpPt);
2115
3
2116
3
  SmallPtrSet<Instruction*, 2> Ignore2;
2117
3
  Ignore2.insert(SI);
2118
3
  if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
2119
3
                            StoreSize, *AA, Ignore2))
2120
0
    goto CleanupAndExit;
2121
3
2122
3
  // Check the stride.
2123
3
  bool StridePos = getSCEVStride(LoadEv) >= 0;
2124
3
2125
3
  // Currently, the volatile memcpy only emulates traversing memory forward.
2126
3
  if (!StridePos && 
DestVolatile0
)
2127
0
    goto CleanupAndExit;
2128
3
2129
3
  bool RuntimeCheck = (Overlap || 
DestVolatile0
);
2130
3
2131
3
  BasicBlock *ExitB;
2132
3
  if (RuntimeCheck) {
2133
3
    // The runtime check needs a single exit block.
2134
3
    SmallVector<BasicBlock*, 8> ExitBlocks;
2135
3
    CurLoop->getUniqueExitBlocks(ExitBlocks);
2136
3
    if (ExitBlocks.size() != 1)
2137
0
      goto CleanupAndExit;
2138
3
    ExitB = ExitBlocks[0];
2139
3
  }
2140
3
2141
3
  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
2142
3
  // pointer size if it isn't already.
2143
3
  LLVMContext &Ctx = SI->getContext();
2144
3
  BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy);
2145
3
  DebugLoc DLoc = SI->getDebugLoc();
2146
3
2147
3
  const SCEV *NumBytesS =
2148
3
      SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW);
2149
3
  if (StoreSize != 1)
2150
3
    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize),
2151
3
                               SCEV::FlagNUW);
2152
3
  Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt);
2153
3
  if (Instruction *In = dyn_cast<Instruction>(NumBytes))
2154
3
    if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT}))
2155
0
      NumBytes = Simp;
2156
3
2157
3
  CallInst *NewCall;
2158
3
2159
3
  if (RuntimeCheck) {
2160
3
    unsigned Threshold = RuntimeMemSizeThreshold;
2161
3
    if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) {
2162
0
      uint64_t C = CI->getZExtValue();
2163
0
      if (Threshold != 0 && C < Threshold)
2164
0
        goto CleanupAndExit;
2165
0
      if (C < CompileTimeMemSizeThreshold)
2166
0
        goto CleanupAndExit;
2167
3
    }
2168
3
2169
3
    BasicBlock *Header = CurLoop->getHeader();
2170
3
    Function *Func = Header->getParent();
2171
3
    Loop *ParentL = LF->getLoopFor(Preheader);
2172
3
    StringRef HeaderName = Header->getName();
2173
3
2174
3
    // Create a new (empty) preheader, and update the PHI nodes in the
2175
3
    // header to use the new preheader.
2176
3
    BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph",
2177
3
                                                  Func, Header);
2178
3
    if (ParentL)
2179
0
      ParentL->addBasicBlockToLoop(NewPreheader, *LF);
2180
3
    IRBuilder<>(NewPreheader).CreateBr(Header);
2181
8
    for (auto &In : *Header) {
2182
8
      PHINode *PN = dyn_cast<PHINode>(&In);
2183
8
      if (!PN)
2184
3
        break;
2185
5
      int bx = PN->getBasicBlockIndex(Preheader);
2186
5
      if (bx >= 0)
2187
5
        PN->setIncomingBlock(bx, NewPreheader);
2188
5
    }
2189
3
    DT->addNewBlock(NewPreheader, Preheader);
2190
3
    DT->changeImmediateDominator(Header, NewPreheader);
2191
3
2192
3
    // Check for safe conditions to execute memmove.
2193
3
    // If stride is positive, copying things from higher to lower addresses
2194
3
    // is equivalent to memmove.  For negative stride, it's the other way
2195
3
    // around.  Copying forward in memory with positive stride may not be
2196
3
    // same as memmove since we may be copying values that we just stored
2197
3
    // in some previous iteration.
2198
3
    Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy);
2199
3
    Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy);
2200
3
    Value *LowA = StridePos ? SA : 
LA0
;
2201
3
    Value *HighA = StridePos ? LA : 
SA0
;
2202
3
    Value *CmpA = Builder.CreateICmpULT(LowA, HighA);
2203
3
    Value *Cond = CmpA;
2204
3
2205
3
    // Check for distance between pointers. Since the case LowA < HighA
2206
3
    // is checked for above, assume LowA >= HighA.
2207
3
    Value *Dist = Builder.CreateSub(LowA, HighA);
2208
3
    Value *CmpD = Builder.CreateICmpSLE(NumBytes, Dist);
2209
3
    Value *CmpEither = Builder.CreateOr(Cond, CmpD);
2210
3
    Cond = CmpEither;
2211
3
2212
3
    if (Threshold != 0) {
2213
0
      Type *Ty = NumBytes->getType();
2214
0
      Value *Thr = ConstantInt::get(Ty, Threshold);
2215
0
      Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes);
2216
0
      Value *CmpBoth = Builder.CreateAnd(Cond, CmpB);
2217
0
      Cond = CmpBoth;
2218
0
    }
2219
3
    BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli",
2220
3
                                              Func, NewPreheader);
2221
3
    if (ParentL)
2222
0
      ParentL->addBasicBlockToLoop(MemmoveB, *LF);
2223
3
    Instruction *OldT = Preheader->getTerminator();
2224
3
    Builder.CreateCondBr(Cond, MemmoveB, NewPreheader);
2225
3
    OldT->eraseFromParent();
2226
3
    Preheader->setName(Preheader->getName()+".old");
2227
3
    DT->addNewBlock(MemmoveB, Preheader);
2228
3
    // Find the new immediate dominator of the exit block.
2229
3
    BasicBlock *ExitD = Preheader;
2230
6
    for (auto PI = pred_begin(ExitB), PE = pred_end(ExitB); PI != PE; 
++PI3
) {
2231
3
      BasicBlock *PB = *PI;
2232
3
      ExitD = DT->findNearestCommonDominator(ExitD, PB);
2233
3
      if (!ExitD)
2234
0
        break;
2235
3
    }
2236
3
    // If the prior immediate dominator of ExitB was dominated by the
2237
3
    // old preheader, then the old preheader becomes the new immediate
2238
3
    // dominator.  Otherwise don't change anything (because the newly
2239
3
    // added blocks are dominated by the old preheader).
2240
3
    if (ExitD && DT->dominates(Preheader, ExitD)) {
2241
3
      DomTreeNode *BN = DT->getNode(ExitB);
2242
3
      DomTreeNode *DN = DT->getNode(ExitD);
2243
3
      BN->setIDom(DN);
2244
3
    }
2245
3
2246
3
    // Add a call to memmove to the conditional block.
2247
3
    IRBuilder<> CondBuilder(MemmoveB);
2248
3
    CondBuilder.CreateBr(ExitB);
2249
3
    CondBuilder.SetInsertPoint(MemmoveB->getTerminator());
2250
3
2251
3
    if (DestVolatile) {
2252
0
      Type *Int32Ty = Type::getInt32Ty(Ctx);
2253
0
      Type *Int32PtrTy = Type::getInt32PtrTy(Ctx);
2254
0
      Type *VoidTy = Type::getVoidTy(Ctx);
2255
0
      Module *M = Func->getParent();
2256
0
      FunctionCallee Fn = M->getOrInsertFunction(
2257
0
          HexagonVolatileMemcpyName, VoidTy, Int32PtrTy, Int32PtrTy, Int32Ty);
2258
0
2259
0
      const SCEV *OneS = SE->getConstant(Int32Ty, 1);
2260
0
      const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty);
2261
0
      const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW);
2262
0
      Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty,
2263
0
                                               MemmoveB->getTerminator());
2264
0
      if (Instruction *In = dyn_cast<Instruction>(NumWords))
2265
0
        if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT}))
2266
0
          NumWords = Simp;
2267
0
2268
0
      Value *Op0 = (StoreBasePtr->getType() == Int32PtrTy)
2269
0
                      ? StoreBasePtr
2270
0
                      : CondBuilder.CreateBitCast(StoreBasePtr, Int32PtrTy);
2271
0
      Value *Op1 = (LoadBasePtr->getType() == Int32PtrTy)
2272
0
                      ? LoadBasePtr
2273
0
                      : CondBuilder.CreateBitCast(LoadBasePtr, Int32PtrTy);
2274
0
      NewCall = CondBuilder.CreateCall(Fn, {Op0, Op1, NumWords});
2275
3
    } else {
2276
3
      NewCall = CondBuilder.CreateMemMove(StoreBasePtr, SI->getAlignment(),
2277
3
                                          LoadBasePtr, LI->getAlignment(),
2278
3
                                          NumBytes);
2279
3
    }
2280
3
  } else {
2281
0
    NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlignment(),
2282
0
                                   LoadBasePtr, LI->getAlignment(),
2283
0
                                   NumBytes);
2284
0
    // Okay, the memcpy has been formed.  Zap the original store and
2285
0
    // anything that feeds into it.
2286
0
    RecursivelyDeleteTriviallyDeadInstructions(SI, TLI);
2287
0
  }
2288
3
2289
3
  NewCall->setDebugLoc(DLoc);
2290
3
2291
3
  LLVM_DEBUG(dbgs() << "  Formed " << (Overlap ? "memmove: " : "memcpy: ")
2292
3
                    << *NewCall << "\n"
2293
3
                    << "    from load ptr=" << *LoadEv << " at: " << *LI << "\n"
2294
3
                    << "    from store ptr=" << *StoreEv << " at: " << *SI
2295
3
                    << "\n");
2296
3
2297
3
  return true;
2298
3
}
2299
2300
// Check if the instructions in Insts, together with their dependencies
2301
// cover the loop in the sense that the loop could be safely eliminated once
2302
// the instructions in Insts are removed.
2303
bool HexagonLoopIdiomRecognize::coverLoop(Loop *L,
2304
3
      SmallVectorImpl<Instruction*> &Insts) const {
2305
3
  SmallSet<BasicBlock*,8> LoopBlocks;
2306
3
  for (auto *B : L->blocks())
2307
4
    LoopBlocks.insert(B);
2308
3
2309
3
  SetVector<Instruction*> Worklist(Insts.begin(), Insts.end());
2310
3
2311
3
  // Collect all instructions from the loop that the instructions in Insts
2312
3
  // depend on (plus their dependencies, etc.).  These instructions will
2313
3
  // constitute the expression trees that feed those in Insts, but the trees
2314
3
  // will be limited only to instructions contained in the loop.
2315
21
  for (unsigned i = 0; i < Worklist.size(); 
++i18
) {
2316
18
    Instruction *In = Worklist[i];
2317
51
    for (auto I = In->op_begin(), E = In->op_end(); I != E; 
++I33
) {
2318
33
      Instruction *OpI = dyn_cast<Instruction>(I);
2319
33
      if (!OpI)
2320
8
        continue;
2321
25
      BasicBlock *PB = OpI->getParent();
2322
25
      if (!LoopBlocks.count(PB))
2323
4
        continue;
2324
21
      Worklist.insert(OpI);
2325
21
    }
2326
18
  }
2327
3
2328
3
  // Scan all instructions in the loop, if any of them have a user outside
2329
3
  // of the loop, or outside of the expressions collected above, then either
2330
3
  // the loop has a side-effect visible outside of it, or there are
2331
3
  // instructions in it that are not involved in the original set Insts.
2332
4
  for (auto *B : L->blocks()) {
2333
27
    for (auto &In : *B) {
2334
27
      if (isa<BranchInst>(In) || 
isa<DbgInfoIntrinsic>(In)23
)
2335
4
        continue;
2336
23
      if (!Worklist.count(&In) && 
In.mayHaveSideEffects()5
)
2337
0
        return false;
2338
29
      
for (const auto &K : In.users())23
{
2339
29
        Instruction *UseI = dyn_cast<Instruction>(K);
2340
29
        if (!UseI)
2341
0
          continue;
2342
29
        BasicBlock *UseB = UseI->getParent();
2343
29
        if (LF->getLoopFor(UseB) != L)
2344
0
          return false;
2345
29
      }
2346
23
    }
2347
4
  }
2348
3
2349
3
  return true;
2350
3
}
2351
2352
/// runOnLoopBlock - Process the specified block, which lives in a counted loop
2353
/// with the specified backedge count.  This block is known to be in the current
2354
/// loop and not in any subloops.
2355
bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB,
2356
7
      const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) {
2357
7
  // We can only promote stores in this block if they are unconditionally
2358
7
  // executed in the loop.  For a block to be unconditionally executed, it has
2359
7
  // to dominate all the exit blocks of the loop.  Verify this now.
2360
7
  auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool {
2361
7
    return DT->dominates(BB, EB);
2362
7
  };
2363
7
  if (!all_of(ExitBlocks, DominatedByBB))
2364
1
    return false;
2365
6
2366
6
  bool MadeChange = false;
2367
6
  // Look for store instructions, which may be optimized to memset/memcpy.
2368
6
  SmallVector<StoreInst*,8> Stores;
2369
6
  collectStores(CurLoop, BB, Stores);
2370
6
2371
6
  // Optimize the store into a memcpy, if it feeds an similarly strided load.
2372
6
  for (auto &SI : Stores)
2373
3
    MadeChange |= processCopyingStore(CurLoop, SI, BECount);
2374
6
2375
6
  return MadeChange;
2376
6
}
2377
2378
8
bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) {
2379
8
  PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE);
2380
8
  if (PMR.recognize())
2381
2
    return true;
2382
6
2383
6
  if (!HasMemcpy && 
!HasMemmove0
)
2384
0
    return false;
2385
6
2386
6
  const SCEV *BECount = SE->getBackedgeTakenCount(L);
2387
6
  assert(!isa<SCEVCouldNotCompute>(BECount) &&
2388
6
         "runOnCountableLoop() called on a loop without a predictable"
2389
6
         "backedge-taken count");
2390
6
2391
6
  SmallVector<BasicBlock *, 8> ExitBlocks;
2392
6
  L->getUniqueExitBlocks(ExitBlocks);
2393
6
2394
6
  bool Changed = false;
2395
6
2396
6
  // Scan all the blocks in the loop that are not in subloops.
2397
7
  for (auto *BB : L->getBlocks()) {
2398
7
    // Ignore blocks in subloops.
2399
7
    if (LF->getLoopFor(BB) != L)
2400
0
      continue;
2401
7
    Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks);
2402
7
  }
2403
6
2404
6
  return Changed;
2405
6
}
2406
2407
10
bool HexagonLoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) {
2408
10
  const Module &M = *L->getHeader()->getParent()->getParent();
2409
10
  if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon)
2410
1
    return false;
2411
9
2412
9
  if (skipLoop(L))
2413
0
    return false;
2414
9
2415
9
  // If the loop could not be converted to canonical form, it must have an
2416
9
  // indirectbr in it, just give up.
2417
9
  if (!L->getLoopPreheader())
2418
0
    return false;
2419
9
2420
9
  // Disable loop idiom recognition if the function's name is a common idiom.
2421
9
  StringRef Name = L->getHeader()->getParent()->getName();
2422
9
  if (Name == "memset" || Name == "memcpy" || Name == "memmove")
2423
0
    return false;
2424
9
2425
9
  AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
2426
9
  DL = &L->getHeader()->getModule()->getDataLayout();
2427
9
  DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2428
9
  LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2429
9
  TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
2430
9
  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
2431
9
2432
9
  HasMemcpy = TLI->has(LibFunc_memcpy);
2433
9
  HasMemmove = TLI->has(LibFunc_memmove);
2434
9
2435
9
  if (SE->hasLoopInvariantBackedgeTakenCount(L))
2436
8
    return runOnCountableLoop(L);
2437
1
  return false;
2438
1
}
2439
2440
12
Pass *llvm::createHexagonLoopIdiomPass() {
2441
12
  return new HexagonLoopIdiomRecognize();
2442
12
}