Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
// This pass builds the coroutine frame and outlines resume and destroy parts
9
// of the coroutine into separate functions.
10
//
11
// We present a coroutine to an LLVM as an ordinary function with suspension
12
// points marked up with intrinsics. We let the optimizer party on the coroutine
13
// as a single function for as long as possible. Shortly before the coroutine is
14
// eligible to be inlined into its callers, we split up the coroutine into parts
15
// corresponding to an initial, resume and destroy invocations of the coroutine,
16
// add them to the current SCC and restart the IPO pipeline to optimize the
17
// coroutine subfunctions we extracted before proceeding to the caller of the
18
// coroutine.
19
//===----------------------------------------------------------------------===//
20
21
#include "CoroInstr.h"
22
#include "CoroInternal.h"
23
#include "llvm/ADT/DenseMap.h"
24
#include "llvm/ADT/SmallPtrSet.h"
25
#include "llvm/ADT/SmallVector.h"
26
#include "llvm/ADT/StringRef.h"
27
#include "llvm/ADT/Twine.h"
28
#include "llvm/Analysis/CallGraph.h"
29
#include "llvm/Analysis/CallGraphSCCPass.h"
30
#include "llvm/Transforms/Utils/Local.h"
31
#include "llvm/IR/Argument.h"
32
#include "llvm/IR/Attributes.h"
33
#include "llvm/IR/BasicBlock.h"
34
#include "llvm/IR/CFG.h"
35
#include "llvm/IR/CallSite.h"
36
#include "llvm/IR/CallingConv.h"
37
#include "llvm/IR/Constants.h"
38
#include "llvm/IR/DataLayout.h"
39
#include "llvm/IR/DerivedTypes.h"
40
#include "llvm/IR/Function.h"
41
#include "llvm/IR/GlobalValue.h"
42
#include "llvm/IR/GlobalVariable.h"
43
#include "llvm/IR/IRBuilder.h"
44
#include "llvm/IR/InstIterator.h"
45
#include "llvm/IR/InstrTypes.h"
46
#include "llvm/IR/Instruction.h"
47
#include "llvm/IR/Instructions.h"
48
#include "llvm/IR/IntrinsicInst.h"
49
#include "llvm/IR/LLVMContext.h"
50
#include "llvm/IR/LegacyPassManager.h"
51
#include "llvm/IR/Module.h"
52
#include "llvm/IR/Type.h"
53
#include "llvm/IR/Value.h"
54
#include "llvm/IR/Verifier.h"
55
#include "llvm/Pass.h"
56
#include "llvm/Support/Casting.h"
57
#include "llvm/Support/Debug.h"
58
#include "llvm/Support/raw_ostream.h"
59
#include "llvm/Transforms/Scalar.h"
60
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
61
#include "llvm/Transforms/Utils/Cloning.h"
62
#include "llvm/Transforms/Utils/ValueMapper.h"
63
#include <cassert>
64
#include <cstddef>
65
#include <cstdint>
66
#include <initializer_list>
67
#include <iterator>
68
69
using namespace llvm;
70
71
#define DEBUG_TYPE "coro-split"
72
73
// Create an entry block for a resume function with a switch that will jump to
74
// suspend points.
75
33
static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
76
33
  LLVMContext &C = F.getContext();
77
33
78
33
  // resume.entry:
79
33
  //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
80
33
  //  i32 2
81
33
  //  % index = load i32, i32* %index.addr
82
33
  //  switch i32 %index, label %unreachable [
83
33
  //    i32 0, label %resume.0
84
33
  //    i32 1, label %resume.1
85
33
  //    ...
86
33
  //  ]
87
33
88
33
  auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
89
33
  auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
90
33
91
33
  IRBuilder<> Builder(NewEntry);
92
33
  auto *FramePtr = Shape.FramePtr;
93
33
  auto *FrameTy = Shape.FrameTy;
94
33
  auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
95
33
      FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
96
33
  auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
97
33
  auto *Switch =
98
33
      Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
99
33
  Shape.ResumeSwitch = Switch;
100
33
101
33
  size_t SuspendIndex = 0;
102
39
  for (CoroSuspendInst *S : Shape.CoroSuspends) {
103
39
    ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
104
39
105
39
    // Replace CoroSave with a store to Index:
106
39
    //    %index.addr = getelementptr %f.frame... (index field number)
107
39
    //    store i32 0, i32* %index.addr1
108
39
    auto *Save = S->getCoroSave();
109
39
    Builder.SetInsertPoint(Save);
110
39
    if (S->isFinal()) {
111
2
      // Final suspend point is represented by storing zero in ResumeFnAddr.
112
2
      auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
113
2
                                                          0, "ResumeFn.addr");
114
2
      auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
115
2
          cast<PointerType>(GepIndex->getType())->getElementType()));
116
2
      Builder.CreateStore(NullPtr, GepIndex);
117
37
    } else {
118
37
      auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
119
37
          FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
120
37
      Builder.CreateStore(IndexVal, GepIndex);
121
37
    }
122
39
    Save->replaceAllUsesWith(ConstantTokenNone::get(C));
123
39
    Save->eraseFromParent();
124
39
125
39
    // Split block before and after coro.suspend and add a jump from an entry
126
39
    // switch:
127
39
    //
128
39
    //  whateverBB:
129
39
    //    whatever
130
39
    //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
131
39
    //    switch i8 %0, label %suspend[i8 0, label %resume
132
39
    //                                 i8 1, label %cleanup]
133
39
    // becomes:
134
39
    //
135
39
    //  whateverBB:
136
39
    //     whatever
137
39
    //     br label %resume.0.landing
138
39
    //
139
39
    //  resume.0: ; <--- jump from the switch in the resume.entry
140
39
    //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
141
39
    //     br label %resume.0.landing
142
39
    //
143
39
    //  resume.0.landing:
144
39
    //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
145
39
    //     switch i8 % 1, label %suspend [i8 0, label %resume
146
39
    //                                    i8 1, label %cleanup]
147
39
148
39
    auto *SuspendBB = S->getParent();
149
39
    auto *ResumeBB =
150
39
        SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
151
39
    auto *LandingBB = ResumeBB->splitBasicBlock(
152
39
        S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
153
39
    Switch->addCase(IndexVal, ResumeBB);
154
39
155
39
    cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
156
39
    auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
157
39
    S->replaceAllUsesWith(PN);
158
39
    PN->addIncoming(Builder.getInt8(-1), SuspendBB);
159
39
    PN->addIncoming(S, ResumeBB);
160
39
161
39
    ++SuspendIndex;
162
39
  }
163
33
164
33
  Builder.SetInsertPoint(UnreachBB);
165
33
  Builder.CreateUnreachable();
166
33
167
33
  return NewEntry;
168
33
}
169
170
// In Resumers, we replace fallthrough coro.end with ret void and delete the
171
// rest of the block.
172
static void replaceFallthroughCoroEnd(IntrinsicInst *End,
173
99
                                      ValueToValueMapTy &VMap) {
174
99
  auto *NewE = cast<IntrinsicInst>(VMap[End]);
175
99
  ReturnInst::Create(NewE->getContext(), nullptr, NewE);
176
99
177
99
  // Remove the rest of the block, by splitting it into an unreachable block.
178
99
  auto *BB = NewE->getParent();
179
99
  BB->splitBasicBlock(NewE);
180
99
  BB->getTerminator()->eraseFromParent();
181
99
}
182
183
// In Resumers, we replace unwind coro.end with True to force the immediate
184
// unwind to caller.
185
99
static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
186
99
  if (Shape.CoroEnds.empty())
187
0
    return;
188
99
189
99
  LLVMContext &Context = Shape.CoroEnds.front()->getContext();
190
99
  auto *True = ConstantInt::getTrue(Context);
191
105
  for (CoroEndInst *CE : Shape.CoroEnds) {
192
105
    if (!CE->isUnwind())
193
99
      continue;
194
6
195
6
    auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
196
6
197
6
    // If coro.end has an associated bundle, add cleanupret instruction.
198
6
    if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
199
3
      Value *FromPad = Bundle->Inputs[0];
200
3
      auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
201
3
      NewCE->getParent()->splitBasicBlock(NewCE);
202
3
      CleanupRet->getParent()->getTerminator()->eraseFromParent();
203
3
    }
204
6
205
6
    NewCE->replaceAllUsesWith(True);
206
6
    NewCE->eraseFromParent();
207
6
  }
208
99
}
209
210
// Rewrite final suspend point handling. We do not use suspend index to
211
// represent the final suspend point. Instead we zero-out ResumeFnAddr in the
212
// coroutine frame, since it is undefined behavior to resume a coroutine
213
// suspended at the final suspend point. Thus, in the resume function, we can
214
// simply remove the last case (when coro::Shape is built, the final suspend
215
// point (if present) is always the last element of CoroSuspends array).
216
// In the destroy function, we add a code sequence to check if ResumeFnAddress
217
// is Null, and if so, jump to the appropriate label to handle cleanup from the
218
// final suspend point.
219
static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
220
                               coro::Shape &Shape, SwitchInst *Switch,
221
6
                               bool IsDestroy) {
222
6
  assert(Shape.HasFinalSuspend);
223
6
  auto FinalCaseIt = std::prev(Switch->case_end());
224
6
  BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
225
6
  Switch->removeCase(FinalCaseIt);
226
6
  if (IsDestroy) {
227
4
    BasicBlock *OldSwitchBB = Switch->getParent();
228
4
    auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
229
4
    Builder.SetInsertPoint(OldSwitchBB->getTerminator());
230
4
    auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
231
4
                                                        0, 0, "ResumeFn.addr");
232
4
    auto *Load = Builder.CreateLoad(
233
4
        Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex);
234
4
    auto *NullPtr =
235
4
        ConstantPointerNull::get(cast<PointerType>(Load->getType()));
236
4
    auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
237
4
    Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
238
4
    OldSwitchBB->getTerminator()->eraseFromParent();
239
4
  }
240
6
}
241
242
// Create a resume clone by cloning the body of the original function, setting
243
// new entry block and replacing coro.suspend an appropriate value to force
244
// resume or cleanup pass for every suspend point.
245
static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
246
99
                             BasicBlock *ResumeEntry, int8_t FnIndex) {
247
99
  Module *M = F.getParent();
248
99
  auto *FrameTy = Shape.FrameTy;
249
99
  auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
250
99
  auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
251
99
252
99
  Function *NewF =
253
99
      Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage,
254
99
                       F.getName() + Suffix, M);
255
99
  NewF->addParamAttr(0, Attribute::NonNull);
256
99
  NewF->addParamAttr(0, Attribute::NoAlias);
257
99
258
99
  ValueToValueMapTy VMap;
259
99
  // Replace all args with undefs. The buildCoroutineFrame algorithm already
260
99
  // rewritten access to the args that occurs after suspend points with loads
261
99
  // and stores to/from the coroutine frame.
262
99
  for (Argument &A : F.args())
263
81
    VMap[&A] = UndefValue::get(A.getType());
264
99
265
99
  SmallVector<ReturnInst *, 4> Returns;
266
99
267
99
  CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
268
99
  NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
269
99
270
99
  // Remove old returns.
271
99
  for (ReturnInst *Return : Returns)
272
99
    changeToUnreachable(Return, /*UseLLVMTrap=*/false);
273
99
274
99
  // Remove old return attributes.
275
99
  NewF->removeAttributes(
276
99
      AttributeList::ReturnIndex,
277
99
      AttributeFuncs::typeIncompatible(NewF->getReturnType()));
278
99
279
99
  // Make AllocaSpillBlock the new entry block.
280
99
  auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
281
99
  auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
282
99
  Entry->moveBefore(&NewF->getEntryBlock());
283
99
  Entry->getTerminator()->eraseFromParent();
284
99
  BranchInst::Create(SwitchBB, Entry);
285
99
  Entry->setName("entry" + Suffix);
286
99
287
99
  // Clear all predecessors of the new entry block.
288
99
  auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
289
99
  Entry->replaceAllUsesWith(Switch->getDefaultDest());
290
99
291
99
  IRBuilder<> Builder(&NewF->getEntryBlock().front());
292
99
293
99
  // Remap frame pointer.
294
99
  Argument *NewFramePtr = &*NewF->arg_begin();
295
99
  Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
296
99
  NewFramePtr->takeName(OldFramePtr);
297
99
  OldFramePtr->replaceAllUsesWith(NewFramePtr);
298
99
299
99
  // Remap vFrame pointer.
300
99
  auto *NewVFrame = Builder.CreateBitCast(
301
99
      NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
302
99
  Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
303
99
  OldVFrame->replaceAllUsesWith(NewVFrame);
304
99
305
99
  // Rewrite final suspend handling as it is not done via switch (allows to
306
99
  // remove final case from the switch, since it is undefined behavior to resume
307
99
  // the coroutine suspended at the final suspend point.
308
99
  if (Shape.HasFinalSuspend) {
309
6
    auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
310
6
    bool IsDestroy = FnIndex != 0;
311
6
    handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
312
6
  }
313
99
314
99
  // Replace coro suspend with the appropriate resume index.
315
99
  // Replacing coro.suspend with (0) will result in control flow proceeding to
316
99
  // a resume label associated with a suspend point, replacing it with (1) will
317
99
  // result in control flow proceeding to a cleanup label associated with this
318
99
  // suspend point.
319
99
  auto *NewValue = Builder.getInt8(FnIndex ? 
166
:
033
);
320
117
  for (CoroSuspendInst *CS : Shape.CoroSuspends) {
321
117
    auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
322
117
    MappedCS->replaceAllUsesWith(NewValue);
323
117
    MappedCS->eraseFromParent();
324
117
  }
325
99
326
99
  // Remove coro.end intrinsics.
327
99
  replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
328
99
  replaceUnwindCoroEnds(Shape, VMap);
329
99
  // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
330
99
  // to suppress deallocation code.
331
99
  coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
332
99
                        /*Elide=*/FnIndex == 2);
333
99
334
99
  NewF->setCallingConv(CallingConv::Fast);
335
99
336
99
  return NewF;
337
99
}
338
339
37
static void removeCoroEnds(coro::Shape &Shape) {
340
37
  if (Shape.CoroEnds.empty())
341
0
    return;
342
37
343
37
  LLVMContext &Context = Shape.CoroEnds.front()->getContext();
344
37
  auto *False = ConstantInt::getFalse(Context);
345
37
346
39
  for (CoroEndInst *CE : Shape.CoroEnds) {
347
39
    CE->replaceAllUsesWith(False);
348
39
    CE->eraseFromParent();
349
39
  }
350
37
}
351
352
37
static void replaceFrameSize(coro::Shape &Shape) {
353
37
  if (Shape.CoroSizes.empty())
354
4
    return;
355
33
356
33
  // In the same function all coro.sizes should have the same result type.
357
33
  auto *SizeIntrin = Shape.CoroSizes.back();
358
33
  Module *M = SizeIntrin->getModule();
359
33
  const DataLayout &DL = M->getDataLayout();
360
33
  auto Size = DL.getTypeAllocSize(Shape.FrameTy);
361
33
  auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
362
33
363
33
  for (CoroSizeInst *CS : Shape.CoroSizes) {
364
33
    CS->replaceAllUsesWith(SizeConstant);
365
33
    CS->eraseFromParent();
366
33
  }
367
33
}
368
369
// Create a global constant array containing pointers to functions provided and
370
// set Info parameter of CoroBegin to point at this constant. Example:
371
//
372
//   @f.resumers = internal constant [2 x void(%f.frame*)*]
373
//                    [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
374
//   define void @f() {
375
//     ...
376
//     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
377
//                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
378
//
379
// Assumes that all the functions have the same signature.
380
static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
381
33
                        std::initializer_list<Function *> Fns) {
382
33
  SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
383
33
  assert(!Args.empty());
384
33
  Function *Part = *Fns.begin();
385
33
  Module *M = Part->getParent();
386
33
  auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
387
33
388
33
  auto *ConstVal = ConstantArray::get(ArrTy, Args);
389
33
  auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
390
33
                                GlobalVariable::PrivateLinkage, ConstVal,
391
33
                                F.getName() + Twine(".resumers"));
392
33
393
33
  // Update coro.begin instruction to refer to this constant.
394
33
  LLVMContext &C = F.getContext();
395
33
  auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
396
33
  CoroBegin->getId()->setInfo(BC);
397
33
}
398
399
// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
400
static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
401
33
                            Function *DestroyFn, Function *CleanupFn) {
402
33
  IRBuilder<> Builder(Shape.FramePtr->getNextNode());
403
33
  auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
404
33
      Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
405
33
      "resume.addr");
406
33
  Builder.CreateStore(ResumeFn, ResumeAddr);
407
33
408
33
  Value *DestroyOrCleanupFn = DestroyFn;
409
33
410
33
  CoroIdInst *CoroId = Shape.CoroBegin->getId();
411
33
  if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
412
10
    // If there is a CoroAlloc and it returns false (meaning we elide the
413
10
    // allocation, use CleanupFn instead of DestroyFn).
414
10
    DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
415
10
  }
416
33
417
33
  auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
418
33
      Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
419
33
      "destroy.addr");
420
33
  Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
421
33
}
422
423
136
static void postSplitCleanup(Function &F) {
424
136
  removeUnreachableBlocks(F);
425
136
  legacy::FunctionPassManager FPM(F.getParent());
426
136
427
136
  FPM.add(createVerifierPass());
428
136
  FPM.add(createSCCPPass());
429
136
  FPM.add(createCFGSimplificationPass());
430
136
  FPM.add(createEarlyCSEPass());
431
136
  FPM.add(createCFGSimplificationPass());
432
136
433
136
  FPM.doInitialization();
434
136
  FPM.run(F);
435
136
  FPM.doFinalization();
436
136
}
437
438
// Assuming we arrived at the block NewBlock from Prev instruction, store
439
// PHI's incoming values in the ResolvedValues map.
440
static void
441
scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
442
1
                          DenseMap<Value *, Value *> &ResolvedValues) {
443
1
  auto *PrevBB = Prev->getParent();
444
1
  for (PHINode &PN : NewBlock->phis()) {
445
0
    auto V = PN.getIncomingValueForBlock(PrevBB);
446
0
    // See if we already resolved it.
447
0
    auto VI = ResolvedValues.find(V);
448
0
    if (VI != ResolvedValues.end())
449
0
      V = VI->second;
450
0
    // Remember the value.
451
0
    ResolvedValues[&PN] = V;
452
0
  }
453
1
}
454
455
// Replace a sequence of branches leading to a ret, with a clone of a ret
456
// instruction. Suspend instruction represented by a switch, track the PHI
457
// values and select the correct case successor when possible.
458
1
static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
459
1
  DenseMap<Value *, Value *> ResolvedValues;
460
1
461
1
  Instruction *I = InitialInst;
462
2
  while (I->isTerminator()) {
463
2
    if (isa<ReturnInst>(I)) {
464
1
      if (I != InitialInst)
465
1
        ReplaceInstWithInst(InitialInst, I->clone());
466
1
      return true;
467
1
    }
468
1
    if (auto *BR = dyn_cast<BranchInst>(I)) {
469
1
      if (BR->isUnconditional()) {
470
1
        BasicBlock *BB = BR->getSuccessor(0);
471
1
        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
472
1
        I = BB->getFirstNonPHIOrDbgOrLifetime();
473
1
        continue;
474
1
      }
475
0
    } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
476
0
      Value *V = SI->getCondition();
477
0
      auto it = ResolvedValues.find(V);
478
0
      if (it != ResolvedValues.end())
479
0
        V = it->second;
480
0
      if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
481
0
        BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
482
0
        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
483
0
        I = BB->getFirstNonPHIOrDbgOrLifetime();
484
0
        continue;
485
0
      }
486
0
    }
487
0
    return false;
488
0
  }
489
1
  
return false0
;
490
1
}
491
492
// Add musttail to any resume instructions that is immediately followed by a
493
// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
494
// for symmetrical coroutine control transfer (C++ Coroutines TS extension).
495
// This transformation is done only in the resume part of the coroutine that has
496
// identical signature and calling convention as the coro.resume call.
497
33
static void addMustTailToCoroResumes(Function &F) {
498
33
  bool changed = false;
499
33
500
33
  // Collect potential resume instructions.
501
33
  SmallVector<CallInst *, 4> Resumes;
502
33
  for (auto &I : instructions(F))
503
427
    if (auto *Call = dyn_cast<CallInst>(&I))
504
75
      if (auto *CalledValue = Call->getCalledValue())
505
75
        // CoroEarly pass replaced coro resumes with indirect calls to an
506
75
        // address return by CoroSubFnInst intrinsic. See if it is one of those.
507
75
        if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
508
1
          Resumes.push_back(Call);
509
33
510
33
  // Set musttail on those that are followed by a ret instruction.
511
33
  for (CallInst *Call : Resumes)
512
1
    if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
513
1
      Call->setTailCallKind(CallInst::TCK_MustTail);
514
1
      changed = true;
515
1
    }
516
33
517
33
  if (changed)
518
1
    removeUnreachableBlocks(F);
519
33
}
520
521
// Coroutine has no suspend points. Remove heap allocation for the coroutine
522
// frame if possible.
523
4
static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
524
4
  auto *CoroId = CoroBegin->getId();
525
4
  auto *AllocInst = CoroId->getCoroAlloc();
526
4
  coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
527
4
  if (AllocInst) {
528
4
    IRBuilder<> Builder(AllocInst);
529
4
    // FIXME: Need to handle overaligned members.
530
4
    auto *Frame = Builder.CreateAlloca(FrameTy);
531
4
    auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
532
4
    AllocInst->replaceAllUsesWith(Builder.getFalse());
533
4
    AllocInst->eraseFromParent();
534
4
    CoroBegin->replaceAllUsesWith(VFrame);
535
4
  } else {
536
0
    CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
537
0
  }
538
4
  CoroBegin->eraseFromParent();
539
4
}
540
541
// SimplifySuspendPoint needs to check that there is no calls between
542
// coro_save and coro_suspend, since any of the calls may potentially resume
543
// the coroutine and if that is the case we cannot eliminate the suspend point.
544
10
static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
545
25
  for (Instruction *I = From; I != To; 
I = I->getNextNode()15
) {
546
17
    // Assume that no intrinsic can resume the coroutine.
547
17
    if (isa<IntrinsicInst>(I))
548
7
      continue;
549
10
550
10
    if (CallSite(I))
551
2
      return true;
552
10
  }
553
10
  
return false8
;
554
10
}
555
556
2
static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
557
2
  SmallPtrSet<BasicBlock *, 8> Set;
558
2
  SmallVector<BasicBlock *, 8> Worklist;
559
2
560
2
  Set.insert(SaveBB);
561
2
  Worklist.push_back(ResDesBB);
562
2
563
2
  // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
564
2
  // returns a token consumed by suspend instruction, all blocks in between
565
2
  // will have to eventually hit SaveBB when going backwards from ResDesBB.
566
7
  while (!Worklist.empty()) {
567
5
    auto *BB = Worklist.pop_back_val();
568
5
    Set.insert(BB);
569
5
    for (auto *Pred : predecessors(BB))
570
6
      if (Set.count(Pred) == 0)
571
3
        Worklist.push_back(Pred);
572
5
  }
573
2
574
2
  // SaveBB and ResDesBB are checked separately in hasCallsBetween.
575
2
  Set.erase(SaveBB);
576
2
  Set.erase(ResDesBB);
577
2
578
2
  for (auto *BB : Set)
579
3
    if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
580
1
      return true;
581
2
582
2
  
return false1
;
583
2
}
584
585
5
static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
586
5
  auto *SaveBB = Save->getParent();
587
5
  auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
588
5
589
5
  if (SaveBB == ResumeOrDestroyBB)
590
2
    return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
591
3
592
3
  // Any calls from Save to the end of the block?
593
3
  if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
594
1
    return true;
595
2
596
2
  // Any calls from begging of the block up to ResumeOrDestroy?
597
2
  if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
598
2
                             ResumeOrDestroy))
599
0
    return true;
600
2
601
2
  // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
602
2
  if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
603
1
    return true;
604
1
605
1
  return false;
606
1
}
607
608
// If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
609
// suspend point and replace it with nornal control flow.
610
static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
611
42
                                 CoroBeginInst *CoroBegin) {
612
42
  Instruction *Prev = Suspend->getPrevNode();
613
42
  if (!Prev) {
614
1
    auto *Pred = Suspend->getParent()->getSinglePredecessor();
615
1
    if (!Pred)
616
0
      return false;
617
1
    Prev = Pred->getTerminator();
618
1
  }
619
42
620
42
  CallSite CS{Prev};
621
42
  if (!CS)
622
1
    return false;
623
41
624
41
  auto *CallInstr = CS.getInstruction();
625
41
626
41
  auto *Callee = CS.getCalledValue()->stripPointerCasts();
627
41
628
41
  // See if the callsite is for resumption or destruction of the coroutine.
629
41
  auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
630
41
  if (!SubFn)
631
34
    return false;
632
7
633
7
  // Does not refer to the current coroutine, we cannot do anything with it.
634
7
  if (SubFn->getFrame() != CoroBegin)
635
2
    return false;
636
5
637
5
  // See if the transformation is safe. Specifically, see if there are any
638
5
  // calls in between Save and CallInstr. They can potenitally resume the
639
5
  // coroutine rendering this optimization unsafe.
640
5
  auto *Save = Suspend->getCoroSave();
641
5
  if (hasCallsBetween(Save, CallInstr))
642
2
    return false;
643
3
644
3
  // Replace llvm.coro.suspend with the value that results in resumption over
645
3
  // the resume or cleanup path.
646
3
  Suspend->replaceAllUsesWith(SubFn->getRawIndex());
647
3
  Suspend->eraseFromParent();
648
3
  Save->eraseFromParent();
649
3
650
3
  // No longer need a call to coro.resume or coro.destroy.
651
3
  if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
652
1
    BranchInst::Create(Invoke->getNormalDest(), Invoke);
653
1
  }
654
3
655
3
  // Grab the CalledValue from CS before erasing the CallInstr.
656
3
  auto *CalledValue = CS.getCalledValue();
657
3
  CallInstr->eraseFromParent();
658
3
659
3
  // If no more users remove it. Usually it is a bitcast of SubFn.
660
3
  if (CalledValue != SubFn && CalledValue->user_empty())
661
3
    if (auto *I = dyn_cast<Instruction>(CalledValue))
662
3
      I->eraseFromParent();
663
3
664
3
  // Now we are good to remove SubFn.
665
3
  if (SubFn->user_empty())
666
3
    SubFn->eraseFromParent();
667
3
668
3
  return true;
669
3
}
670
671
// Remove suspend points that are simplified.
672
37
static void simplifySuspendPoints(coro::Shape &Shape) {
673
37
  auto &S = Shape.CoroSuspends;
674
37
  size_t I = 0, N = S.size();
675
37
  if (N == 0)
676
1
    return;
677
42
  
while (36
true) {
678
42
    if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
679
3
      if (--N == I)
680
3
        break;
681
0
      std::swap(S[I], S[N]);
682
0
      continue;
683
0
    }
684
39
    if (++I == N)
685
33
      break;
686
39
  }
687
36
  S.resize(N);
688
36
}
689
690
37
static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) {
691
37
  // Collect all blocks that we need to look for instructions to relocate.
692
37
  SmallPtrSet<BasicBlock *, 4> RelocBlocks;
693
37
  SmallVector<BasicBlock *, 4> Work;
694
37
  Work.push_back(CB->getParent());
695
37
696
65
  do {
697
65
    BasicBlock *Current = Work.pop_back_val();
698
65
    for (BasicBlock *BB : predecessors(Current))
699
42
      if (RelocBlocks.count(BB) == 0) {
700
28
        RelocBlocks.insert(BB);
701
28
        Work.push_back(BB);
702
28
      }
703
65
  } while (!Work.empty());
704
37
  return RelocBlocks;
705
37
}
706
707
static SmallPtrSet<Instruction *, 8>
708
getNotRelocatableInstructions(CoroBeginInst *CoroBegin,
709
37
                              SmallPtrSetImpl<BasicBlock *> &RelocBlocks) {
710
37
  SmallPtrSet<Instruction *, 8> DoNotRelocate;
711
37
  // Collect all instructions that we should not relocate
712
37
  SmallVector<Instruction *, 8> Work;
713
37
714
37
  // Start with CoroBegin and terminators of all preceding blocks.
715
37
  Work.push_back(CoroBegin);
716
37
  BasicBlock *CoroBeginBB = CoroBegin->getParent();
717
37
  for (BasicBlock *BB : RelocBlocks)
718
28
    if (BB != CoroBeginBB)
719
28
      Work.push_back(BB->getTerminator());
720
37
721
37
  // For every instruction in the Work list, place its operands in DoNotRelocate
722
37
  // set.
723
203
  do {
724
203
    Instruction *Current = Work.pop_back_val();
725
203
    LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n");
726
203
    DoNotRelocate.insert(Current);
727
519
    for (Value *U : Current->operands()) {
728
519
      auto *I = dyn_cast<Instruction>(U);
729
519
      if (!I)
730
364
        continue;
731
155
732
155
      if (auto *A = dyn_cast<AllocaInst>(I)) {
733
4
        // Stores to alloca instructions that occur before the coroutine frame
734
4
        // is allocated should not be moved; the stored values may be used by
735
4
        // the coroutine frame allocator. The operands to those stores must also
736
4
        // remain in place.
737
4
        for (const auto &User : A->users())
738
13
          if (auto *SI = dyn_cast<llvm::StoreInst>(User))
739
3
            if (RelocBlocks.count(SI->getParent()) != 0 &&
740
3
                
DoNotRelocate.count(SI) == 02
) {
741
1
              Work.push_back(SI);
742
1
              DoNotRelocate.insert(SI);
743
1
            }
744
4
        continue;
745
4
      }
746
151
747
151
      if (DoNotRelocate.count(I) == 0) {
748
137
        Work.push_back(I);
749
137
        DoNotRelocate.insert(I);
750
137
      }
751
151
    }
752
203
  } while (!Work.empty());
753
37
  return DoNotRelocate;
754
37
}
755
756
37
static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) {
757
37
  // Analyze which non-alloca instructions are needed for allocation and
758
37
  // relocate the rest to after coro.begin. We need to do it, since some of the
759
37
  // targets of those instructions may be placed into coroutine frame memory
760
37
  // for which becomes available after coro.begin intrinsic.
761
37
762
37
  auto BlockSet = getCoroBeginPredBlocks(CoroBegin);
763
37
  auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet);
764
37
765
37
  Instruction *InsertPt = CoroBegin->getNextNode();
766
37
  BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well.
767
183
  for (auto B = BB.begin(), E = BB.end(); B != E;) {
768
169
    Instruction &I = *B++;
769
169
    if (isa<AllocaInst>(&I))
770
28
      continue;
771
141
    if (&I == CoroBegin)
772
23
      break;
773
118
    if (DoNotRelocateSet.count(&I))
774
109
      continue;
775
9
    I.moveBefore(InsertPt);
776
9
  }
777
37
}
778
779
37
static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
780
37
  EliminateUnreachableBlocks(F);
781
37
782
37
  coro::Shape Shape(F);
783
37
  if (!Shape.CoroBegin)
784
0
    return;
785
37
786
37
  simplifySuspendPoints(Shape);
787
37
  relocateInstructionBefore(Shape.CoroBegin, F);
788
37
  buildCoroutineFrame(F, Shape);
789
37
  replaceFrameSize(Shape);
790
37
791
37
  // If there are no suspend points, no split required, just remove
792
37
  // the allocation and deallocation blocks, they are not needed.
793
37
  if (Shape.CoroSuspends.empty()) {
794
4
    handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
795
4
    removeCoroEnds(Shape);
796
4
    postSplitCleanup(F);
797
4
    coro::updateCallGraph(F, {}, CG, SCC);
798
4
    return;
799
4
  }
800
33
801
33
  auto *ResumeEntry = createResumeEntryBlock(F, Shape);
802
33
  auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
803
33
  auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
804
33
  auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
805
33
806
33
  // We no longer need coro.end in F.
807
33
  removeCoroEnds(Shape);
808
33
809
33
  postSplitCleanup(F);
810
33
  postSplitCleanup(*ResumeClone);
811
33
  postSplitCleanup(*DestroyClone);
812
33
  postSplitCleanup(*CleanupClone);
813
33
814
33
  addMustTailToCoroResumes(*ResumeClone);
815
33
816
33
  // Store addresses resume/destroy/cleanup functions in the coroutine frame.
817
33
  updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
818
33
819
33
  // Create a constant array referring to resume/destroy/clone functions pointed
820
33
  // by the last argument of @llvm.coro.info, so that CoroElide pass can
821
33
  // determined correct function to call.
822
33
  setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
823
33
824
33
  // Update call graph and add the functions we created to the SCC.
825
33
  coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
826
33
}
827
828
// When we see the coroutine the first time, we insert an indirect call to a
829
// devirt trigger function and mark the coroutine that it is now ready for
830
// split.
831
10
static void prepareForSplit(Function &F, CallGraph &CG) {
832
10
  Module &M = *F.getParent();
833
10
  LLVMContext &Context = F.getContext();
834
#ifndef NDEBUG
835
  Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
836
  assert(DevirtFn && "coro.devirt.trigger function not found");
837
#endif
838
839
10
  F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
840
10
841
10
  // Insert an indirect call sequence that will be devirtualized by CoroElide
842
10
  // pass:
843
10
  //    %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
844
10
  //    %1 = bitcast i8* %0 to void(i8*)*
845
10
  //    call void %1(i8* null)
846
10
  coro::LowererBase Lowerer(M);
847
10
  Instruction *InsertPt = F.getEntryBlock().getTerminator();
848
10
  auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
849
10
  auto *DevirtFnAddr =
850
10
      Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
851
10
  FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
852
10
                                         {Type::getInt8PtrTy(Context)}, false);
853
10
  auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
854
10
855
10
  // Update CG graph with an indirect call we just added.
856
10
  CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
857
10
}
858
859
// Make sure that there is a devirtualization trigger function that CoroSplit
860
// pass uses the force restart CGSCC pipeline. If devirt trigger function is not
861
// found, we will create one and add it to the current SCC.
862
47
static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
863
47
  Module &M = CG.getModule();
864
47
  if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
865
20
    return;
866
27
867
27
  LLVMContext &C = M.getContext();
868
27
  auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
869
27
                                 /*isVarArg=*/false);
870
27
  Function *DevirtFn =
871
27
      Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
872
27
                       CORO_DEVIRT_TRIGGER_FN, &M);
873
27
  DevirtFn->addFnAttr(Attribute::AlwaysInline);
874
27
  auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
875
27
  ReturnInst::Create(C, Entry);
876
27
877
27
  auto *Node = CG.getOrInsertFunction(DevirtFn);
878
27
879
27
  SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
880
27
  Nodes.push_back(Node);
881
27
  SCC.initialize(Nodes);
882
27
}
883
884
//===----------------------------------------------------------------------===//
885
//                              Top Level Driver
886
//===----------------------------------------------------------------------===//
887
888
namespace {
889
890
struct CoroSplit : public CallGraphSCCPass {
891
  static char ID; // Pass identification, replacement for typeid
892
893
55
  CoroSplit() : CallGraphSCCPass(ID) {
894
55
    initializeCoroSplitPass(*PassRegistry::getPassRegistry());
895
55
  }
896
897
  bool Run = false;
898
899
  // A coroutine is identified by the presence of coro.begin intrinsic, if
900
  // we don't have any, this pass has nothing to do.
901
55
  bool doInitialization(CallGraph &CG) override {
902
55
    Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
903
55
    return CallGraphSCCPass::doInitialization(CG);
904
55
  }
905
906
681
  bool runOnSCC(CallGraphSCC &SCC) override {
907
681
    if (!Run)
908
211
      return false;
909
470
910
470
    // Find coroutines for processing.
911
470
    SmallVector<Function *, 4> Coroutines;
912
470
    for (CallGraphNode *CGN : SCC)
913
480
      if (auto *F = CGN->getFunction())
914
424
        if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
915
47
          Coroutines.push_back(F);
916
470
917
470
    if (Coroutines.empty())
918
423
      return false;
919
47
920
47
    CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
921
47
    createDevirtTriggerFunc(CG, SCC);
922
47
923
47
    for (Function *F : Coroutines) {
924
47
      Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
925
47
      StringRef Value = Attr.getValueAsString();
926
47
      LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
927
47
                        << "' state: " << Value << "\n");
928
47
      if (Value == UNPREPARED_FOR_SPLIT) {
929
10
        prepareForSplit(*F, CG);
930
10
        continue;
931
10
      }
932
37
      F->removeFnAttr(CORO_PRESPLIT_ATTR);
933
37
      splitCoroutine(*F, CG, SCC);
934
37
    }
935
47
    return true;
936
47
  }
937
938
55
  void getAnalysisUsage(AnalysisUsage &AU) const override {
939
55
    CallGraphSCCPass::getAnalysisUsage(AU);
940
55
  }
941
942
0
  StringRef getPassName() const override { return "Coroutine Splitting"; }
943
};
944
945
} // end anonymous namespace
946
947
char CoroSplit::ID = 0;
948
949
11.0k
INITIALIZE_PASS_BEGIN(
950
11.0k
    CoroSplit, "coro-split",
951
11.0k
    "Split coroutine into a set of functions driving its state machine", false,
952
11.0k
    false)
953
11.0k
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
954
11.0k
INITIALIZE_PASS_END(
955
    CoroSplit, "coro-split",
956
    "Split coroutine into a set of functions driving its state machine", false,
957
    false)
958
959
36
Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }