Coverage Report

Created: 2017-10-03 07:32

/Users/buildslave/jenkins/sharedspace/clang-stage2-coverage-R@2/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2
//                                    instrinsics
3
//
4
//                     The LLVM Compiler Infrastructure
5
//
6
// This file is distributed under the University of Illinois Open Source
7
// License. See LICENSE.TXT for details.
8
//
9
//===----------------------------------------------------------------------===//
10
//
11
// This pass replaces masked memory intrinsics - when unsupported by the target
12
// - with a chain of basic blocks, that deal with the elements one-by-one if the
13
// appropriate mask bit is set.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#include "llvm/ADT/Twine.h"
18
#include "llvm/Analysis/TargetTransformInfo.h"
19
#include "llvm/IR/BasicBlock.h"
20
#include "llvm/IR/Constant.h"
21
#include "llvm/IR/Constants.h"
22
#include "llvm/IR/DerivedTypes.h"
23
#include "llvm/IR/Function.h"
24
#include "llvm/IR/IRBuilder.h"
25
#include "llvm/IR/InstrTypes.h"
26
#include "llvm/IR/Instruction.h"
27
#include "llvm/IR/Instructions.h"
28
#include "llvm/IR/IntrinsicInst.h"
29
#include "llvm/IR/Intrinsics.h"
30
#include "llvm/IR/Type.h"
31
#include "llvm/IR/Value.h"
32
#include "llvm/Pass.h"
33
#include "llvm/Support/Casting.h"
34
#include "llvm/Target/TargetSubtargetInfo.h"
35
#include <algorithm>
36
#include <cassert>
37
38
using namespace llvm;
39
40
#define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42
namespace {
43
44
class ScalarizeMaskedMemIntrin : public FunctionPass {
45
  const TargetTransformInfo *TTI = nullptr;
46
47
public:
48
  static char ID; // Pass identification, replacement for typeid
49
50
33.5k
  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51
33.5k
    initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52
33.5k
  }
53
54
  bool runOnFunction(Function &F) override;
55
56
23
  StringRef getPassName() const override {
57
23
    return "Scalarize Masked Memory Intrinsics";
58
23
  }
59
60
33.4k
  void getAnalysisUsage(AnalysisUsage &AU) const override {
61
33.4k
    AU.addRequired<TargetTransformInfoWrapperPass>();
62
33.4k
  }
63
64
private:
65
  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66
  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67
};
68
69
} // end anonymous namespace
70
71
char ScalarizeMaskedMemIntrin::ID = 0;
72
73
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74
                "Scalarize unsupported masked memory intrinsics", false, false)
75
76
33.5k
FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77
33.5k
  return new ScalarizeMaskedMemIntrin();
78
33.5k
}
79
80
// Translate a masked load intrinsic like
81
// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
82
//                               <16 x i1> %mask, <16 x i32> %passthru)
83
// to a chain of basic blocks, with loading element one-by-one if
84
// the appropriate mask bit is set
85
//
86
//  %1 = bitcast i8* %addr to i32*
87
//  %2 = extractelement <16 x i1> %mask, i32 0
88
//  %3 = icmp eq i1 %2, true
89
//  br i1 %3, label %cond.load, label %else
90
//
91
// cond.load:                                        ; preds = %0
92
//  %4 = getelementptr i32* %1, i32 0
93
//  %5 = load i32* %4
94
//  %6 = insertelement <16 x i32> undef, i32 %5, i32 0
95
//  br label %else
96
//
97
// else:                                             ; preds = %0, %cond.load
98
//  %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
99
//  %7 = extractelement <16 x i1> %mask, i32 1
100
//  %8 = icmp eq i1 %7, true
101
//  br i1 %8, label %cond.load1, label %else2
102
//
103
// cond.load1:                                       ; preds = %else
104
//  %9 = getelementptr i32* %1, i32 1
105
//  %10 = load i32* %9
106
//  %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
107
//  br label %else2
108
//
109
// else2:                                          ; preds = %else, %cond.load1
110
//  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
111
//  %12 = extractelement <16 x i1> %mask, i32 2
112
//  %13 = icmp eq i1 %12, true
113
//  br i1 %13, label %cond.load4, label %else5
114
//
115
1
static void scalarizeMaskedLoad(CallInst *CI) {
116
1
  Value *Ptr = CI->getArgOperand(0);
117
1
  Value *Alignment = CI->getArgOperand(1);
118
1
  Value *Mask = CI->getArgOperand(2);
119
1
  Value *Src0 = CI->getArgOperand(3);
120
1
121
1
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
122
1
  VectorType *VecType = dyn_cast<VectorType>(CI->getType());
123
1
  assert(VecType && "Unexpected return type of masked load intrinsic");
124
1
125
1
  Type *EltTy = CI->getType()->getVectorElementType();
126
1
127
1
  IRBuilder<> Builder(CI->getContext());
128
1
  Instruction *InsertPt = CI;
129
1
  BasicBlock *IfBlock = CI->getParent();
130
1
  BasicBlock *CondBlock = nullptr;
131
1
  BasicBlock *PrevIfBlock = CI->getParent();
132
1
133
1
  Builder.SetInsertPoint(InsertPt);
134
1
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
135
1
136
1
  // Short-cut if the mask is all-true.
137
1
  bool IsAllOnesMask =
138
0
      isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
139
1
140
1
  if (
IsAllOnesMask1
) {
141
0
    Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
142
0
    CI->replaceAllUsesWith(NewI);
143
0
    CI->eraseFromParent();
144
0
    return;
145
0
  }
146
1
147
1
  // Adjust alignment for the scalar instruction.
148
1
  AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
149
1
  // Bitcast %addr fron i8* to EltTy*
150
1
  Type *NewPtrType =
151
1
      EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
152
1
  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
153
1
  unsigned VectorWidth = VecType->getNumElements();
154
1
155
1
  Value *UndefVal = UndefValue::get(VecType);
156
1
157
1
  // The result vector
158
1
  Value *VResult = UndefVal;
159
1
160
1
  if (
isa<ConstantVector>(Mask)1
) {
161
0
    for (unsigned Idx = 0; 
Idx < VectorWidth0
;
++Idx0
) {
162
0
      if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
163
0
        continue;
164
0
      Value *Gep =
165
0
          Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
166
0
      LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
167
0
      VResult =
168
0
          Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
169
0
    }
170
0
    Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
171
0
    CI->replaceAllUsesWith(NewI);
172
0
    CI->eraseFromParent();
173
0
    return;
174
0
  }
175
1
176
1
  PHINode *Phi = nullptr;
177
1
  Value *PrevPhi = UndefVal;
178
1
179
5
  for (unsigned Idx = 0; 
Idx < VectorWidth5
;
++Idx4
) {
180
4
    // Fill the "else" block, created in the previous iteration
181
4
    //
182
4
    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
183
4
    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
184
4
    //  %to_load = icmp eq i1 %mask_1, true
185
4
    //  br i1 %to_load, label %cond.load, label %else
186
4
    //
187
4
    if (
Idx > 04
) {
188
3
      Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
189
3
      Phi->addIncoming(VResult, CondBlock);
190
3
      Phi->addIncoming(PrevPhi, PrevIfBlock);
191
3
      PrevPhi = Phi;
192
3
      VResult = Phi;
193
3
    }
194
4
195
4
    Value *Predicate =
196
4
        Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
197
4
    Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
198
4
                                    ConstantInt::get(Predicate->getType(), 1));
199
4
200
4
    // Create "cond" block
201
4
    //
202
4
    //  %EltAddr = getelementptr i32* %1, i32 0
203
4
    //  %Elt = load i32* %EltAddr
204
4
    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
205
4
    //
206
4
    CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
207
4
    Builder.SetInsertPoint(InsertPt);
208
4
209
4
    Value *Gep =
210
4
        Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
211
4
    LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
212
4
    VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
213
4
214
4
    // Create "else" block, fill it in the next iteration
215
4
    BasicBlock *NewIfBlock =
216
4
        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
217
4
    Builder.SetInsertPoint(InsertPt);
218
4
    Instruction *OldBr = IfBlock->getTerminator();
219
4
    BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
220
4
    OldBr->eraseFromParent();
221
4
    PrevIfBlock = IfBlock;
222
4
    IfBlock = NewIfBlock;
223
4
  }
224
1
225
1
  Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
226
1
  Phi->addIncoming(VResult, CondBlock);
227
1
  Phi->addIncoming(PrevPhi, PrevIfBlock);
228
1
  Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
229
1
  CI->replaceAllUsesWith(NewI);
230
1
  CI->eraseFromParent();
231
1
}
232
233
// Translate a masked store intrinsic, like
234
// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
235
//                               <16 x i1> %mask)
236
// to a chain of basic blocks, that stores element one-by-one if
237
// the appropriate mask bit is set
238
//
239
//   %1 = bitcast i8* %addr to i32*
240
//   %2 = extractelement <16 x i1> %mask, i32 0
241
//   %3 = icmp eq i1 %2, true
242
//   br i1 %3, label %cond.store, label %else
243
//
244
// cond.store:                                       ; preds = %0
245
//   %4 = extractelement <16 x i32> %val, i32 0
246
//   %5 = getelementptr i32* %1, i32 0
247
//   store i32 %4, i32* %5
248
//   br label %else
249
//
250
// else:                                             ; preds = %0, %cond.store
251
//   %6 = extractelement <16 x i1> %mask, i32 1
252
//   %7 = icmp eq i1 %6, true
253
//   br i1 %7, label %cond.store1, label %else2
254
//
255
// cond.store1:                                      ; preds = %else
256
//   %8 = extractelement <16 x i32> %val, i32 1
257
//   %9 = getelementptr i32* %1, i32 1
258
//   store i32 %8, i32* %9
259
//   br label %else2
260
//   . . .
261
1
static void scalarizeMaskedStore(CallInst *CI) {
262
1
  Value *Src = CI->getArgOperand(0);
263
1
  Value *Ptr = CI->getArgOperand(1);
264
1
  Value *Alignment = CI->getArgOperand(2);
265
1
  Value *Mask = CI->getArgOperand(3);
266
1
267
1
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
268
1
  VectorType *VecType = dyn_cast<VectorType>(Src->getType());
269
1
  assert(VecType && "Unexpected data type in masked store intrinsic");
270
1
271
1
  Type *EltTy = VecType->getElementType();
272
1
273
1
  IRBuilder<> Builder(CI->getContext());
274
1
  Instruction *InsertPt = CI;
275
1
  BasicBlock *IfBlock = CI->getParent();
276
1
  Builder.SetInsertPoint(InsertPt);
277
1
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
278
1
279
1
  // Short-cut if the mask is all-true.
280
1
  bool IsAllOnesMask =
281
0
      isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
282
1
283
1
  if (
IsAllOnesMask1
) {
284
0
    Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285
0
    CI->eraseFromParent();
286
0
    return;
287
0
  }
288
1
289
1
  // Adjust alignment for the scalar instruction.
290
1
  AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
291
1
  // Bitcast %addr fron i8* to EltTy*
292
1
  Type *NewPtrType =
293
1
      EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
294
1
  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295
1
  unsigned VectorWidth = VecType->getNumElements();
296
1
297
1
  if (
isa<ConstantVector>(Mask)1
) {
298
0
    for (unsigned Idx = 0; 
Idx < VectorWidth0
;
++Idx0
) {
299
0
      if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
300
0
        continue;
301
0
      Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
302
0
      Value *Gep =
303
0
          Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
304
0
      Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
305
0
    }
306
0
    CI->eraseFromParent();
307
0
    return;
308
0
  }
309
1
310
5
  
for (unsigned Idx = 0; 1
Idx < VectorWidth5
;
++Idx4
) {
311
4
    // Fill the "else" block, created in the previous iteration
312
4
    //
313
4
    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
314
4
    //  %to_store = icmp eq i1 %mask_1, true
315
4
    //  br i1 %to_store, label %cond.store, label %else
316
4
    //
317
4
    Value *Predicate =
318
4
        Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
319
4
    Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
320
4
                                    ConstantInt::get(Predicate->getType(), 1));
321
4
322
4
    // Create "cond" block
323
4
    //
324
4
    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
325
4
    //  %EltAddr = getelementptr i32* %1, i32 0
326
4
    //  %store i32 %OneElt, i32* %EltAddr
327
4
    //
328
4
    BasicBlock *CondBlock =
329
4
        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
330
4
    Builder.SetInsertPoint(InsertPt);
331
4
332
4
    Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
333
4
    Value *Gep =
334
4
        Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
335
4
    Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
336
4
337
4
    // Create "else" block, fill it in the next iteration
338
4
    BasicBlock *NewIfBlock =
339
4
        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
340
4
    Builder.SetInsertPoint(InsertPt);
341
4
    Instruction *OldBr = IfBlock->getTerminator();
342
4
    BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
343
4
    OldBr->eraseFromParent();
344
4
    IfBlock = NewIfBlock;
345
4
  }
346
1
  CI->eraseFromParent();
347
1
}
348
349
// Translate a masked gather intrinsic like
350
// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
351
//                               <16 x i1> %Mask, <16 x i32> %Src)
352
// to a chain of basic blocks, with loading element one-by-one if
353
// the appropriate mask bit is set
354
//
355
// % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
356
// % Mask0 = extractelement <16 x i1> %Mask, i32 0
357
// % ToLoad0 = icmp eq i1 % Mask0, true
358
// br i1 % ToLoad0, label %cond.load, label %else
359
//
360
// cond.load:
361
// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
362
// % Load0 = load i32, i32* % Ptr0, align 4
363
// % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
364
// br label %else
365
//
366
// else:
367
// %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
368
// % Mask1 = extractelement <16 x i1> %Mask, i32 1
369
// % ToLoad1 = icmp eq i1 % Mask1, true
370
// br i1 % ToLoad1, label %cond.load1, label %else2
371
//
372
// cond.load1:
373
// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
374
// % Load1 = load i32, i32* % Ptr1, align 4
375
// % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
376
// br label %else2
377
// . . .
378
// % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
379
// ret <16 x i32> %Result
380
62
static void scalarizeMaskedGather(CallInst *CI) {
381
62
  Value *Ptrs = CI->getArgOperand(0);
382
62
  Value *Alignment = CI->getArgOperand(1);
383
62
  Value *Mask = CI->getArgOperand(2);
384
62
  Value *Src0 = CI->getArgOperand(3);
385
62
386
62
  VectorType *VecType = dyn_cast<VectorType>(CI->getType());
387
62
388
62
  assert(VecType && "Unexpected return type of masked load intrinsic");
389
62
390
62
  IRBuilder<> Builder(CI->getContext());
391
62
  Instruction *InsertPt = CI;
392
62
  BasicBlock *IfBlock = CI->getParent();
393
62
  BasicBlock *CondBlock = nullptr;
394
62
  BasicBlock *PrevIfBlock = CI->getParent();
395
62
  Builder.SetInsertPoint(InsertPt);
396
62
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
397
62
398
62
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
399
62
400
62
  Value *UndefVal = UndefValue::get(VecType);
401
62
402
62
  // The result vector
403
62
  Value *VResult = UndefVal;
404
62
  unsigned VectorWidth = VecType->getNumElements();
405
62
406
62
  // Shorten the way if the mask is a vector of constants.
407
62
  bool IsConstMask = isa<ConstantVector>(Mask);
408
62
409
62
  if (
IsConstMask62
) {
410
155
    for (unsigned Idx = 0; 
Idx < VectorWidth155
;
++Idx142
) {
411
142
      if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
412
13
        continue;
413
129
      Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
414
129
                                                "Ptr" + Twine(Idx));
415
129
      LoadInst *Load =
416
129
          Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
417
129
      VResult = Builder.CreateInsertElement(
418
129
          VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
419
129
    }
420
13
    Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
421
13
    CI->replaceAllUsesWith(NewI);
422
13
    CI->eraseFromParent();
423
13
    return;
424
13
  }
425
49
426
49
  PHINode *Phi = nullptr;
427
49
  Value *PrevPhi = UndefVal;
428
49
429
357
  for (unsigned Idx = 0; 
Idx < VectorWidth357
;
++Idx308
) {
430
308
    // Fill the "else" block, created in the previous iteration
431
308
    //
432
308
    //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
433
308
    //  %ToLoad1 = icmp eq i1 %Mask1, true
434
308
    //  br i1 %ToLoad1, label %cond.load, label %else
435
308
    //
436
308
    if (
Idx > 0308
) {
437
259
      Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
438
259
      Phi->addIncoming(VResult, CondBlock);
439
259
      Phi->addIncoming(PrevPhi, PrevIfBlock);
440
259
      PrevPhi = Phi;
441
259
      VResult = Phi;
442
259
    }
443
308
444
308
    Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
445
308
                                                    "Mask" + Twine(Idx));
446
308
    Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
447
308
                                    ConstantInt::get(Predicate->getType(), 1),
448
308
                                    "ToLoad" + Twine(Idx));
449
308
450
308
    // Create "cond" block
451
308
    //
452
308
    //  %EltAddr = getelementptr i32* %1, i32 0
453
308
    //  %Elt = load i32* %EltAddr
454
308
    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
455
308
    //
456
308
    CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
457
308
    Builder.SetInsertPoint(InsertPt);
458
308
459
308
    Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
460
308
                                              "Ptr" + Twine(Idx));
461
308
    LoadInst *Load =
462
308
        Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
463
308
    VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
464
308
                                          "Res" + Twine(Idx));
465
308
466
308
    // Create "else" block, fill it in the next iteration
467
308
    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
468
308
    Builder.SetInsertPoint(InsertPt);
469
308
    Instruction *OldBr = IfBlock->getTerminator();
470
308
    BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
471
308
    OldBr->eraseFromParent();
472
308
    PrevIfBlock = IfBlock;
473
308
    IfBlock = NewIfBlock;
474
308
  }
475
62
476
62
  Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
477
62
  Phi->addIncoming(VResult, CondBlock);
478
62
  Phi->addIncoming(PrevPhi, PrevIfBlock);
479
62
  Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
480
62
  CI->replaceAllUsesWith(NewI);
481
62
  CI->eraseFromParent();
482
62
}
483
484
// Translate a masked scatter intrinsic, like
485
// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
486
//                                  <16 x i1> %Mask)
487
// to a chain of basic blocks, that stores element one-by-one if
488
// the appropriate mask bit is set.
489
//
490
// % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
491
// % Mask0 = extractelement <16 x i1> % Mask, i32 0
492
// % ToStore0 = icmp eq i1 % Mask0, true
493
// br i1 %ToStore0, label %cond.store, label %else
494
//
495
// cond.store:
496
// % Elt0 = extractelement <16 x i32> %Src, i32 0
497
// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
498
// store i32 %Elt0, i32* % Ptr0, align 4
499
// br label %else
500
//
501
// else:
502
// % Mask1 = extractelement <16 x i1> % Mask, i32 1
503
// % ToStore1 = icmp eq i1 % Mask1, true
504
// br i1 % ToStore1, label %cond.store1, label %else2
505
//
506
// cond.store1:
507
// % Elt1 = extractelement <16 x i32> %Src, i32 1
508
// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
509
// store i32 % Elt1, i32* % Ptr1, align 4
510
// br label %else2
511
//   . . .
512
13
static void scalarizeMaskedScatter(CallInst *CI) {
513
13
  Value *Src = CI->getArgOperand(0);
514
13
  Value *Ptrs = CI->getArgOperand(1);
515
13
  Value *Alignment = CI->getArgOperand(2);
516
13
  Value *Mask = CI->getArgOperand(3);
517
13
518
13
  assert(isa<VectorType>(Src->getType()) &&
519
13
         "Unexpected data type in masked scatter intrinsic");
520
13
  assert(isa<VectorType>(Ptrs->getType()) &&
521
13
         isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
522
13
         "Vector of pointers is expected in masked scatter intrinsic");
523
13
524
13
  IRBuilder<> Builder(CI->getContext());
525
13
  Instruction *InsertPt = CI;
526
13
  BasicBlock *IfBlock = CI->getParent();
527
13
  Builder.SetInsertPoint(InsertPt);
528
13
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
529
13
530
13
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
531
13
  unsigned VectorWidth = Src->getType()->getVectorNumElements();
532
13
533
13
  // Shorten the way if the mask is a vector of constants.
534
13
  bool IsConstMask = isa<ConstantVector>(Mask);
535
13
536
13
  if (
IsConstMask13
) {
537
12
    for (unsigned Idx = 0; 
Idx < VectorWidth12
;
++Idx10
) {
538
10
      if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
539
0
        continue;
540
10
      Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
541
10
                                                   "Elt" + Twine(Idx));
542
10
      Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
543
10
                                                "Ptr" + Twine(Idx));
544
10
      Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
545
10
    }
546
2
    CI->eraseFromParent();
547
2
    return;
548
2
  }
549
123
  
for (unsigned Idx = 0; 11
Idx < VectorWidth123
;
++Idx112
) {
550
112
    // Fill the "else" block, created in the previous iteration
551
112
    //
552
112
    //  % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
553
112
    //  % ToStore = icmp eq i1 % Mask1, true
554
112
    //  br i1 % ToStore, label %cond.store, label %else
555
112
    //
556
112
    Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
557
112
                                                    "Mask" + Twine(Idx));
558
112
    Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
559
112
                                    ConstantInt::get(Predicate->getType(), 1),
560
112
                                    "ToStore" + Twine(Idx));
561
112
562
112
    // Create "cond" block
563
112
    //
564
112
    //  % Elt1 = extractelement <16 x i32> %Src, i32 1
565
112
    //  % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
566
112
    //  %store i32 % Elt1, i32* % Ptr1
567
112
    //
568
112
    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
569
112
    Builder.SetInsertPoint(InsertPt);
570
112
571
112
    Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
572
112
                                                 "Elt" + Twine(Idx));
573
112
    Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
574
112
                                              "Ptr" + Twine(Idx));
575
112
    Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
576
112
577
112
    // Create "else" block, fill it in the next iteration
578
112
    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
579
112
    Builder.SetInsertPoint(InsertPt);
580
112
    Instruction *OldBr = IfBlock->getTerminator();
581
112
    BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
582
112
    OldBr->eraseFromParent();
583
112
    IfBlock = NewIfBlock;
584
112
  }
585
13
  CI->eraseFromParent();
586
13
}
587
588
596k
bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
589
596k
  if (skipFunction(F))
590
520
    return false;
591
595k
592
595k
  bool EverMadeChange = false;
593
595k
594
595k
  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
595
595k
596
595k
  bool MadeChange = true;
597
1.19M
  while (
MadeChange1.19M
) {
598
595k
    MadeChange = false;
599
4.57M
    for (Function::iterator I = F.begin(); 
I != F.end()4.57M
;) {
600
3.98M
      BasicBlock *BB = &*I++;
601
3.98M
      bool ModifiedDTOnIteration = false;
602
3.98M
      MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
603
3.98M
604
3.98M
      // Restart BB iteration if the dominator tree of the Function was changed
605
3.98M
      if (ModifiedDTOnIteration)
606
77
        break;
607
3.98M
    }
608
595k
609
595k
    EverMadeChange |= MadeChange;
610
595k
  }
611
596k
612
596k
  return EverMadeChange;
613
596k
}
614
615
3.98M
bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
616
3.98M
  bool MadeChange = false;
617
3.98M
618
3.98M
  BasicBlock::iterator CurInstIterator = BB.begin();
619
24.3M
  while (
CurInstIterator != BB.end()24.3M
) {
620
20.3M
    if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
621
2.74M
      MadeChange |= optimizeCallInst(CI, ModifiedDT);
622
20.3M
    if (ModifiedDT)
623
77
      return true;
624
20.3M
  }
625
3.98M
626
3.98M
  return MadeChange;
627
3.98M
}
628
629
bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
630
2.74M
                                                bool &ModifiedDT) {
631
2.74M
  IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
632
2.74M
  if (
II2.74M
) {
633
407k
    switch (II->getIntrinsicID()) {
634
407k
    default:
635
407k
      break;
636
212
    case Intrinsic::masked_load:
637
212
      // Scalarize unsupported vector masked load
638
212
      if (
!TTI->isLegalMaskedLoad(CI->getType())212
) {
639
1
        scalarizeMaskedLoad(CI);
640
1
        ModifiedDT = true;
641
1
        return true;
642
1
      }
643
211
      return false;
644
101
    case Intrinsic::masked_store:
645
101
      if (
!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())101
) {
646
1
        scalarizeMaskedStore(CI);
647
1
        ModifiedDT = true;
648
1
        return true;
649
1
      }
650
100
      return false;
651
237
    case Intrinsic::masked_gather:
652
237
      if (
!TTI->isLegalMaskedGather(CI->getType())237
) {
653
62
        scalarizeMaskedGather(CI);
654
62
        ModifiedDT = true;
655
62
        return true;
656
62
      }
657
175
      return false;
658
74
    case Intrinsic::masked_scatter:
659
74
      if (
!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())74
) {
660
13
        scalarizeMaskedScatter(CI);
661
13
        ModifiedDT = true;
662
13
        return true;
663
13
      }
664
61
      return false;
665
407k
    }
666
407k
  }
667
2.74M
668
2.74M
  return false;
669
2.74M
}