Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/tools/clang/lib/Sema/SemaCoroutine.cpp
Line
Count
Source (jump to first uncovered line)
1
//===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
//  This file implements semantic analysis for C++ Coroutines.
10
//
11
//  This file contains references to sections of the Coroutines TS, which
12
//  can be found at http://wg21.link/coroutines.
13
//
14
//===----------------------------------------------------------------------===//
15
16
#include "CoroutineStmtBuilder.h"
17
#include "clang/AST/ASTLambda.h"
18
#include "clang/AST/Decl.h"
19
#include "clang/AST/ExprCXX.h"
20
#include "clang/AST/StmtCXX.h"
21
#include "clang/Lex/Preprocessor.h"
22
#include "clang/Sema/Initialization.h"
23
#include "clang/Sema/Overload.h"
24
#include "clang/Sema/ScopeInfo.h"
25
#include "clang/Sema/SemaInternal.h"
26
27
using namespace clang;
28
using namespace sema;
29
30
static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
31
627
                                 SourceLocation Loc, bool &Res) {
32
627
  DeclarationName DN = S.PP.getIdentifierInfo(Name);
33
627
  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
34
627
  // Suppress diagnostics when a private member is selected. The same warnings
35
627
  // will be produced again when building the call.
36
627
  LR.suppressDiagnostics();
37
627
  Res = S.LookupQualifiedName(LR, RD);
38
627
  return LR;
39
627
}
40
41
static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
42
283
                         SourceLocation Loc) {
43
283
  bool Res;
44
283
  lookupMember(S, Name, RD, Loc, Res);
45
283
  return Res;
46
283
}
47
48
/// Look up the std::coroutine_traits<...>::promise_type for the given
49
/// function type.
50
static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
51
211
                                  SourceLocation KwLoc) {
52
211
  const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
53
211
  const SourceLocation FuncLoc = FD->getLocation();
54
211
  // FIXME: Cache std::coroutine_traits once we've found it.
55
211
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
56
211
  if (!StdExp) {
57
4
    S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
58
4
        << "std::experimental::coroutine_traits";
59
4
    return QualType();
60
4
  }
61
207
62
207
  ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc);
63
207
  if (!CoroTraits) {
64
0
    return QualType();
65
0
  }
66
207
67
207
  // Form template argument list for coroutine_traits<R, P1, P2, ...> according
68
207
  // to [dcl.fct.def.coroutine]3
69
207
  TemplateArgumentListInfo Args(KwLoc, KwLoc);
70
401
  auto AddArg = [&](QualType T) {
71
401
    Args.addArgument(TemplateArgumentLoc(
72
401
        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
73
401
  };
74
207
  AddArg(FnType->getReturnType());
75
207
  // If the function is a non-static member function, add the type
76
207
  // of the implicit object parameter before the formal parameters.
77
207
  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
78
58
    if (MD->isInstance()) {
79
46
      // [over.match.funcs]4
80
46
      // For non-static member functions, the type of the implicit object
81
46
      // parameter is
82
46
      //  -- "lvalue reference to cv X" for functions declared without a
83
46
      //      ref-qualifier or with the & ref-qualifier
84
46
      //  -- "rvalue reference to cv X" for functions declared with the &&
85
46
      //      ref-qualifier
86
46
      QualType T = MD->getThisType()->getAs<PointerType>()->getPointeeType();
87
46
      T = FnType->getRefQualifier() == RQ_RValue
88
46
              ? 
S.Context.getRValueReferenceType(T)7
89
46
              : 
S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true)39
;
90
46
      AddArg(T);
91
46
    }
92
58
  }
93
207
  for (QualType T : FnType->getParamTypes())
94
148
    AddArg(T);
95
207
96
207
  // Build the template-id.
97
207
  QualType CoroTrait =
98
207
      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
99
207
  if (CoroTrait.isNull())
100
0
    return QualType();
101
207
  if (S.RequireCompleteType(KwLoc, CoroTrait,
102
207
                            diag::err_coroutine_type_missing_specialization))
103
1
    return QualType();
104
206
105
206
  auto *RD = CoroTrait->getAsCXXRecordDecl();
106
206
  assert(RD && "specialization of class template is not a class?");
107
206
108
206
  // Look up the ::promise_type member.
109
206
  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
110
206
                 Sema::LookupOrdinaryName);
111
206
  S.LookupQualifiedName(R, RD);
112
206
  auto *Promise = R.getAsSingle<TypeDecl>();
113
206
  if (!Promise) {
114
5
    S.Diag(FuncLoc,
115
5
           diag::err_implied_std_coroutine_traits_promise_type_not_found)
116
5
        << RD;
117
5
    return QualType();
118
5
  }
119
201
  // The promise type is required to be a class type.
120
201
  QualType PromiseType = S.Context.getTypeDeclType(Promise);
121
201
122
201
  auto buildElaboratedType = [&]() {
123
201
    auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
124
201
    NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
125
201
                                      CoroTrait.getTypePtr());
126
201
    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
127
201
  };
128
201
129
201
  if (!PromiseType->getAsCXXRecordDecl()) {
130
1
    S.Diag(FuncLoc,
131
1
           diag::err_implied_std_coroutine_traits_promise_type_not_class)
132
1
        << buildElaboratedType();
133
1
    return QualType();
134
1
  }
135
200
  if (S.RequireCompleteType(FuncLoc, buildElaboratedType(),
136
200
                            diag::err_coroutine_promise_type_incomplete))
137
1
    return QualType();
138
199
139
199
  return PromiseType;
140
199
}
141
142
/// Look up the std::experimental::coroutine_handle<PromiseType>.
143
static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
144
544
                                          SourceLocation Loc) {
145
544
  if (PromiseType.isNull())
146
0
    return QualType();
147
544
148
544
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
149
544
  assert(StdExp && "Should already be diagnosed");
150
544
151
544
  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
152
544
                      Loc, Sema::LookupOrdinaryName);
153
544
  if (!S.LookupQualifiedName(Result, StdExp)) {
154
1
    S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
155
1
        << "std::experimental::coroutine_handle";
156
1
    return QualType();
157
1
  }
158
543
159
543
  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
160
543
  if (!CoroHandle) {
161
0
    Result.suppressDiagnostics();
162
0
    // We found something weird. Complain about the first thing we found.
163
0
    NamedDecl *Found = *Result.begin();
164
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
165
0
    return QualType();
166
0
  }
167
543
168
543
  // Form template argument list for coroutine_handle<Promise>.
169
543
  TemplateArgumentListInfo Args(Loc, Loc);
170
543
  Args.addArgument(TemplateArgumentLoc(
171
543
      TemplateArgument(PromiseType),
172
543
      S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
173
543
174
543
  // Build the template-id.
175
543
  QualType CoroHandleType =
176
543
      S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
177
543
  if (CoroHandleType.isNull())
178
0
    return QualType();
179
543
  if (S.RequireCompleteType(Loc, CoroHandleType,
180
543
                            diag::err_coroutine_type_missing_specialization))
181
0
    return QualType();
182
543
183
543
  return CoroHandleType;
184
543
}
185
186
static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
187
1.30k
                                    StringRef Keyword) {
188
1.30k
  // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within
189
1.30k
  // a function body.
190
1.30k
  // FIXME: This also covers [expr.await]p2: "An await-expression shall not
191
1.30k
  // appear in a default argument." But the diagnostic QoI here could be
192
1.30k
  // improved to inform the user that default arguments specifically are not
193
1.30k
  // allowed.
194
1.30k
  auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
195
1.30k
  if (!FD) {
196
1
    S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
197
1
                    ? 
diag::err_coroutine_objc_method0
198
1
                    : diag::err_coroutine_outside_function) << Keyword;
199
1
    return false;
200
1
  }
201
1.29k
202
1.29k
  // An enumeration for mapping the diagnostic type to the correct diagnostic
203
1.29k
  // selection index.
204
1.29k
  enum InvalidFuncDiag {
205
1.29k
    DiagCtor = 0,
206
1.29k
    DiagDtor,
207
1.29k
    DiagMain,
208
1.29k
    DiagConstexpr,
209
1.29k
    DiagAutoRet,
210
1.29k
    DiagVarargs,
211
1.29k
    DiagConsteval,
212
1.29k
  };
213
1.29k
  bool Diagnosed = false;
214
1.29k
  auto DiagInvalid = [&](InvalidFuncDiag ID) {
215
9
    S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
216
9
    Diagnosed = true;
217
9
    return false;
218
9
  };
219
1.29k
220
1.29k
  // Diagnose when a constructor, destructor
221
1.29k
  // or the function 'main' are declared as a coroutine.
222
1.29k
  auto *MD = dyn_cast<CXXMethodDecl>(FD);
223
1.29k
  // [class.ctor]p11: "A constructor shall not be a coroutine."
224
1.29k
  if (MD && 
isa<CXXConstructorDecl>(MD)338
)
225
2
    return DiagInvalid(DiagCtor);
226
1.29k
  // [class.dtor]p17: "A destructor shall not be a coroutine."
227
1.29k
  else if (MD && 
isa<CXXDestructorDecl>(MD)336
)
228
1
    return DiagInvalid(DiagDtor);
229
1.29k
  // [basic.start.main]p3: "The function main shall not be a coroutine."
230
1.29k
  else if (FD->isMain())
231
1
    return DiagInvalid(DiagMain);
232
1.29k
233
1.29k
  // Emit a diagnostics for each of the following conditions which is not met.
234
1.29k
  // [expr.const]p2: "An expression e is a core constant expression unless the
235
1.29k
  // evaluation of e [...] would evaluate one of the following expressions:
236
1.29k
  // [...] an await-expression [...] a yield-expression."
237
1.29k
  if (FD->isConstexpr())
238
2
    DiagInvalid(FD->isConsteval() ? 
DiagConsteval0
: DiagConstexpr);
239
1.29k
  // [dcl.spec.auto]p15: "A function declared with a return type that uses a
240
1.29k
  // placeholder type shall not be a coroutine."
241
1.29k
  if (FD->getReturnType()->isUndeducedType())
242
2
    DiagInvalid(DiagAutoRet);
243
1.29k
  // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the
244
1.29k
  // coroutine shall not terminate with an ellipsis that is not part of a
245
1.29k
  // parameter-declaration."
246
1.29k
  if (FD->isVariadic())
247
1
    DiagInvalid(DiagVarargs);
248
1.29k
249
1.29k
  return !Diagnosed;
250
1.29k
}
251
252
static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
253
548
                                                 SourceLocation Loc) {
254
548
  DeclarationName OpName =
255
548
      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
256
548
  LookupResult Operators(SemaRef, OpName, SourceLocation(),
257
548
                         Sema::LookupOperatorName);
258
548
  SemaRef.LookupName(Operators, S);
259
548
260
548
  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
261
548
  const auto &Functions = Operators.asUnresolvedSet();
262
548
  bool IsOverloaded =
263
548
      Functions.size() > 1 ||
264
548
      
(524
Functions.size() == 1524
&&
isa<FunctionTemplateDecl>(*Functions.begin())38
);
265
548
  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
266
548
      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
267
548
      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
268
548
      Functions.begin(), Functions.end());
269
548
  assert(CoawaitOp);
270
548
  return CoawaitOp;
271
548
}
272
273
/// Build a call to 'operator co_await' if there is a suitable operator for
274
/// the given expression.
275
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
276
                                           Expr *E,
277
543
                                           UnresolvedLookupExpr *Lookup) {
278
543
  UnresolvedSet<16> Functions;
279
543
  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
280
543
  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
281
543
}
282
283
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
284
439
                                           SourceLocation Loc, Expr *E) {
285
439
  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
286
439
  if (R.isInvalid())
287
0
    return ExprError();
288
439
  return buildOperatorCoawaitCall(SemaRef, Loc, E,
289
439
                                  cast<UnresolvedLookupExpr>(R.get()));
290
439
}
291
292
static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
293
1.02k
                              MultiExprArg CallArgs) {
294
1.02k
  StringRef Name = S.Context.BuiltinInfo.getName(Id);
295
1.02k
  LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
296
1.02k
  S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
297
1.02k
298
1.02k
  auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
299
1.02k
  assert(BuiltInDecl && "failed to find builtin declaration");
300
1.02k
301
1.02k
  ExprResult DeclRef =
302
1.02k
      S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
303
1.02k
  assert(DeclRef.isUsable() && "Builtin reference cannot fail");
304
1.02k
305
1.02k
  ExprResult Call =
306
1.02k
      S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
307
1.02k
308
1.02k
  assert(!Call.isInvalid() && "Call to builtin cannot fail!");
309
1.02k
  return Call.get();
310
1.02k
}
311
312
static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
313
544
                                       SourceLocation Loc) {
314
544
  QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
315
544
  if (CoroHandleType.isNull())
316
1
    return ExprError();
317
543
318
543
  DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
319
543
  LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
320
543
                     Sema::LookupOrdinaryName);
321
543
  if (!S.LookupQualifiedName(Found, LookupCtx)) {
322
1
    S.Diag(Loc, diag::err_coroutine_handle_missing_member)
323
1
        << "from_address";
324
1
    return ExprError();
325
1
  }
326
542
327
542
  Expr *FramePtr =
328
542
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
329
542
330
542
  CXXScopeSpec SS;
331
542
  ExprResult FromAddr =
332
542
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
333
542
  if (FromAddr.isInvalid())
334
0
    return ExprError();
335
542
336
542
  return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
337
542
}
338
339
struct ReadySuspendResumeResult {
340
  enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
341
  Expr *Results[3];
342
  OpaqueValueExpr *OpaqueValue;
343
  bool IsInvalid;
344
};
345
346
static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
347
2.59k
                                  StringRef Name, MultiExprArg Args) {
348
2.59k
  DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
349
2.59k
350
2.59k
  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
351
2.59k
  CXXScopeSpec SS;
352
2.59k
  ExprResult Result = S.BuildMemberReferenceExpr(
353
2.59k
      Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
354
2.59k
      SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
355
2.59k
      /*Scope=*/nullptr);
356
2.59k
  if (Result.isInvalid())
357
18
    return ExprError();
358
2.57k
359
2.57k
  // We meant exactly what we asked for. No need for typo correction.
360
2.57k
  if (auto *TE = dyn_cast<TypoExpr>(Result.get())) {
361
3
    S.clearDelayedTypo(TE);
362
3
    S.Diag(Loc, diag::err_no_member)
363
3
        << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl()
364
3
        << Base->getSourceRange();
365
3
    return ExprError();
366
3
  }
367
2.57k
368
2.57k
  return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
369
2.57k
}
370
371
// See if return type is coroutine-handle and if so, invoke builtin coro-resume
372
// on its address. This is to enable experimental support for coroutine-handle
373
// returning await_suspend that results in a guaranteed tail call to the target
374
// coroutine.
375
static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
376
528
                           SourceLocation Loc) {
377
528
  if (RetType->isReferenceType())
378
2
    return nullptr;
379
526
  Type const *T = RetType.getTypePtr();
380
526
  if (!T->isClassType() && !T->isStructureType())
381
525
    return nullptr;
382
1
383
1
  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
384
1
  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
385
1
  // a private function in SemaExprCXX.cpp
386
1
387
1
  ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
388
1
  if (AddressExpr.isInvalid())
389
0
    return nullptr;
390
1
391
1
  Expr *JustAddress = AddressExpr.get();
392
1
  // FIXME: Check that the type of AddressExpr is void*
393
1
  return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
394
1
                          JustAddress);
395
1
}
396
397
/// Build calls to await_ready, await_suspend, and await_resume for a co_await
398
/// expression.
399
static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
400
544
                                                  SourceLocation Loc, Expr *E) {
401
544
  OpaqueValueExpr *Operand = new (S.Context)
402
544
      OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
403
544
404
544
  // Assume invalid until we see otherwise.
405
544
  ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
406
544
407
544
  ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
408
544
  if (CoroHandleRes.isInvalid())
409
2
    return Calls;
410
542
  Expr *CoroHandle = CoroHandleRes.get();
411
542
412
542
  const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
413
542
  MultiExprArg Args[] = {None, CoroHandle, None};
414
2.12k
  for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; 
++I1.58k
) {
415
1.60k
    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
416
1.60k
    if (Result.isInvalid())
417
14
      return Calls;
418
1.58k
    Calls.Results[I] = Result.get();
419
1.58k
  }
420
542
421
542
  // Assume the calls are valid; all further checking should make them invalid.
422
542
  Calls.IsInvalid = false;
423
528
424
528
  using ACT = ReadySuspendResumeResult::AwaitCallType;
425
528
  CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
426
528
  if (!AwaitReady->getType()->isDependentType()) {
427
528
    // [expr.await]p3 [...]
428
528
    // — await-ready is the expression e.await_ready(), contextually converted
429
528
    // to bool.
430
528
    ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
431
528
    if (Conv.isInvalid()) {
432
1
      S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(),
433
1
             diag::note_await_ready_no_bool_conversion);
434
1
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
435
1
          << AwaitReady->getDirectCallee() << E->getSourceRange();
436
1
      Calls.IsInvalid = true;
437
1
    }
438
528
    Calls.Results[ACT::ACT_Ready] = Conv.get();
439
528
  }
440
528
  CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
441
528
  if (!AwaitSuspend->getType()->isDependentType()) {
442
528
    // [expr.await]p3 [...]
443
528
    //   - await-suspend is the expression e.await_suspend(h), which shall be
444
528
    //     a prvalue of type void or bool.
445
528
    QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
446
528
447
528
    // Experimental support for coroutine_handle returning await_suspend.
448
528
    if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
449
1
      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
450
527
    else {
451
527
      // non-class prvalues always have cv-unqualified types
452
527
      if (RetType->isReferenceType() ||
453
527
          
(525
!RetType->isBooleanType()525
&&
!RetType->isVoidType()522
)) {
454
3
        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
455
3
               diag::err_await_suspend_invalid_return_type)
456
3
            << RetType;
457
3
        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
458
3
            << AwaitSuspend->getDirectCallee();
459
3
        Calls.IsInvalid = true;
460
3
      }
461
527
    }
462
528
  }
463
528
464
528
  return Calls;
465
542
}
466
467
static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
468
                                   SourceLocation Loc, StringRef Name,
469
989
                                   MultiExprArg Args) {
470
989
471
989
  // Form a reference to the promise.
472
989
  ExprResult PromiseRef = S.BuildDeclRefExpr(
473
989
      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
474
989
  if (PromiseRef.isInvalid())
475
0
    return ExprError();
476
989
477
989
  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
478
989
}
479
480
260
VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
481
260
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
482
260
  auto *FD = cast<FunctionDecl>(CurContext);
483
260
  bool IsThisDependentType = [&] {
484
260
    if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
485
78
      return MD->isInstance() && 
MD->getThisType()->isDependentType()66
;
486
182
    else
487
182
      return false;
488
260
  }();
489
260
490
260
  QualType T = FD->getType()->isDependentType() || 
IsThisDependentType222
491
260
                   ? 
Context.DependentTy49
492
260
                   : 
lookupPromiseType(*this, FD, Loc)211
;
493
260
  if (T.isNull())
494
12
    return nullptr;
495
248
496
248
  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
497
248
                             &PP.getIdentifierTable().get("__promise"), T,
498
248
                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
499
248
  CheckVariableDeclarationType(VD);
500
248
  if (VD->isInvalidDecl())
501
0
    return nullptr;
502
248
503
248
  auto *ScopeInfo = getCurFunction();
504
248
  // Build a list of arguments, based on the coroutine functions arguments,
505
248
  // that will be passed to the promise type's constructor.
506
248
  llvm::SmallVector<Expr *, 4> CtorArgExprs;
507
248
508
248
  // Add implicit object parameter.
509
248
  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
510
75
    if (MD->isInstance() && 
!isLambdaCallOperator(MD)64
) {
511
54
      ExprResult ThisExpr = ActOnCXXThis(Loc);
512
54
      if (ThisExpr.isInvalid())
513
0
        return nullptr;
514
54
      ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
515
54
      if (ThisExpr.isInvalid())
516
0
        return nullptr;
517
54
      CtorArgExprs.push_back(ThisExpr.get());
518
54
    }
519
75
  }
520
248
521
248
  auto &Moves = ScopeInfo->CoroutineParameterMoves;
522
248
  for (auto *PD : FD->parameters()) {
523
197
    if (PD->getType()->isDependentType())
524
40
      continue;
525
157
526
157
    auto RefExpr = ExprEmpty();
527
157
    auto Move = Moves.find(PD);
528
157
    assert(Move != Moves.end() &&
529
157
           "Coroutine function parameter not inserted into move map");
530
157
    // If a reference to the function parameter exists in the coroutine
531
157
    // frame, use that reference.
532
157
    auto *MoveDecl =
533
157
        cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
534
157
    RefExpr =
535
157
        BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
536
157
                         ExprValueKind::VK_LValue, FD->getLocation());
537
157
    if (RefExpr.isInvalid())
538
0
      return nullptr;
539
157
    CtorArgExprs.push_back(RefExpr.get());
540
157
  }
541
248
542
248
  // Create an initialization sequence for the promise type using the
543
248
  // constructor arguments, wrapped in a parenthesized list expression.
544
248
  Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(),
545
248
                                    CtorArgExprs, FD->getLocation());
546
248
  InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
547
248
  InitializationKind Kind = InitializationKind::CreateForInit(
548
248
      VD->getLocation(), /*DirectInit=*/true, PLE);
549
248
  InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
550
248
                                 /*TopLevelOfInitList=*/false,
551
248
                                 /*TreatUnavailableAsInvalid=*/false);
552
248
553
248
  // Attempt to initialize the promise type with the arguments.
554
248
  // If that fails, fall back to the promise type's default constructor.
555
248
  if (InitSeq) {
556
142
    ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
557
142
    if (Result.isInvalid()) {
558
0
      VD->setInvalidDecl();
559
142
    } else if (Result.get()) {
560
142
      VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
561
142
      VD->setInitStyle(VarDecl::CallInit);
562
142
      CheckCompleteVariableDeclaration(VD);
563
142
    }
564
142
  } else
565
106
    ActOnUninitializedDecl(VD);
566
248
567
248
  FD->addDecl(VD);
568
248
  return VD;
569
248
}
570
571
/// Check that this is a context in which a coroutine suspension can appear.
572
static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
573
                                                StringRef Keyword,
574
1.30k
                                                bool IsImplicit = false) {
575
1.30k
  if (!isValidCoroutineContext(S, Loc, Keyword))
576
9
    return nullptr;
577
1.29k
578
1.29k
  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
579
1.29k
580
1.29k
  auto *ScopeInfo = S.getCurFunction();
581
1.29k
  assert(ScopeInfo && "missing function scope for function");
582
1.29k
583
1.29k
  if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && 
!IsImplicit364
)
584
260
    ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
585
1.29k
586
1.29k
  if (ScopeInfo->CoroutinePromise)
587
1.08k
    return ScopeInfo;
588
208
589
208
  if (!S.buildCoroutineParameterMoves(Loc))
590
0
    return nullptr;
591
208
592
208
  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
593
208
  if (!ScopeInfo->CoroutinePromise)
594
12
    return nullptr;
595
196
596
196
  return ScopeInfo;
597
196
}
598
599
bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
600
281
                                   StringRef Keyword) {
601
281
  if (!checkCoroutineContext(*this, KWLoc, Keyword))
602
21
    return false;
603
260
  auto *ScopeInfo = getCurFunction();
604
260
  assert(ScopeInfo->CoroutinePromise);
605
260
606
260
  // If we have existing coroutine statements then we have already built
607
260
  // the initial and final suspend points.
608
260
  if (!ScopeInfo->NeedsCoroutineSuspends)
609
64
    return true;
610
196
611
196
  ScopeInfo->setNeedsCoroutineSuspends(false);
612
196
613
196
  auto *Fn = cast<FunctionDecl>(CurContext);
614
196
  SourceLocation Loc = Fn->getLocation();
615
196
  // Build the initial suspend point
616
385
  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
617
385
    ExprResult Suspend =
618
385
        buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
619
385
    if (Suspend.isInvalid())
620
3
      return StmtError();
621
382
    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
622
382
    if (Suspend.isInvalid())
623
0
      return StmtError();
624
382
    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
625
382
                                       /*IsImplicit*/ true);
626
382
    Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
627
382
    if (Suspend.isInvalid()) {
628
6
      Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
629
6
          << ((Name == "initial_suspend") ? 
05
:
11
);
630
6
      Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
631
6
      return StmtError();
632
6
    }
633
376
    return cast<Stmt>(Suspend.get());
634
376
  };
635
196
636
196
  StmtResult InitSuspend = buildSuspends("initial_suspend");
637
196
  if (InitSuspend.isInvalid())
638
7
    return true;
639
189
640
189
  StmtResult FinalSuspend = buildSuspends("final_suspend");
641
189
  if (FinalSuspend.isInvalid())
642
2
    return true;
643
187
644
187
  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
645
187
646
187
  return true;
647
187
}
648
649
// Recursively walks up the scope hierarchy until either a 'catch' or a function
650
// scope is found, whichever comes first.
651
169
static bool isWithinCatchScope(Scope *S) {
652
169
  // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but
653
169
  // lambdas that use 'co_await' are allowed. The loop below ends when a
654
169
  // function scope is found in order to ensure the following behavior:
655
169
  //
656
169
  // void foo() {      // <- function scope
657
169
  //   try {           //
658
169
  //     co_await x;   // <- 'co_await' is OK within a function scope
659
169
  //   } catch {       // <- catch scope
660
169
  //     co_await x;   // <- 'co_await' is not OK within a catch scope
661
169
  //     []() {        // <- function scope
662
169
  //       co_await x; // <- 'co_await' is OK within a function scope
663
169
  //     }();
664
169
  //   }
665
169
  // }
666
195
  while (S && 
!(S->getFlags() & Scope::FnScope)189
) {
667
29
    if (S->getFlags() & Scope::CatchScope)
668
3
      return true;
669
26
    S = S->getParent();
670
26
  }
671
169
  
return false166
;
672
169
}
673
674
// [expr.await]p2, emphasis added: "An await-expression shall appear only in
675
// a *potentially evaluated* expression within the compound-statement of a
676
// function-body *outside of a handler* [...] A context within a function
677
// where an await-expression can appear is called a suspension context of the
678
// function."
679
static void checkSuspensionContext(Sema &S, SourceLocation Loc,
680
169
                                   StringRef Keyword) {
681
169
  // First emphasis of [expr.await]p2: must be a potentially evaluated context.
682
169
  // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of
683
169
  // \c sizeof.
684
169
  if (S.isUnevaluatedContext())
685
6
    S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
686
169
687
169
  // Second emphasis of [expr.await]p2: must be outside of an exception handler.
688
169
  if (isWithinCatchScope(S.getCurScope()))
689
3
    S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword;
690
169
}
691
692
120
ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
693
120
  if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
694
11
    CorrectDelayedTyposInExpr(E);
695
11
    return ExprError();
696
11
  }
697
109
698
109
  checkSuspensionContext(*this, Loc, "co_await");
699
109
700
109
  if (E->getType()->isPlaceholderType()) {
701
2
    ExprResult R = CheckPlaceholderExpr(E);
702
2
    if (R.isInvalid()) 
return ExprError()0
;
703
2
    E = R.get();
704
2
  }
705
109
  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
706
109
  if (Lookup.isInvalid())
707
0
    return ExprError();
708
109
  return BuildUnresolvedCoawaitExpr(Loc, E,
709
109
                                   cast<UnresolvedLookupExpr>(Lookup.get()));
710
109
}
711
712
ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
713
132
                                            UnresolvedLookupExpr *Lookup) {
714
132
  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
715
132
  if (!FSI)
716
0
    return ExprError();
717
132
718
132
  if (E->getType()->isPlaceholderType()) {
719
0
    ExprResult R = CheckPlaceholderExpr(E);
720
0
    if (R.isInvalid())
721
0
      return ExprError();
722
0
    E = R.get();
723
0
  }
724
132
725
132
  auto *Promise = FSI->CoroutinePromise;
726
132
  if (Promise->getType()->isDependentType()) {
727
24
    Expr *Res =
728
24
        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
729
24
    return Res;
730
24
  }
731
108
732
108
  auto *RD = Promise->getType()->getAsCXXRecordDecl();
733
108
  if (lookupMember(*this, "await_transform", RD, Loc)) {
734
22
    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
735
22
    if (R.isInvalid()) {
736
4
      Diag(Loc,
737
4
           diag::note_coroutine_promise_implicit_await_transform_required_here)
738
4
          << E->getSourceRange();
739
4
      return ExprError();
740
4
    }
741
18
    E = R.get();
742
18
  }
743
108
  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
744
104
  if (Awaitable.isInvalid())
745
4
    return ExprError();
746
100
747
100
  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
748
100
}
749
750
ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
751
591
                                  bool IsImplicit) {
752
591
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
753
591
  if (!Coroutine)
754
0
    return ExprError();
755
591
756
591
  if (E->getType()->isPlaceholderType()) {
757
0
    ExprResult R = CheckPlaceholderExpr(E);
758
0
    if (R.isInvalid()) return ExprError();
759
0
    E = R.get();
760
0
  }
761
591
762
591
  if (E->getType()->isDependentType()) {
763
103
    Expr *Res = new (Context)
764
103
        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
765
103
    return Res;
766
103
  }
767
488
768
488
  // If the expression is a temporary, materialize it as an lvalue so that we
769
488
  // can use it multiple times.
770
488
  if (E->getValueKind() == VK_RValue)
771
437
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
772
488
773
488
  // The location of the `co_await` token cannot be used when constructing
774
488
  // the member call expressions since it's before the location of `Expr`, which
775
488
  // is used as the start of the member call expression.
776
488
  SourceLocation CallLoc = E->getExprLoc();
777
488
778
488
  // Build the await_ready, await_suspend, await_resume calls.
779
488
  ReadySuspendResumeResult RSS =
780
488
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E);
781
488
  if (RSS.IsInvalid)
782
19
    return ExprError();
783
469
784
469
  Expr *Res =
785
469
      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
786
469
                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
787
469
788
469
  return Res;
789
469
}
790
791
63
ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
792
63
  if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
793
3
    CorrectDelayedTyposInExpr(E);
794
3
    return ExprError();
795
3
  }
796
60
797
60
  checkSuspensionContext(*this, Loc, "co_yield");
798
60
799
60
  // Build yield_value call.
800
60
  ExprResult Awaitable = buildPromiseCall(
801
60
      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
802
60
  if (Awaitable.isInvalid())
803
3
    return ExprError();
804
57
805
57
  // Build 'operator co_await' call.
806
57
  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
807
57
  if (Awaitable.isInvalid())
808
0
    return ExprError();
809
57
810
57
  return BuildCoyieldExpr(Loc, Awaitable.get());
811
57
}
812
76
ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
813
76
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
814
76
  if (!Coroutine)
815
0
    return ExprError();
816
76
817
76
  if (E->getType()->isPlaceholderType()) {
818
0
    ExprResult R = CheckPlaceholderExpr(E);
819
0
    if (R.isInvalid()) return ExprError();
820
0
    E = R.get();
821
0
  }
822
76
823
76
  if (E->getType()->isDependentType()) {
824
20
    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
825
20
    return Res;
826
20
  }
827
56
828
56
  // If the expression is a temporary, materialize it as an lvalue so that we
829
56
  // can use it multiple times.
830
56
  if (E->getValueKind() == VK_RValue)
831
56
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
832
56
833
56
  // Build the await_ready, await_suspend, await_resume calls.
834
56
  ReadySuspendResumeResult RSS =
835
56
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
836
56
  if (RSS.IsInvalid)
837
1
    return ExprError();
838
55
839
55
  Expr *Res =
840
55
      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
841
55
                                RSS.Results[2], RSS.OpaqueValue);
842
55
843
55
  return Res;
844
55
}
845
846
90
StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
847
90
  if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
848
6
    CorrectDelayedTyposInExpr(E);
849
6
    return StmtError();
850
6
  }
851
84
  return BuildCoreturnStmt(Loc, E);
852
84
}
853
854
StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
855
220
                                   bool IsImplicit) {
856
220
  auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
857
220
  if (!FSI)
858
0
    return StmtError();
859
220
860
220
  if (E && 
E->getType()->isPlaceholderType()31
&&
861
220
      
!E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)2
) {
862
0
    ExprResult R = CheckPlaceholderExpr(E);
863
0
    if (R.isInvalid()) return StmtError();
864
0
    E = R.get();
865
0
  }
866
220
867
220
  // Move the return value if we can
868
220
  if (E) {
869
31
    auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove);
870
31
    if (NRVOCandidate) {
871
5
      InitializedEntity Entity =
872
5
          InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate);
873
5
      ExprResult MoveResult = this->PerformMoveOrCopyInitialization(
874
5
          Entity, NRVOCandidate, E->getType(), E);
875
5
      if (MoveResult.get())
876
5
        E = MoveResult.get();
877
5
    }
878
31
  }
879
220
880
220
  // FIXME: If the operand is a reference to a variable that's about to go out
881
220
  // of scope, we should treat the operand as an xvalue for this overload
882
220
  // resolution.
883
220
  VarDecl *Promise = FSI->CoroutinePromise;
884
220
  ExprResult PC;
885
220
  if (E && 
(31
isa<InitListExpr>(E)31
||
!E->getType()->isVoidType()28
)) {
886
30
    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
887
190
  } else {
888
190
    E = MakeFullDiscardedValueExpr(E).get();
889
190
    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
890
190
  }
891
220
  if (PC.isInvalid())
892
4
    return StmtError();
893
216
894
216
  Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get();
895
216
896
216
  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
897
216
  return Res;
898
216
}
899
900
/// Look up the std::nothrow object.
901
6
static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
902
6
  NamespaceDecl *Std = S.getStdNamespace();
903
6
  assert(Std && "Should already be diagnosed");
904
6
905
6
  LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
906
6
                      Sema::LookupOrdinaryName);
907
6
  if (!S.LookupQualifiedName(Result, Std)) {
908
0
    // FIXME: <experimental/coroutine> should have been included already.
909
0
    // If we require it to include <new> then this diagnostic is no longer
910
0
    // needed.
911
0
    S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
912
0
    return nullptr;
913
0
  }
914
6
915
6
  auto *VD = Result.getAsSingle<VarDecl>();
916
6
  if (!VD) {
917
0
    Result.suppressDiagnostics();
918
0
    // We found something weird. Complain about the first thing we found.
919
0
    NamedDecl *Found = *Result.begin();
920
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
921
0
    return nullptr;
922
0
  }
923
6
924
6
  ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
925
6
  if (DR.isInvalid())
926
0
    return nullptr;
927
6
928
6
  return DR.get();
929
6
}
930
931
// Find an appropriate delete for the promise.
932
static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
933
162
                                          QualType PromiseType) {
934
162
  FunctionDecl *OperatorDelete = nullptr;
935
162
936
162
  DeclarationName DeleteName =
937
162
      S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
938
162
939
162
  auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
940
162
  assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
941
162
942
162
  if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
943
0
    return nullptr;
944
162
945
162
  if (!OperatorDelete) {
946
160
    // Look for a global declaration.
947
160
    const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
948
160
    const bool Overaligned = false;
949
160
    OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
950
160
                                                     Overaligned, DeleteName);
951
160
  }
952
162
  S.MarkFunctionReferenced(Loc, OperatorDelete);
953
162
  return OperatorDelete;
954
162
}
955
956
957
260
void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
958
260
  FunctionScopeInfo *Fn = getCurFunction();
959
260
  assert(Fn && Fn->isCoroutine() && "not a coroutine");
960
260
  if (!Body) {
961
11
    assert(FD->isInvalidDecl() &&
962
11
           "a null body is only allowed for invalid declarations");
963
11
    return;
964
11
  }
965
249
  // We have a function that uses coroutine keywords, but we failed to build
966
249
  // the promise type.
967
249
  if (!Fn->CoroutinePromise)
968
12
    return FD->setInvalidDecl();
969
237
970
237
  if (isa<CoroutineBodyStmt>(Body)) {
971
41
    // Nothing todo. the body is already a transformed coroutine body statement.
972
41
    return;
973
41
  }
974
196
975
196
  // Coroutines [stmt.return]p1:
976
196
  //   A return statement shall not appear in a coroutine.
977
196
  if (Fn->FirstReturnLoc.isValid()) {
978
13
    assert(Fn->FirstCoroutineStmtLoc.isValid() &&
979
13
                   "first coroutine location not set");
980
13
    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
981
13
    Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
982
13
            << Fn->getFirstCoroutineStmtKeyword();
983
13
  }
984
196
  CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
985
196
  if (Builder.isInvalid() || 
!Builder.buildStatements()187
)
986
20
    return FD->setInvalidDecl();
987
176
988
176
  // Build body for the coroutine wrapper statement.
989
176
  Body = CoroutineBodyStmt::Create(Context, Builder);
990
176
}
991
992
CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
993
                                           sema::FunctionScopeInfo &Fn,
994
                                           Stmt *Body)
995
    : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
996
      IsPromiseDependentType(
997
          !Fn.CoroutinePromise ||
998
240
          Fn.CoroutinePromise->getType()->isDependentType()) {
999
240
  this->Body = Body;
1000
240
1001
240
  for (auto KV : Fn.CoroutineParameterMoves)
1002
148
    this->ParamMovesVector.push_back(KV.second);
1003
240
  this->ParamMoves = this->ParamMovesVector;
1004
240
1005
240
  if (!IsPromiseDependentType) {
1006
191
    PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
1007
191
    assert(PromiseRecordDecl && "Type should have already been checked");
1008
191
  }
1009
240
  this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
1010
240
}
1011
1012
187
bool CoroutineStmtBuilder::buildStatements() {
1013
187
  assert(this->IsValid && "coroutine already invalid");
1014
187
  this->IsValid = makeReturnObject();
1015
187
  if (this->IsValid && 
!IsPromiseDependentType186
)
1016
139
    buildDependentStatements();
1017
187
  return this->IsValid;
1018
187
}
1019
1020
175
bool CoroutineStmtBuilder::buildDependentStatements() {
1021
175
  assert(this->IsValid && "coroutine already invalid");
1022
175
  assert(!this->IsPromiseDependentType &&
1023
175
         "coroutine cannot have a dependent promise type");
1024
175
  this->IsValid = makeOnException() && 
makeOnFallthrough()172
&&
1025
175
                  
makeGroDeclAndReturnStmt()169
&&
makeReturnOnAllocFailure()167
&&
1026
175
                  
makeNewAndDeleteExpr()164
;
1027
175
  return this->IsValid;
1028
175
}
1029
1030
240
bool CoroutineStmtBuilder::makePromiseStmt() {
1031
240
  // Form a declaration statement for the promise declaration, so that AST
1032
240
  // visitors can more easily find it.
1033
240
  StmtResult PromiseStmt =
1034
240
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
1035
240
  if (PromiseStmt.isInvalid())
1036
0
    return false;
1037
240
1038
240
  this->Promise = PromiseStmt.get();
1039
240
  return true;
1040
240
}
1041
1042
240
bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
1043
240
  if (Fn.hasInvalidCoroutineSuspends())
1044
9
    return false;
1045
231
  this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
1046
231
  this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
1047
231
  return true;
1048
231
}
1049
1050
static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
1051
                                     CXXRecordDecl *PromiseRecordDecl,
1052
12
                                     FunctionScopeInfo &Fn) {
1053
12
  auto Loc = E->getExprLoc();
1054
12
  if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
1055
12
    auto *Decl = DeclRef->getDecl();
1056
12
    if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
1057
12
      if (Method->isStatic())
1058
11
        return true;
1059
1
      else
1060
1
        Loc = Decl->getLocation();
1061
12
    }
1062
12
  }
1063
12
1064
12
  S.Diag(
1065
1
      Loc,
1066
1
      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
1067
1
      << PromiseRecordDecl;
1068
1
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1069
1
      << Fn.getFirstCoroutineStmtKeyword();
1070
1
  return false;
1071
12
}
1072
1073
167
bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
1074
167
  assert(!IsPromiseDependentType &&
1075
167
         "cannot make statement while the promise type is dependent");
1076
167
1077
167
  // [dcl.fct.def.coroutine]/8
1078
167
  // The unqualified-id get_return_object_on_allocation_failure is looked up in
1079
167
  // the scope of class P by class member access lookup (3.4.5). ...
1080
167
  // If an allocation function returns nullptr, ... the coroutine return value
1081
167
  // is obtained by a call to ... get_return_object_on_allocation_failure().
1082
167
1083
167
  DeclarationName DN =
1084
167
      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
1085
167
  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
1086
167
  if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
1087
155
    return true;
1088
12
1089
12
  CXXScopeSpec SS;
1090
12
  ExprResult DeclNameExpr =
1091
12
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
1092
12
  if (DeclNameExpr.isInvalid())
1093
0
    return false;
1094
12
1095
12
  if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
1096
1
    return false;
1097
11
1098
11
  ExprResult ReturnObjectOnAllocationFailure =
1099
11
      S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
1100
11
  if (ReturnObjectOnAllocationFailure.isInvalid())
1101
0
    return false;
1102
11
1103
11
  StmtResult ReturnStmt =
1104
11
      S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
1105
11
  if (ReturnStmt.isInvalid()) {
1106
2
    S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
1107
2
        << DN;
1108
2
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1109
2
        << Fn.getFirstCoroutineStmtKeyword();
1110
2
    return false;
1111
2
  }
1112
9
1113
9
  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
1114
9
  return true;
1115
9
}
1116
1117
164
bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
1118
164
  // Form and check allocation and deallocation calls.
1119
164
  assert(!IsPromiseDependentType &&
1120
164
         "cannot make statement while the promise type is dependent");
1121
164
  QualType PromiseType = Fn.CoroutinePromise->getType();
1122
164
1123
164
  if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
1124
0
    return false;
1125
164
1126
164
  const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
1127
164
1128
164
  // [dcl.fct.def.coroutine]/7
1129
164
  // Lookup allocation functions using a parameter list composed of the
1130
164
  // requested size of the coroutine state being allocated, followed by
1131
164
  // the coroutine function's arguments. If a matching allocation function
1132
164
  // exists, use it. Otherwise, use an allocation function that just takes
1133
164
  // the requested size.
1134
164
1135
164
  FunctionDecl *OperatorNew = nullptr;
1136
164
  FunctionDecl *OperatorDelete = nullptr;
1137
164
  FunctionDecl *UnusedResult = nullptr;
1138
164
  bool PassAlignment = false;
1139
164
  SmallVector<Expr *, 1> PlacementArgs;
1140
164
1141
164
  // [dcl.fct.def.coroutine]/7
1142
164
  // "The allocation function’s name is looked up in the scope of P.
1143
164
  // [...] If the lookup finds an allocation function in the scope of P,
1144
164
  // overload resolution is performed on a function call created by assembling
1145
164
  // an argument list. The first argument is the amount of space requested,
1146
164
  // and has type std::size_t. The lvalues p1 ... pn are the succeeding
1147
164
  // arguments."
1148
164
  //
1149
164
  // ...where "p1 ... pn" are defined earlier as:
1150
164
  //
1151
164
  // [dcl.fct.def.coroutine]/3
1152
164
  // "For a coroutine f that is a non-static member function, let P1 denote the
1153
164
  // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
1154
164
  // of the function parameters; otherwise let P1 ... Pn be the types of the
1155
164
  // function parameters. Let p1 ... pn be lvalues denoting those objects."
1156
164
  if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
1157
47
    if (MD->isInstance() && 
!isLambdaCallOperator(MD)40
) {
1158
36
      ExprResult ThisExpr = S.ActOnCXXThis(Loc);
1159
36
      if (ThisExpr.isInvalid())
1160
0
        return false;
1161
36
      ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
1162
36
      if (ThisExpr.isInvalid())
1163
0
        return false;
1164
36
      PlacementArgs.push_back(ThisExpr.get());
1165
36
    }
1166
47
  }
1167
164
  for (auto *PD : FD.parameters()) {
1168
119
    if (PD->getType()->isDependentType())
1169
0
      continue;
1170
119
1171
119
    // Build a reference to the parameter.
1172
119
    auto PDLoc = PD->getLocation();
1173
119
    ExprResult PDRefExpr =
1174
119
        S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
1175
119
                           ExprValueKind::VK_LValue, PDLoc);
1176
119
    if (PDRefExpr.isInvalid())
1177
0
      return false;
1178
119
1179
119
    PlacementArgs.push_back(PDRefExpr.get());
1180
119
  }
1181
164
  S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1182
164
                            /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1183
164
                            /*isArray*/ false, PassAlignment, PlacementArgs,
1184
164
                            OperatorNew, UnusedResult, /*Diagnose*/ false);
1185
164
1186
164
  // [dcl.fct.def.coroutine]/7
1187
164
  // "If no matching function is found, overload resolution is performed again
1188
164
  // on a function call created by passing just the amount of space required as
1189
164
  // an argument of type std::size_t."
1190
164
  if (!OperatorNew && 
!PlacementArgs.empty()160
) {
1191
90
    PlacementArgs.clear();
1192
90
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1193
90
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1194
90
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1195
90
                              OperatorNew, UnusedResult, /*Diagnose*/ false);
1196
90
  }
1197
164
1198
164
  // [dcl.fct.def.coroutine]/7
1199
164
  // "The allocation function’s name is looked up in the scope of P. If this
1200
164
  // lookup fails, the allocation function’s name is looked up in the global
1201
164
  // scope."
1202
164
  if (!OperatorNew) {
1203
158
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global,
1204
158
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1205
158
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1206
158
                              OperatorNew, UnusedResult);
1207
158
  }
1208
164
1209
164
  bool IsGlobalOverload =
1210
164
      OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1211
164
  // If we didn't find a class-local new declaration and non-throwing new
1212
164
  // was is required then we need to lookup the non-throwing global operator
1213
164
  // instead.
1214
164
  if (RequiresNoThrowAlloc && 
(9
!OperatorNew9
||
IsGlobalOverload9
)) {
1215
6
    auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1216
6
    if (!StdNoThrow)
1217
0
      return false;
1218
6
    PlacementArgs = {StdNoThrow};
1219
6
    OperatorNew = nullptr;
1220
6
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both,
1221
6
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1222
6
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1223
6
                              OperatorNew, UnusedResult);
1224
6
  }
1225
164
1226
164
  if (!OperatorNew)
1227
0
    return false;
1228
164
1229
164
  if (RequiresNoThrowAlloc) {
1230
9
    const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>();
1231
9
    if (!FT->isNothrow(/*ResultIfDependent*/ false)) {
1232
2
      S.Diag(OperatorNew->getLocation(),
1233
2
             diag::err_coroutine_promise_new_requires_nothrow)
1234
2
          << OperatorNew;
1235
2
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1236
2
          << OperatorNew;
1237
2
      return false;
1238
2
    }
1239
162
  }
1240
162
1241
162
  if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr)
1242
0
    return false;
1243
162
1244
162
  Expr *FramePtr =
1245
162
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1246
162
1247
162
  Expr *FrameSize =
1248
162
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1249
162
1250
162
  // Make new call.
1251
162
1252
162
  ExprResult NewRef =
1253
162
      S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1254
162
  if (NewRef.isInvalid())
1255
0
    return false;
1256
162
1257
162
  SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1258
162
  for (auto Arg : PlacementArgs)
1259
16
    NewArgs.push_back(Arg);
1260
162
1261
162
  ExprResult NewExpr =
1262
162
      S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1263
162
  NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false);
1264
162
  if (NewExpr.isInvalid())
1265
0
    return false;
1266
162
1267
162
  // Make delete call.
1268
162
1269
162
  QualType OpDeleteQualType = OperatorDelete->getType();
1270
162
1271
162
  ExprResult DeleteRef =
1272
162
      S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1273
162
  if (DeleteRef.isInvalid())
1274
0
    return false;
1275
162
1276
162
  Expr *CoroFree =
1277
162
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1278
162
1279
162
  SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1280
162
1281
162
  // Check if we need to pass the size.
1282
162
  const auto *OpDeleteType =
1283
162
      OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>();
1284
162
  if (OpDeleteType->getNumParams() > 1)
1285
1
    DeleteArgs.push_back(FrameSize);
1286
162
1287
162
  ExprResult DeleteExpr =
1288
162
      S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1289
162
  DeleteExpr =
1290
162
      S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false);
1291
162
  if (DeleteExpr.isInvalid())
1292
0
    return false;
1293
162
1294
162
  this->Allocate = NewExpr.get();
1295
162
  this->Deallocate = DeleteExpr.get();
1296
162
1297
162
  return true;
1298
162
}
1299
1300
172
bool CoroutineStmtBuilder::makeOnFallthrough() {
1301
172
  assert(!IsPromiseDependentType &&
1302
172
         "cannot make statement while the promise type is dependent");
1303
172
1304
172
  // [dcl.fct.def.coroutine]/4
1305
172
  // The unqualified-ids 'return_void' and 'return_value' are looked up in
1306
172
  // the scope of class P. If both are found, the program is ill-formed.
1307
172
  bool HasRVoid, HasRValue;
1308
172
  LookupResult LRVoid =
1309
172
      lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1310
172
  LookupResult LRValue =
1311
172
      lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1312
172
1313
172
  StmtResult Fallthrough;
1314
172
  if (HasRVoid && 
HasRValue122
) {
1315
2
    // FIXME Improve this diagnostic
1316
2
    S.Diag(FD.getLocation(),
1317
2
           diag::err_coroutine_promise_incompatible_return_functions)
1318
2
        << PromiseRecordDecl;
1319
2
    S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1320
2
           diag::note_member_first_declared_here)
1321
2
        << LRVoid.getLookupName();
1322
2
    S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1323
2
           diag::note_member_first_declared_here)
1324
2
        << LRValue.getLookupName();
1325
2
    return false;
1326
170
  } else if (!HasRVoid && 
!HasRValue50
) {
1327
1
    // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1328
1
    // However we still diagnose this as an error since until the PDTS is fixed.
1329
1
    S.Diag(FD.getLocation(),
1330
1
           diag::err_coroutine_promise_requires_return_function)
1331
1
        << PromiseRecordDecl;
1332
1
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1333
1
        << PromiseRecordDecl;
1334
1
    return false;
1335
169
  } else if (HasRVoid) {
1336
120
    // If the unqualified-id return_void is found, flowing off the end of a
1337
120
    // coroutine is equivalent to a co_return with no operand. Otherwise,
1338
120
    // flowing off the end of a coroutine results in undefined behavior.
1339
120
    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1340
120
                                      /*IsImplicit*/false);
1341
120
    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1342
120
    if (Fallthrough.isInvalid())
1343
0
      return false;
1344
169
  }
1345
169
1346
169
  this->OnFallthrough = Fallthrough.get();
1347
169
  return true;
1348
169
}
1349
1350
175
bool CoroutineStmtBuilder::makeOnException() {
1351
175
  // Try to form 'p.unhandled_exception();'
1352
175
  assert(!IsPromiseDependentType &&
1353
175
         "cannot make statement while the promise type is dependent");
1354
175
1355
175
  const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1356
175
1357
175
  if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) {
1358
26
    auto DiagID =
1359
26
        RequireUnhandledException
1360
26
            ? 
diag::err_coroutine_promise_unhandled_exception_required2
1361
26
            : diag::
1362
24
                  warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1363
26
    S.Diag(Loc, DiagID) << PromiseRecordDecl;
1364
26
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1365
26
        << PromiseRecordDecl;
1366
26
    return !RequireUnhandledException;
1367
26
  }
1368
149
1369
149
  // If exceptions are disabled, don't try to build OnException.
1370
149
  if (!S.getLangOpts().CXXExceptions)
1371
34
    return true;
1372
115
1373
115
  ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1374
115
                                                   "unhandled_exception", None);
1375
115
  UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc,
1376
115
                                             /*DiscardedValue*/ false);
1377
115
  if (UnhandledException.isInvalid())
1378
0
    return false;
1379
115
1380
115
  // Since the body of the coroutine will be wrapped in try-catch, it will
1381
115
  // be incompatible with SEH __try if present in a function.
1382
115
  if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) {
1383
1
    S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1384
1
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1385
1
        << Fn.getFirstCoroutineStmtKeyword();
1386
1
    return false;
1387
1
  }
1388
114
1389
114
  this->OnException = UnhandledException.get();
1390
114
  return true;
1391
114
}
1392
1393
187
bool CoroutineStmtBuilder::makeReturnObject() {
1394
187
  // Build implicit 'p.get_return_object()' expression and form initialization
1395
187
  // of return type from it.
1396
187
  ExprResult ReturnObject =
1397
187
      buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1398
187
  if (ReturnObject.isInvalid())
1399
1
    return false;
1400
186
1401
186
  this->ReturnValue = ReturnObject.get();
1402
186
  return true;
1403
186
}
1404
1405
2
static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1406
2
  if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) {
1407
2
    auto *MethodDecl = MbrRef->getMethodDecl();
1408
2
    S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1409
2
        << MethodDecl;
1410
2
  }
1411
2
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1412
2
      << Fn.getFirstCoroutineStmtKeyword();
1413
2
}
1414
1415
169
bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1416
169
  assert(!IsPromiseDependentType &&
1417
169
         "cannot make statement while the promise type is dependent");
1418
169
  assert(this->ReturnValue && "ReturnValue must be already formed");
1419
169
1420
169
  QualType const GroType = this->ReturnValue->getType();
1421
169
  assert(!GroType->isDependentType() &&
1422
169
         "get_return_object type must no longer be dependent");
1423
169
1424
169
  QualType const FnRetType = FD.getReturnType();
1425
169
  assert(!FnRetType->isDependentType() &&
1426
169
         "get_return_object type must no longer be dependent");
1427
169
1428
169
  if (FnRetType->isVoidType()) {
1429
67
    ExprResult Res =
1430
67
        S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
1431
67
    if (Res.isInvalid())
1432
0
      return false;
1433
67
1434
67
    this->ResultDecl = Res.get();
1435
67
    return true;
1436
67
  }
1437
102
1438
102
  if (GroType->isVoidType()) {
1439
1
    // Trigger a nice error message.
1440
1
    InitializedEntity Entity =
1441
1
        InitializedEntity::InitializeResult(Loc, FnRetType, false);
1442
1
    S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1443
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1444
1
    return false;
1445
1
  }
1446
101
1447
101
  auto *GroDecl = VarDecl::Create(
1448
101
      S.Context, &FD, FD.getLocation(), FD.getLocation(),
1449
101
      &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1450
101
      S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1451
101
1452
101
  S.CheckVariableDeclarationType(GroDecl);
1453
101
  if (GroDecl->isInvalidDecl())
1454
0
    return false;
1455
101
1456
101
  InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1457
101
  ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1458
101
                                                     this->ReturnValue);
1459
101
  if (Res.isInvalid())
1460
0
    return false;
1461
101
1462
101
  Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
1463
101
  if (Res.isInvalid())
1464
0
    return false;
1465
101
1466
101
  S.AddInitializerToDecl(GroDecl, Res.get(),
1467
101
                         /*DirectInit=*/false);
1468
101
1469
101
  S.FinalizeDeclaration(GroDecl);
1470
101
1471
101
  // Form a declaration statement for the return declaration, so that AST
1472
101
  // visitors can more easily find it.
1473
101
  StmtResult GroDeclStmt =
1474
101
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1475
101
  if (GroDeclStmt.isInvalid())
1476
0
    return false;
1477
101
1478
101
  this->ResultDecl = GroDeclStmt.get();
1479
101
1480
101
  ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1481
101
  if (declRef.isInvalid())
1482
0
    return false;
1483
101
1484
101
  StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1485
101
  if (ReturnStmt.isInvalid()) {
1486
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1487
1
    return false;
1488
1
  }
1489
100
  if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
1490
87
    GroDecl->setNRVOVariable(true);
1491
100
1492
100
  this->ReturnStmt = ReturnStmt.get();
1493
100
  return true;
1494
100
}
1495
1496
// Create a static_cast\<T&&>(expr).
1497
52
static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1498
52
  if (T.isNull())
1499
52
    T = E->getType();
1500
52
  QualType TargetType = S.BuildReferenceType(
1501
52
      T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1502
52
  SourceLocation ExprLoc = E->getBeginLoc();
1503
52
  TypeSourceInfo *TargetLoc =
1504
52
      S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1505
52
1506
52
  return S
1507
52
      .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1508
52
                         SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1509
52
      .get();
1510
52
}
1511
1512
/// Build a variable declaration for move parameter.
1513
static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1514
160
                             IdentifierInfo *II) {
1515
160
  TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1516
160
  VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
1517
160
                                  TInfo, SC_None);
1518
160
  Decl->setImplicit();
1519
160
  return Decl;
1520
160
}
1521
1522
// Build statements that move coroutine function parameters to the coroutine
1523
// frame, and store them on the function scope info.
1524
260
bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
1525
260
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
1526
260
  auto *FD = cast<FunctionDecl>(CurContext);
1527
260
1528
260
  auto *ScopeInfo = getCurFunction();
1529
260
  assert(ScopeInfo->CoroutineParameterMoves.empty() &&
1530
260
         "Should not build parameter moves twice");
1531
260
1532
260
  for (auto *PD : FD->parameters()) {
1533
200
    if (PD->getType()->isDependentType())
1534
40
      continue;
1535
160
1536
160
    ExprResult PDRefExpr =
1537
160
        BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
1538
160
                         ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1539
160
    if (PDRefExpr.isInvalid())
1540
0
      return false;
1541
160
1542
160
    Expr *CExpr = nullptr;
1543
160
    if (PD->getType()->getAsCXXRecordDecl() ||
1544
160
        
PD->getType()->isRValueReferenceType()116
)
1545
52
      CExpr = castForMoving(*this, PDRefExpr.get());
1546
108
    else
1547
108
      CExpr = PDRefExpr.get();
1548
160
1549
160
    auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
1550
160
    AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
1551
160
1552
160
    // Convert decl to a statement.
1553
160
    StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
1554
160
    if (Stmt.isInvalid())
1555
0
      return false;
1556
160
1557
160
    ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
1558
160
  }
1559
260
  return true;
1560
260
}
1561
1562
41
StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1563
41
  CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1564
41
  if (!Res)
1565
0
    return StmtError();
1566
41
  return Res;
1567
41
}
1568
1569
ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc,
1570
207
                                               SourceLocation FuncLoc) {
1571
207
  if (!StdCoroutineTraitsCache) {
1572
31
    if (auto StdExp = lookupStdExperimentalNamespace()) {
1573
31
      LookupResult Result(*this,
1574
31
                          &PP.getIdentifierTable().get("coroutine_traits"),
1575
31
                          FuncLoc, LookupOrdinaryName);
1576
31
      if (!LookupQualifiedName(Result, StdExp)) {
1577
0
        Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
1578
0
            << "std::experimental::coroutine_traits";
1579
0
        return nullptr;
1580
0
      }
1581
31
      if (!(StdCoroutineTraitsCache =
1582
31
                Result.getAsSingle<ClassTemplateDecl>())) {
1583
0
        Result.suppressDiagnostics();
1584
0
        NamedDecl *Found = *Result.begin();
1585
0
        Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
1586
0
        return nullptr;
1587
0
      }
1588
207
    }
1589
31
  }
1590
207
  return StdCoroutineTraitsCache;
1591
207
}