Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2
//                                    instrinsics
3
//
4
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5
// See https://llvm.org/LICENSE.txt for license information.
6
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
//
8
//===----------------------------------------------------------------------===//
9
//
10
// This pass replaces masked memory intrinsics - when unsupported by the target
11
// - with a chain of basic blocks, that deal with the elements one-by-one if the
12
// appropriate mask bit is set.
13
//
14
//===----------------------------------------------------------------------===//
15
16
#include "llvm/ADT/Twine.h"
17
#include "llvm/Analysis/TargetTransformInfo.h"
18
#include "llvm/CodeGen/TargetSubtargetInfo.h"
19
#include "llvm/IR/BasicBlock.h"
20
#include "llvm/IR/Constant.h"
21
#include "llvm/IR/Constants.h"
22
#include "llvm/IR/DerivedTypes.h"
23
#include "llvm/IR/Function.h"
24
#include "llvm/IR/IRBuilder.h"
25
#include "llvm/IR/InstrTypes.h"
26
#include "llvm/IR/Instruction.h"
27
#include "llvm/IR/Instructions.h"
28
#include "llvm/IR/IntrinsicInst.h"
29
#include "llvm/IR/Intrinsics.h"
30
#include "llvm/IR/Type.h"
31
#include "llvm/IR/Value.h"
32
#include "llvm/Pass.h"
33
#include "llvm/Support/Casting.h"
34
#include <algorithm>
35
#include <cassert>
36
37
using namespace llvm;
38
39
#define DEBUG_TYPE "scalarize-masked-mem-intrin"
40
41
namespace {
42
43
class ScalarizeMaskedMemIntrin : public FunctionPass {
44
  const TargetTransformInfo *TTI = nullptr;
45
46
public:
47
  static char ID; // Pass identification, replacement for typeid
48
49
36.3k
  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
50
36.3k
    initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
51
36.3k
  }
52
53
  bool runOnFunction(Function &F) override;
54
55
499k
  StringRef getPassName() const override {
56
499k
    return "Scalarize Masked Memory Intrinsics";
57
499k
  }
58
59
36.1k
  void getAnalysisUsage(AnalysisUsage &AU) const override {
60
36.1k
    AU.addRequired<TargetTransformInfoWrapperPass>();
61
36.1k
  }
62
63
private:
64
  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65
  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66
};
67
68
} // end anonymous namespace
69
70
char ScalarizeMaskedMemIntrin::ID = 0;
71
72
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73
                "Scalarize unsupported masked memory intrinsics", false, false)
74
75
36.3k
FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
76
36.3k
  return new ScalarizeMaskedMemIntrin();
77
36.3k
}
78
79
621
static bool isConstantIntVector(Value *Mask) {
80
621
  Constant *C = dyn_cast<Constant>(Mask);
81
621
  if (!C)
82
536
    return false;
83
85
84
85
  unsigned NumElts = Mask->getType()->getVectorNumElements();
85
566
  for (unsigned i = 0; i != NumElts; 
++i481
) {
86
481
    Constant *CElt = C->getAggregateElement(i);
87
481
    if (!CElt || !isa<ConstantInt>(CElt))
88
0
      return false;
89
481
  }
90
85
91
85
  return true;
92
85
}
93
94
// Translate a masked load intrinsic like
95
// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96
//                               <16 x i1> %mask, <16 x i32> %passthru)
97
// to a chain of basic blocks, with loading element one-by-one if
98
// the appropriate mask bit is set
99
//
100
//  %1 = bitcast i8* %addr to i32*
101
//  %2 = extractelement <16 x i1> %mask, i32 0
102
//  br i1 %2, label %cond.load, label %else
103
//
104
// cond.load:                                        ; preds = %0
105
//  %3 = getelementptr i32* %1, i32 0
106
//  %4 = load i32* %3
107
//  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108
//  br label %else
109
//
110
// else:                                             ; preds = %0, %cond.load
111
//  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112
//  %6 = extractelement <16 x i1> %mask, i32 1
113
//  br i1 %6, label %cond.load1, label %else2
114
//
115
// cond.load1:                                       ; preds = %else
116
//  %7 = getelementptr i32* %1, i32 1
117
//  %8 = load i32* %7
118
//  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119
//  br label %else2
120
//
121
// else2:                                          ; preds = %else, %cond.load1
122
//  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123
//  %10 = extractelement <16 x i1> %mask, i32 2
124
//  br i1 %10, label %cond.load4, label %else5
125
//
126
131
static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127
131
  Value *Ptr = CI->getArgOperand(0);
128
131
  Value *Alignment = CI->getArgOperand(1);
129
131
  Value *Mask = CI->getArgOperand(2);
130
131
  Value *Src0 = CI->getArgOperand(3);
131
131
132
131
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133
131
  VectorType *VecType = cast<VectorType>(CI->getType());
134
131
135
131
  Type *EltTy = VecType->getElementType();
136
131
137
131
  IRBuilder<> Builder(CI->getContext());
138
131
  Instruction *InsertPt = CI;
139
131
  BasicBlock *IfBlock = CI->getParent();
140
131
141
131
  Builder.SetInsertPoint(InsertPt);
142
131
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143
131
144
131
  // Short-cut if the mask is all-true.
145
131
  if (isa<Constant>(Mask) && 
cast<Constant>(Mask)->isAllOnesValue()41
) {
146
3
    Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147
3
    CI->replaceAllUsesWith(NewI);
148
3
    CI->eraseFromParent();
149
3
    return;
150
3
  }
151
128
152
128
  // Adjust alignment for the scalar instruction.
153
128
  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154
128
  // Bitcast %addr from i8* to EltTy*
155
128
  Type *NewPtrType =
156
128
      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157
128
  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158
128
  unsigned VectorWidth = VecType->getNumElements();
159
128
160
128
  // The result vector
161
128
  Value *VResult = Src0;
162
128
163
128
  if (isConstantIntVector(Mask)) {
164
214
    for (unsigned Idx = 0; Idx < VectorWidth; 
++Idx176
) {
165
176
      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166
101
        continue;
167
75
      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168
75
      LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169
75
      VResult = Builder.CreateInsertElement(VResult, Load, Idx);
170
75
    }
171
38
    CI->replaceAllUsesWith(VResult);
172
38
    CI->eraseFromParent();
173
38
    return;
174
38
  }
175
90
176
772
  
for (unsigned Idx = 0; 90
Idx < VectorWidth;
++Idx682
) {
177
682
    // Fill the "else" block, created in the previous iteration
178
682
    //
179
682
    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180
682
    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
181
682
    //  br i1 %mask_1, label %cond.load, label %else
182
682
    //
183
682
184
682
    Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
185
682
186
682
    // Create "cond" block
187
682
    //
188
682
    //  %EltAddr = getelementptr i32* %1, i32 0
189
682
    //  %Elt = load i32* %EltAddr
190
682
    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
191
682
    //
192
682
    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
193
682
                                                     "cond.load");
194
682
    Builder.SetInsertPoint(InsertPt);
195
682
196
682
    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
197
682
    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
198
682
    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
199
682
200
682
    // Create "else" block, fill it in the next iteration
201
682
    BasicBlock *NewIfBlock =
202
682
        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
203
682
    Builder.SetInsertPoint(InsertPt);
204
682
    Instruction *OldBr = IfBlock->getTerminator();
205
682
    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
206
682
    OldBr->eraseFromParent();
207
682
    BasicBlock *PrevIfBlock = IfBlock;
208
682
    IfBlock = NewIfBlock;
209
682
210
682
    // Create the phi to join the new and previous value.
211
682
    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
212
682
    Phi->addIncoming(NewVResult, CondBlock);
213
682
    Phi->addIncoming(VResult, PrevIfBlock);
214
682
    VResult = Phi;
215
682
  }
216
90
217
90
  CI->replaceAllUsesWith(VResult);
218
90
  CI->eraseFromParent();
219
90
220
90
  ModifiedDT = true;
221
90
}
222
223
// Translate a masked store intrinsic, like
224
// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
225
//                               <16 x i1> %mask)
226
// to a chain of basic blocks, that stores element one-by-one if
227
// the appropriate mask bit is set
228
//
229
//   %1 = bitcast i8* %addr to i32*
230
//   %2 = extractelement <16 x i1> %mask, i32 0
231
//   br i1 %2, label %cond.store, label %else
232
//
233
// cond.store:                                       ; preds = %0
234
//   %3 = extractelement <16 x i32> %val, i32 0
235
//   %4 = getelementptr i32* %1, i32 0
236
//   store i32 %3, i32* %4
237
//   br label %else
238
//
239
// else:                                             ; preds = %0, %cond.store
240
//   %5 = extractelement <16 x i1> %mask, i32 1
241
//   br i1 %5, label %cond.store1, label %else2
242
//
243
// cond.store1:                                      ; preds = %else
244
//   %6 = extractelement <16 x i32> %val, i32 1
245
//   %7 = getelementptr i32* %1, i32 1
246
//   store i32 %6, i32* %7
247
//   br label %else2
248
//   . . .
249
376
static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
250
376
  Value *Src = CI->getArgOperand(0);
251
376
  Value *Ptr = CI->getArgOperand(1);
252
376
  Value *Alignment = CI->getArgOperand(2);
253
376
  Value *Mask = CI->getArgOperand(3);
254
376
255
376
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
256
376
  VectorType *VecType = cast<VectorType>(Src->getType());
257
376
258
376
  Type *EltTy = VecType->getElementType();
259
376
260
376
  IRBuilder<> Builder(CI->getContext());
261
376
  Instruction *InsertPt = CI;
262
376
  BasicBlock *IfBlock = CI->getParent();
263
376
  Builder.SetInsertPoint(InsertPt);
264
376
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
265
376
266
376
  // Short-cut if the mask is all-true.
267
376
  if (isa<Constant>(Mask) && 
cast<Constant>(Mask)->isAllOnesValue()17
) {
268
3
    Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269
3
    CI->eraseFromParent();
270
3
    return;
271
3
  }
272
373
273
373
  // Adjust alignment for the scalar instruction.
274
373
  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
275
373
  // Bitcast %addr from i8* to EltTy*
276
373
  Type *NewPtrType =
277
373
      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
278
373
  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279
373
  unsigned VectorWidth = VecType->getNumElements();
280
373
281
373
  if (isConstantIntVector(Mask)) {
282
70
    for (unsigned Idx = 0; Idx < VectorWidth; 
++Idx56
) {
283
56
      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
284
45
        continue;
285
11
      Value *OneElt = Builder.CreateExtractElement(Src, Idx);
286
11
      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
287
11
      Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
288
11
    }
289
14
    CI->eraseFromParent();
290
14
    return;
291
14
  }
292
359
293
3.57k
  
for (unsigned Idx = 0; 359
Idx < VectorWidth;
++Idx3.21k
) {
294
3.21k
    // Fill the "else" block, created in the previous iteration
295
3.21k
    //
296
3.21k
    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
297
3.21k
    //  br i1 %mask_1, label %cond.store, label %else
298
3.21k
    //
299
3.21k
    Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
300
3.21k
301
3.21k
    // Create "cond" block
302
3.21k
    //
303
3.21k
    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
304
3.21k
    //  %EltAddr = getelementptr i32* %1, i32 0
305
3.21k
    //  %store i32 %OneElt, i32* %EltAddr
306
3.21k
    //
307
3.21k
    BasicBlock *CondBlock =
308
3.21k
        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
309
3.21k
    Builder.SetInsertPoint(InsertPt);
310
3.21k
311
3.21k
    Value *OneElt = Builder.CreateExtractElement(Src, Idx);
312
3.21k
    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
313
3.21k
    Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
314
3.21k
315
3.21k
    // Create "else" block, fill it in the next iteration
316
3.21k
    BasicBlock *NewIfBlock =
317
3.21k
        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
318
3.21k
    Builder.SetInsertPoint(InsertPt);
319
3.21k
    Instruction *OldBr = IfBlock->getTerminator();
320
3.21k
    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
321
3.21k
    OldBr->eraseFromParent();
322
3.21k
    IfBlock = NewIfBlock;
323
3.21k
  }
324
359
  CI->eraseFromParent();
325
359
326
359
  ModifiedDT = true;
327
359
}
328
329
// Translate a masked gather intrinsic like
330
// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331
//                               <16 x i1> %Mask, <16 x i32> %Src)
332
// to a chain of basic blocks, with loading element one-by-one if
333
// the appropriate mask bit is set
334
//
335
// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336
// %Mask0 = extractelement <16 x i1> %Mask, i32 0
337
// br i1 %Mask0, label %cond.load, label %else
338
//
339
// cond.load:
340
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341
// %Load0 = load i32, i32* %Ptr0, align 4
342
// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
343
// br label %else
344
//
345
// else:
346
// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
348
// br i1 %Mask1, label %cond.load1, label %else2
349
//
350
// cond.load1:
351
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352
// %Load1 = load i32, i32* %Ptr1, align 4
353
// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
354
// br label %else2
355
// . . .
356
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
357
// ret <16 x i32> %Result
358
90
static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
359
90
  Value *Ptrs = CI->getArgOperand(0);
360
90
  Value *Alignment = CI->getArgOperand(1);
361
90
  Value *Mask = CI->getArgOperand(2);
362
90
  Value *Src0 = CI->getArgOperand(3);
363
90
364
90
  VectorType *VecType = cast<VectorType>(CI->getType());
365
90
  Type *EltTy = VecType->getElementType();
366
90
367
90
  IRBuilder<> Builder(CI->getContext());
368
90
  Instruction *InsertPt = CI;
369
90
  BasicBlock *IfBlock = CI->getParent();
370
90
  Builder.SetInsertPoint(InsertPt);
371
90
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
372
90
373
90
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
374
90
375
90
  // The result vector
376
90
  Value *VResult = Src0;
377
90
  unsigned VectorWidth = VecType->getNumElements();
378
90
379
90
  // Shorten the way if the mask is a vector of constants.
380
90
  if (isConstantIntVector(Mask)) {
381
267
    for (unsigned Idx = 0; Idx < VectorWidth; 
++Idx237
) {
382
237
      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
383
18
        continue;
384
219
      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
385
219
      LoadInst *Load =
386
219
          Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
387
219
      VResult =
388
219
          Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
389
219
    }
390
30
    CI->replaceAllUsesWith(VResult);
391
30
    CI->eraseFromParent();
392
30
    return;
393
30
  }
394
60
395
461
  
for (unsigned Idx = 0; 60
Idx < VectorWidth;
++Idx401
) {
396
401
    // Fill the "else" block, created in the previous iteration
397
401
    //
398
401
    //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
399
401
    //  br i1 %Mask1, label %cond.load, label %else
400
401
    //
401
401
402
401
    Value *Predicate =
403
401
        Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
404
401
405
401
    // Create "cond" block
406
401
    //
407
401
    //  %EltAddr = getelementptr i32* %1, i32 0
408
401
    //  %Elt = load i32* %EltAddr
409
401
    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
410
401
    //
411
401
    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
412
401
    Builder.SetInsertPoint(InsertPt);
413
401
414
401
    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
415
401
    LoadInst *Load =
416
401
        Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
417
401
    Value *NewVResult =
418
401
        Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
419
401
420
401
    // Create "else" block, fill it in the next iteration
421
401
    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
422
401
    Builder.SetInsertPoint(InsertPt);
423
401
    Instruction *OldBr = IfBlock->getTerminator();
424
401
    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
425
401
    OldBr->eraseFromParent();
426
401
    BasicBlock *PrevIfBlock = IfBlock;
427
401
    IfBlock = NewIfBlock;
428
401
429
401
    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
430
401
    Phi->addIncoming(NewVResult, CondBlock);
431
401
    Phi->addIncoming(VResult, PrevIfBlock);
432
401
    VResult = Phi;
433
401
  }
434
60
435
60
  CI->replaceAllUsesWith(VResult);
436
60
  CI->eraseFromParent();
437
60
438
60
  ModifiedDT = true;
439
60
}
440
441
// Translate a masked scatter intrinsic, like
442
// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
443
//                                  <16 x i1> %Mask)
444
// to a chain of basic blocks, that stores element one-by-one if
445
// the appropriate mask bit is set.
446
//
447
// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448
// %Mask0 = extractelement <16 x i1> %Mask, i32 0
449
// br i1 %Mask0, label %cond.store, label %else
450
//
451
// cond.store:
452
// %Elt0 = extractelement <16 x i32> %Src, i32 0
453
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454
// store i32 %Elt0, i32* %Ptr0, align 4
455
// br label %else
456
//
457
// else:
458
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
459
// br i1 %Mask1, label %cond.store1, label %else2
460
//
461
// cond.store1:
462
// %Elt1 = extractelement <16 x i32> %Src, i32 1
463
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464
// store i32 %Elt1, i32* %Ptr1, align 4
465
// br label %else2
466
//   . . .
467
30
static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
468
30
  Value *Src = CI->getArgOperand(0);
469
30
  Value *Ptrs = CI->getArgOperand(1);
470
30
  Value *Alignment = CI->getArgOperand(2);
471
30
  Value *Mask = CI->getArgOperand(3);
472
30
473
30
  assert(isa<VectorType>(Src->getType()) &&
474
30
         "Unexpected data type in masked scatter intrinsic");
475
30
  assert(isa<VectorType>(Ptrs->getType()) &&
476
30
         isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
477
30
         "Vector of pointers is expected in masked scatter intrinsic");
478
30
479
30
  IRBuilder<> Builder(CI->getContext());
480
30
  Instruction *InsertPt = CI;
481
30
  BasicBlock *IfBlock = CI->getParent();
482
30
  Builder.SetInsertPoint(InsertPt);
483
30
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
484
30
485
30
  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
486
30
  unsigned VectorWidth = Src->getType()->getVectorNumElements();
487
30
488
30
  // Shorten the way if the mask is a vector of constants.
489
30
  if (isConstantIntVector(Mask)) {
490
15
    for (unsigned Idx = 0; Idx < VectorWidth; 
++Idx12
) {
491
12
      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
492
2
        continue;
493
10
      Value *OneElt =
494
10
          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
495
10
      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
496
10
      Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
497
10
    }
498
3
    CI->eraseFromParent();
499
3
    return;
500
3
  }
501
27
502
180
  
for (unsigned Idx = 0; 27
Idx < VectorWidth;
++Idx153
) {
503
153
    // Fill the "else" block, created in the previous iteration
504
153
    //
505
153
    //  %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506
153
    //  br i1 %Mask1, label %cond.store, label %else
507
153
    //
508
153
    Value *Predicate =
509
153
        Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
510
153
511
153
    // Create "cond" block
512
153
    //
513
153
    //  %Elt1 = extractelement <16 x i32> %Src, i32 1
514
153
    //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515
153
    //  %store i32 %Elt1, i32* %Ptr1
516
153
    //
517
153
    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
518
153
    Builder.SetInsertPoint(InsertPt);
519
153
520
153
    Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
521
153
    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
522
153
    Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
523
153
524
153
    // Create "else" block, fill it in the next iteration
525
153
    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
526
153
    Builder.SetInsertPoint(InsertPt);
527
153
    Instruction *OldBr = IfBlock->getTerminator();
528
153
    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
529
153
    OldBr->eraseFromParent();
530
153
    IfBlock = NewIfBlock;
531
153
  }
532
27
  CI->eraseFromParent();
533
27
534
27
  ModifiedDT = true;
535
27
}
536
537
63
static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
538
63
  Value *Ptr = CI->getArgOperand(0);
539
63
  Value *Mask = CI->getArgOperand(1);
540
63
  Value *PassThru = CI->getArgOperand(2);
541
63
542
63
  VectorType *VecType = cast<VectorType>(CI->getType());
543
63
544
63
  Type *EltTy = VecType->getElementType();
545
63
546
63
  IRBuilder<> Builder(CI->getContext());
547
63
  Instruction *InsertPt = CI;
548
63
  BasicBlock *IfBlock = CI->getParent();
549
63
550
63
  Builder.SetInsertPoint(InsertPt);
551
63
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
552
63
553
63
  unsigned VectorWidth = VecType->getNumElements();
554
63
555
63
  // The result vector
556
63
  Value *VResult = PassThru;
557
63
558
665
  for (unsigned Idx = 0; Idx < VectorWidth; 
++Idx602
) {
559
602
    // Fill the "else" block, created in the previous iteration
560
602
    //
561
602
    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562
602
    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563
602
    //  br i1 %mask_1, label %cond.load, label %else
564
602
    //
565
602
566
602
    Value *Predicate =
567
602
        Builder.CreateExtractElement(Mask, Idx);
568
602
569
602
    // Create "cond" block
570
602
    //
571
602
    //  %EltAddr = getelementptr i32* %1, i32 0
572
602
    //  %Elt = load i32* %EltAddr
573
602
    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
574
602
    //
575
602
    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
576
602
                                                     "cond.load");
577
602
    Builder.SetInsertPoint(InsertPt);
578
602
579
602
    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
580
602
    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
581
602
582
602
    // Move the pointer if there are more blocks to come.
583
602
    Value *NewPtr;
584
602
    if ((Idx + 1) != VectorWidth)
585
539
      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
586
602
587
602
    // Create "else" block, fill it in the next iteration
588
602
    BasicBlock *NewIfBlock =
589
602
        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
590
602
    Builder.SetInsertPoint(InsertPt);
591
602
    Instruction *OldBr = IfBlock->getTerminator();
592
602
    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
593
602
    OldBr->eraseFromParent();
594
602
    BasicBlock *PrevIfBlock = IfBlock;
595
602
    IfBlock = NewIfBlock;
596
602
597
602
    // Create the phi to join the new and previous value.
598
602
    PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
599
602
    ResultPhi->addIncoming(NewVResult, CondBlock);
600
602
    ResultPhi->addIncoming(VResult, PrevIfBlock);
601
602
    VResult = ResultPhi;
602
602
603
602
    // Add a PHI for the pointer if this isn't the last iteration.
604
602
    if ((Idx + 1) != VectorWidth) {
605
539
      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
606
539
      PtrPhi->addIncoming(NewPtr, CondBlock);
607
539
      PtrPhi->addIncoming(Ptr, PrevIfBlock);
608
539
      Ptr = PtrPhi;
609
539
    }
610
602
  }
611
63
612
63
  CI->replaceAllUsesWith(VResult);
613
63
  CI->eraseFromParent();
614
63
615
63
  ModifiedDT = true;
616
63
}
617
618
64
static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
619
64
  Value *Src = CI->getArgOperand(0);
620
64
  Value *Ptr = CI->getArgOperand(1);
621
64
  Value *Mask = CI->getArgOperand(2);
622
64
623
64
  VectorType *VecType = cast<VectorType>(Src->getType());
624
64
625
64
  IRBuilder<> Builder(CI->getContext());
626
64
  Instruction *InsertPt = CI;
627
64
  BasicBlock *IfBlock = CI->getParent();
628
64
629
64
  Builder.SetInsertPoint(InsertPt);
630
64
  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
631
64
632
64
  Type *EltTy = VecType->getVectorElementType();
633
64
634
64
  unsigned VectorWidth = VecType->getNumElements();
635
64
636
674
  for (unsigned Idx = 0; Idx < VectorWidth; 
++Idx610
) {
637
610
    // Fill the "else" block, created in the previous iteration
638
610
    //
639
610
    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640
610
    //  br i1 %mask_1, label %cond.store, label %else
641
610
    //
642
610
    Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
643
610
644
610
    // Create "cond" block
645
610
    //
646
610
    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
647
610
    //  %EltAddr = getelementptr i32* %1, i32 0
648
610
    //  %store i32 %OneElt, i32* %EltAddr
649
610
    //
650
610
    BasicBlock *CondBlock =
651
610
        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
652
610
    Builder.SetInsertPoint(InsertPt);
653
610
654
610
    Value *OneElt = Builder.CreateExtractElement(Src, Idx);
655
610
    Builder.CreateAlignedStore(OneElt, Ptr, 1);
656
610
657
610
    // Move the pointer if there are more blocks to come.
658
610
    Value *NewPtr;
659
610
    if ((Idx + 1) != VectorWidth)
660
546
      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
661
610
662
610
    // Create "else" block, fill it in the next iteration
663
610
    BasicBlock *NewIfBlock =
664
610
        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
665
610
    Builder.SetInsertPoint(InsertPt);
666
610
    Instruction *OldBr = IfBlock->getTerminator();
667
610
    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
668
610
    OldBr->eraseFromParent();
669
610
    BasicBlock *PrevIfBlock = IfBlock;
670
610
    IfBlock = NewIfBlock;
671
610
672
610
    // Add a PHI for the pointer if this isn't the last iteration.
673
610
    if ((Idx + 1) != VectorWidth) {
674
546
      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
675
546
      PtrPhi->addIncoming(NewPtr, CondBlock);
676
546
      PtrPhi->addIncoming(Ptr, PrevIfBlock);
677
546
      Ptr = PtrPhi;
678
546
    }
679
610
  }
680
64
  CI->eraseFromParent();
681
64
682
64
  ModifiedDT = true;
683
64
}
684
685
499k
bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
686
499k
  bool EverMadeChange = false;
687
499k
688
499k
  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
689
499k
690
499k
  bool MadeChange = true;
691
999k
  while (MadeChange) {
692
500k
    MadeChange = false;
693
3.21M
    for (Function::iterator I = F.begin(); I != F.end();) {
694
2.71M
      BasicBlock *BB = &*I++;
695
2.71M
      bool ModifiedDTOnIteration = false;
696
2.71M
      MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
697
2.71M
698
2.71M
      // Restart BB iteration if the dominator tree of the Function was changed
699
2.71M
      if (ModifiedDTOnIteration)
700
663
        break;
701
2.71M
    }
702
500k
703
500k
    EverMadeChange |= MadeChange;
704
500k
  }
705
499k
706
499k
  return EverMadeChange;
707
499k
}
708
709
2.71M
bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
710
2.71M
  bool MadeChange = false;
711
2.71M
712
2.71M
  BasicBlock::iterator CurInstIterator = BB.begin();
713
17.1M
  while (CurInstIterator != BB.end()) {
714
14.4M
    if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
715
1.84M
      MadeChange |= optimizeCallInst(CI, ModifiedDT);
716
14.4M
    if (ModifiedDT)
717
663
      return true;
718
14.4M
  }
719
2.71M
720
2.71M
  
return MadeChange2.71M
;
721
2.71M
}
722
723
bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
724
1.84M
                                                bool &ModifiedDT) {
725
1.84M
  IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
726
1.84M
  if (II) {
727
390k
    switch (II->getIntrinsicID()) {
728
390k
    default:
729
387k
      break;
730
390k
    case Intrinsic::masked_load:
731
512
      // Scalarize unsupported vector masked load
732
512
      if (TTI->isLegalMaskedLoad(CI->getType()))
733
381
        return false;
734
131
      scalarizeMaskedLoad(CI, ModifiedDT);
735
131
      return true;
736
770
    case Intrinsic::masked_store:
737
770
      if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
738
394
        return false;
739
376
      scalarizeMaskedStore(CI, ModifiedDT);
740
376
      return true;
741
414
    case Intrinsic::masked_gather:
742
414
      if (TTI->isLegalMaskedGather(CI->getType()))
743
324
        return false;
744
90
      scalarizeMaskedGather(CI, ModifiedDT);
745
90
      return true;
746
133
    case Intrinsic::masked_scatter:
747
133
      if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
748
103
        return false;
749
30
      scalarizeMaskedScatter(CI, ModifiedDT);
750
30
      return true;
751
314
    case Intrinsic::masked_expandload:
752
314
      if (TTI->isLegalMaskedExpandLoad(CI->getType()))
753
251
        return false;
754
63
      scalarizeMaskedExpandLoad(CI, ModifiedDT);
755
63
      return true;
756
229
    case Intrinsic::masked_compressstore:
757
229
      if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
758
165
        return false;
759
64
      scalarizeMaskedCompressStore(CI, ModifiedDT);
760
64
      return true;
761
390k
    }
762
390k
  }
763
1.84M
764
1.84M
  return false;
765
1.84M
}