Coverage Report

Created: 2020-09-15 12:33

/Users/buildslave/jenkins/workspace/coverage/llvm-project/clang/lib/Sema/SemaCoroutine.cpp
Line
Count
Source (jump to first uncovered line)
1
//===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
//  This file implements semantic analysis for C++ Coroutines.
10
//
11
//  This file contains references to sections of the Coroutines TS, which
12
//  can be found at http://wg21.link/coroutines.
13
//
14
//===----------------------------------------------------------------------===//
15
16
#include "CoroutineStmtBuilder.h"
17
#include "clang/AST/ASTLambda.h"
18
#include "clang/AST/Decl.h"
19
#include "clang/AST/ExprCXX.h"
20
#include "clang/AST/StmtCXX.h"
21
#include "clang/Basic/Builtins.h"
22
#include "clang/Lex/Preprocessor.h"
23
#include "clang/Sema/Initialization.h"
24
#include "clang/Sema/Overload.h"
25
#include "clang/Sema/ScopeInfo.h"
26
#include "clang/Sema/SemaInternal.h"
27
#include "llvm/ADT/SmallSet.h"
28
29
using namespace clang;
30
using namespace sema;
31
32
static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
33
674
                                 SourceLocation Loc, bool &Res) {
34
674
  DeclarationName DN = S.PP.getIdentifierInfo(Name);
35
674
  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
36
  // Suppress diagnostics when a private member is selected. The same warnings
37
  // will be produced again when building the call.
38
674
  LR.suppressDiagnostics();
39
674
  Res = S.LookupQualifiedName(LR, RD);
40
674
  return LR;
41
674
}
42
43
static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
44
302
                         SourceLocation Loc) {
45
302
  bool Res;
46
302
  lookupMember(S, Name, RD, Loc, Res);
47
302
  return Res;
48
302
}
49
50
/// Look up the std::coroutine_traits<...>::promise_type for the given
51
/// function type.
52
static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
53
228
                                  SourceLocation KwLoc) {
54
228
  const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
55
228
  const SourceLocation FuncLoc = FD->getLocation();
56
  // FIXME: Cache std::coroutine_traits once we've found it.
57
228
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
58
228
  if (!StdExp) {
59
4
    S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
60
4
        << "std::experimental::coroutine_traits";
61
4
    return QualType();
62
4
  }
63
224
64
224
  ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc);
65
224
  if (!CoroTraits) {
66
0
    return QualType();
67
0
  }
68
224
69
  // Form template argument list for coroutine_traits<R, P1, P2, ...> according
70
  // to [dcl.fct.def.coroutine]3
71
224
  TemplateArgumentListInfo Args(KwLoc, KwLoc);
72
424
  auto AddArg = [&](QualType T) {
73
424
    Args.addArgument(TemplateArgumentLoc(
74
424
        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
75
424
  };
76
224
  AddArg(FnType->getReturnType());
77
  // If the function is a non-static member function, add the type
78
  // of the implicit object parameter before the formal parameters.
79
224
  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
80
61
    if (MD->isInstance()) {
81
      // [over.match.funcs]4
82
      // For non-static member functions, the type of the implicit object
83
      // parameter is
84
      //  -- "lvalue reference to cv X" for functions declared without a
85
      //      ref-qualifier or with the & ref-qualifier
86
      //  -- "rvalue reference to cv X" for functions declared with the &&
87
      //      ref-qualifier
88
47
      QualType T = MD->getThisType()->castAs<PointerType>()->getPointeeType();
89
47
      T = FnType->getRefQualifier() == RQ_RValue
90
7
              ? S.Context.getRValueReferenceType(T)
91
40
              : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true);
92
47
      AddArg(T);
93
47
    }
94
61
  }
95
224
  for (QualType T : FnType->getParamTypes())
96
153
    AddArg(T);
97
224
98
  // Build the template-id.
99
224
  QualType CoroTrait =
100
224
      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
101
224
  if (CoroTrait.isNull())
102
0
    return QualType();
103
224
  if (S.RequireCompleteType(KwLoc, CoroTrait,
104
224
                            diag::err_coroutine_type_missing_specialization))
105
1
    return QualType();
106
223
107
223
  auto *RD = CoroTrait->getAsCXXRecordDecl();
108
223
  assert(RD && "specialization of class template is not a class?");
109
223
110
  // Look up the ::promise_type member.
111
223
  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
112
223
                 Sema::LookupOrdinaryName);
113
223
  S.LookupQualifiedName(R, RD);
114
223
  auto *Promise = R.getAsSingle<TypeDecl>();
115
223
  if (!Promise) {
116
6
    S.Diag(FuncLoc,
117
6
           diag::err_implied_std_coroutine_traits_promise_type_not_found)
118
6
        << RD;
119
6
    return QualType();
120
6
  }
121
  // The promise type is required to be a class type.
122
217
  QualType PromiseType = S.Context.getTypeDeclType(Promise);
123
217
124
217
  auto buildElaboratedType = [&]() {
125
217
    auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
126
217
    NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
127
217
                                      CoroTrait.getTypePtr());
128
217
    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
129
217
  };
130
217
131
217
  if (!PromiseType->getAsCXXRecordDecl()) {
132
1
    S.Diag(FuncLoc,
133
1
           diag::err_implied_std_coroutine_traits_promise_type_not_class)
134
1
        << buildElaboratedType();
135
1
    return QualType();
136
1
  }
137
216
  if (S.RequireCompleteType(FuncLoc, buildElaboratedType(),
138
216
                            diag::err_coroutine_promise_type_incomplete))
139
1
    return QualType();
140
215
141
215
  return PromiseType;
142
215
}
143
144
/// Look up the std::experimental::coroutine_handle<PromiseType>.
145
static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
146
582
                                          SourceLocation Loc) {
147
582
  if (PromiseType.isNull())
148
0
    return QualType();
149
582
150
582
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
151
582
  assert(StdExp && "Should already be diagnosed");
152
582
153
582
  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
154
582
                      Loc, Sema::LookupOrdinaryName);
155
582
  if (!S.LookupQualifiedName(Result, StdExp)) {
156
1
    S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
157
1
        << "std::experimental::coroutine_handle";
158
1
    return QualType();
159
1
  }
160
581
161
581
  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
162
581
  if (!CoroHandle) {
163
0
    Result.suppressDiagnostics();
164
    // We found something weird. Complain about the first thing we found.
165
0
    NamedDecl *Found = *Result.begin();
166
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
167
0
    return QualType();
168
0
  }
169
581
170
  // Form template argument list for coroutine_handle<Promise>.
171
581
  TemplateArgumentListInfo Args(Loc, Loc);
172
581
  Args.addArgument(TemplateArgumentLoc(
173
581
      TemplateArgument(PromiseType),
174
581
      S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
175
581
176
  // Build the template-id.
177
581
  QualType CoroHandleType =
178
581
      S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
179
581
  if (CoroHandleType.isNull())
180
0
    return QualType();
181
581
  if (S.RequireCompleteType(Loc, CoroHandleType,
182
581
                            diag::err_coroutine_type_missing_specialization))
183
0
    return QualType();
184
581
185
581
  return CoroHandleType;
186
581
}
187
188
static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
189
1.40k
                                    StringRef Keyword) {
190
  // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within
191
  // a function body.
192
  // FIXME: This also covers [expr.await]p2: "An await-expression shall not
193
  // appear in a default argument." But the diagnostic QoI here could be
194
  // improved to inform the user that default arguments specifically are not
195
  // allowed.
196
1.40k
  auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
197
1.40k
  if (!FD) {
198
1
    S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
199
0
                    ? diag::err_coroutine_objc_method
200
1
                    : diag::err_coroutine_outside_function) << Keyword;
201
1
    return false;
202
1
  }
203
1.40k
204
  // An enumeration for mapping the diagnostic type to the correct diagnostic
205
  // selection index.
206
1.40k
  enum InvalidFuncDiag {
207
1.40k
    DiagCtor = 0,
208
1.40k
    DiagDtor,
209
1.40k
    DiagMain,
210
1.40k
    DiagConstexpr,
211
1.40k
    DiagAutoRet,
212
1.40k
    DiagVarargs,
213
1.40k
    DiagConsteval,
214
1.40k
  };
215
1.40k
  bool Diagnosed = false;
216
9
  auto DiagInvalid = [&](InvalidFuncDiag ID) {
217
9
    S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
218
9
    Diagnosed = true;
219
9
    return false;
220
9
  };
221
1.40k
222
  // Diagnose when a constructor, destructor
223
  // or the function 'main' are declared as a coroutine.
224
1.40k
  auto *MD = dyn_cast<CXXMethodDecl>(FD);
225
  // [class.ctor]p11: "A constructor shall not be a coroutine."
226
1.40k
  if (MD && 
isa<CXXConstructorDecl>(MD)370
)
227
2
    return DiagInvalid(DiagCtor);
228
  // [class.dtor]p17: "A destructor shall not be a coroutine."
229
1.40k
  else if (MD && 
isa<CXXDestructorDecl>(MD)368
)
230
1
    return DiagInvalid(DiagDtor);
231
  // [basic.start.main]p3: "The function main shall not be a coroutine."
232
1.40k
  else if (FD->isMain())
233
1
    return DiagInvalid(DiagMain);
234
1.39k
235
  // Emit a diagnostics for each of the following conditions which is not met.
236
  // [expr.const]p2: "An expression e is a core constant expression unless the
237
  // evaluation of e [...] would evaluate one of the following expressions:
238
  // [...] an await-expression [...] a yield-expression."
239
1.39k
  if (FD->isConstexpr())
240
2
    DiagInvalid(FD->isConsteval() ? 
DiagConsteval0
: DiagConstexpr);
241
  // [dcl.spec.auto]p15: "A function declared with a return type that uses a
242
  // placeholder type shall not be a coroutine."
243
1.39k
  if (FD->getReturnType()->isUndeducedType())
244
2
    DiagInvalid(DiagAutoRet);
245
  // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the
246
  // coroutine shall not terminate with an ellipsis that is not part of a
247
  // parameter-declaration."
248
1.39k
  if (FD->isVariadic())
249
1
    DiagInvalid(DiagVarargs);
250
1.39k
251
1.39k
  return !Diagnosed;
252
1.39k
}
253
254
static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
255
590
                                                 SourceLocation Loc) {
256
590
  DeclarationName OpName =
257
590
      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
258
590
  LookupResult Operators(SemaRef, OpName, SourceLocation(),
259
590
                         Sema::LookupOperatorName);
260
590
  SemaRef.LookupName(Operators, S);
261
590
262
590
  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
263
590
  const auto &Functions = Operators.asUnresolvedSet();
264
590
  bool IsOverloaded =
265
590
      Functions.size() > 1 ||
266
566
      (Functions.size() == 1 && 
isa<FunctionTemplateDecl>(*Functions.begin())38
);
267
590
  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
268
590
      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
269
590
      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
270
590
      Functions.begin(), Functions.end());
271
590
  assert(CoawaitOp);
272
590
  return CoawaitOp;
273
590
}
274
275
/// Build a call to 'operator co_await' if there is a suitable operator for
276
/// the given expression.
277
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
278
                                           Expr *E,
279
582
                                           UnresolvedLookupExpr *Lookup) {
280
582
  UnresolvedSet<16> Functions;
281
582
  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
282
582
  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
283
582
}
284
285
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
286
473
                                           SourceLocation Loc, Expr *E) {
287
473
  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
288
473
  if (R.isInvalid())
289
0
    return ExprError();
290
473
  return buildOperatorCoawaitCall(SemaRef, Loc, E,
291
473
                                  cast<UnresolvedLookupExpr>(R.get()));
292
473
}
293
294
static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
295
1.11k
                              MultiExprArg CallArgs) {
296
1.11k
  StringRef Name = S.Context.BuiltinInfo.getName(Id);
297
1.11k
  LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
298
1.11k
  S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
299
1.11k
300
1.11k
  auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
301
1.11k
  assert(BuiltInDecl && "failed to find builtin declaration");
302
1.11k
303
1.11k
  ExprResult DeclRef =
304
1.11k
      S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
305
1.11k
  assert(DeclRef.isUsable() && "Builtin reference cannot fail");
306
1.11k
307
1.11k
  ExprResult Call =
308
1.11k
      S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
309
1.11k
310
1.11k
  assert(!Call.isInvalid() && "Call to builtin cannot fail!");
311
1.11k
  return Call.get();
312
1.11k
}
313
314
static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
315
582
                                       SourceLocation Loc) {
316
582
  QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
317
582
  if (CoroHandleType.isNull())
318
1
    return ExprError();
319
581
320
581
  DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
321
581
  LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
322
581
                     Sema::LookupOrdinaryName);
323
581
  if (!S.LookupQualifiedName(Found, LookupCtx)) {
324
1
    S.Diag(Loc, diag::err_coroutine_handle_missing_member)
325
1
        << "from_address";
326
1
    return ExprError();
327
1
  }
328
580
329
580
  Expr *FramePtr =
330
580
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
331
580
332
580
  CXXScopeSpec SS;
333
580
  ExprResult FromAddr =
334
580
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
335
580
  if (FromAddr.isInvalid())
336
0
    return ExprError();
337
580
338
580
  return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
339
580
}
340
341
struct ReadySuspendResumeResult {
342
  enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
343
  Expr *Results[3];
344
  OpaqueValueExpr *OpaqueValue;
345
  bool IsInvalid;
346
};
347
348
static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
349
2.78k
                                  StringRef Name, MultiExprArg Args) {
350
2.78k
  DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
351
2.78k
352
  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
353
2.78k
  CXXScopeSpec SS;
354
2.78k
  ExprResult Result = S.BuildMemberReferenceExpr(
355
2.78k
      Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
356
2.78k
      SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
357
2.78k
      /*Scope=*/nullptr);
358
2.78k
  if (Result.isInvalid())
359
18
    return ExprError();
360
2.76k
361
  // We meant exactly what we asked for. No need for typo correction.
362
2.76k
  if (auto *TE = dyn_cast<TypoExpr>(Result.get())) {
363
3
    S.clearDelayedTypo(TE);
364
3
    S.Diag(Loc, diag::err_no_member)
365
3
        << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl()
366
3
        << Base->getSourceRange();
367
3
    return ExprError();
368
3
  }
369
2.76k
370
2.76k
  return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
371
2.76k
}
372
373
// See if return type is coroutine-handle and if so, invoke builtin coro-resume
374
// on its address. This is to enable experimental support for coroutine-handle
375
// returning await_suspend that results in a guaranteed tail call to the target
376
// coroutine.
377
static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
378
566
                           SourceLocation Loc) {
379
566
  if (RetType->isReferenceType())
380
2
    return nullptr;
381
564
  Type const *T = RetType.getTypePtr();
382
564
  if (!T->isClassType() && !T->isStructureType())
383
561
    return nullptr;
384
3
385
  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
386
  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
387
  // a private function in SemaExprCXX.cpp
388
3
389
3
  ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
390
3
  if (AddressExpr.isInvalid())
391
0
    return nullptr;
392
3
393
3
  Expr *JustAddress = AddressExpr.get();
394
3
395
  // Check that the type of AddressExpr is void*
396
3
  if (!JustAddress->getType().getTypePtr()->isVoidPointerType())
397
1
    S.Diag(cast<CallExpr>(JustAddress)->getCalleeDecl()->getLocation(),
398
1
           diag::warn_coroutine_handle_address_invalid_return_type)
399
1
        << JustAddress->getType();
400
3
401
  // The coroutine handle used to obtain the address is no longer needed
402
  // at this point, clean it up to avoid unnecessarily long lifetime which
403
  // could lead to unnecessary spilling.
404
3
  JustAddress = S.MaybeCreateExprWithCleanups(JustAddress);
405
3
  return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
406
3
                          JustAddress);
407
3
}
408
409
/// Build calls to await_ready, await_suspend, and await_resume for a co_await
410
/// expression.
411
static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
412
582
                                                  SourceLocation Loc, Expr *E) {
413
582
  OpaqueValueExpr *Operand = new (S.Context)
414
582
      OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
415
582
416
  // Assume invalid until we see otherwise.
417
582
  ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
418
582
419
582
  ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
420
582
  if (CoroHandleRes.isInvalid())
421
2
    return Calls;
422
580
  Expr *CoroHandle = CoroHandleRes.get();
423
580
424
580
  const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
425
580
  MultiExprArg Args[] = {None, CoroHandle, None};
426
2.28k
  for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; 
++I1.70k
) {
427
1.71k
    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
428
1.71k
    if (Result.isInvalid())
429
14
      return Calls;
430
1.70k
    Calls.Results[I] = Result.get();
431
1.70k
  }
432
580
433
  // Assume the calls are valid; all further checking should make them invalid.
434
566
  Calls.IsInvalid = false;
435
566
436
566
  using ACT = ReadySuspendResumeResult::AwaitCallType;
437
566
  CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
438
566
  if (!AwaitReady->getType()->isDependentType()) {
439
    // [expr.await]p3 [...]
440
    // — await-ready is the expression e.await_ready(), contextually converted
441
    // to bool.
442
566
    ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
443
566
    if (Conv.isInvalid()) {
444
1
      S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(),
445
1
             diag::note_await_ready_no_bool_conversion);
446
1
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
447
1
          << AwaitReady->getDirectCallee() << E->getSourceRange();
448
1
      Calls.IsInvalid = true;
449
1
    }
450
566
    Calls.Results[ACT::ACT_Ready] = Conv.get();
451
566
  }
452
566
  CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
453
566
  if (!AwaitSuspend->getType()->isDependentType()) {
454
    // [expr.await]p3 [...]
455
    //   - await-suspend is the expression e.await_suspend(h), which shall be
456
    //     a prvalue of type void, bool, or std::coroutine_handle<Z> for some
457
    //     type Z.
458
566
    QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
459
566
460
    // Experimental support for coroutine_handle returning await_suspend.
461
566
    if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
462
3
      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
463
563
    else {
464
      // non-class prvalues always have cv-unqualified types
465
563
      if (RetType->isReferenceType() ||
466
561
          (!RetType->isBooleanType() && 
!RetType->isVoidType()557
)) {
467
3
        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
468
3
               diag::err_await_suspend_invalid_return_type)
469
3
            << RetType;
470
3
        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
471
3
            << AwaitSuspend->getDirectCallee();
472
3
        Calls.IsInvalid = true;
473
3
      }
474
563
    }
475
566
  }
476
566
477
566
  return Calls;
478
580
}
479
480
static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
481
                                   SourceLocation Loc, StringRef Name,
482
1.06k
                                   MultiExprArg Args) {
483
1.06k
484
  // Form a reference to the promise.
485
1.06k
  ExprResult PromiseRef = S.BuildDeclRefExpr(
486
1.06k
      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
487
1.06k
  if (PromiseRef.isInvalid())
488
0
    return ExprError();
489
1.06k
490
1.06k
  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
491
1.06k
}
492
493
282
VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
494
282
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
495
282
  auto *FD = cast<FunctionDecl>(CurContext);
496
282
  bool IsThisDependentType = [&] {
497
282
    if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
498
85
      return MD->isInstance() && 
MD->getThisType()->isDependentType()71
;
499
197
    else
500
197
      return false;
501
282
  }();
502
282
503
282
  QualType T = FD->getType()->isDependentType() || 
IsThisDependentType241
504
54
                   ? Context.DependentTy
505
228
                   : lookupPromiseType(*this, FD, Loc);
506
282
  if (T.isNull())
507
13
    return nullptr;
508
269
509
269
  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
510
269
                             &PP.getIdentifierTable().get("__promise"), T,
511
269
                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
512
269
  CheckVariableDeclarationType(VD);
513
269
  if (VD->isInvalidDecl())
514
0
    return nullptr;
515
269
516
269
  auto *ScopeInfo = getCurFunction();
517
269
518
  // Build a list of arguments, based on the coroutine function's arguments,
519
  // that if present will be passed to the promise type's constructor.
520
269
  llvm::SmallVector<Expr *, 4> CtorArgExprs;
521
269
522
  // Add implicit object parameter.
523
269
  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
524
82
    if (MD->isInstance() && 
!isLambdaCallOperator(MD)69
) {
525
55
      ExprResult ThisExpr = ActOnCXXThis(Loc);
526
55
      if (ThisExpr.isInvalid())
527
0
        return nullptr;
528
55
      ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
529
55
      if (ThisExpr.isInvalid())
530
0
        return nullptr;
531
55
      CtorArgExprs.push_back(ThisExpr.get());
532
55
    }
533
82
  }
534
269
535
  // Add the coroutine function's parameters.
536
269
  auto &Moves = ScopeInfo->CoroutineParameterMoves;
537
204
  for (auto *PD : FD->parameters()) {
538
204
    if (PD->getType()->isDependentType())
539
43
      continue;
540
161
541
161
    auto RefExpr = ExprEmpty();
542
161
    auto Move = Moves.find(PD);
543
161
    assert(Move != Moves.end() &&
544
161
           "Coroutine function parameter not inserted into move map");
545
    // If a reference to the function parameter exists in the coroutine
546
    // frame, use that reference.
547
161
    auto *MoveDecl =
548
161
        cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
549
161
    RefExpr =
550
161
        BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
551
161
                         ExprValueKind::VK_LValue, FD->getLocation());
552
161
    if (RefExpr.isInvalid())
553
0
      return nullptr;
554
161
    CtorArgExprs.push_back(RefExpr.get());
555
161
  }
556
269
557
  // If we have a non-zero number of constructor arguments, try to use them.
558
  // Otherwise, fall back to the promise type's default constructor.
559
269
  if (!CtorArgExprs.empty()) {
560
    // Create an initialization sequence for the promise type using the
561
    // constructor arguments, wrapped in a parenthesized list expression.
562
138
    Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(),
563
138
                                      CtorArgExprs, FD->getLocation());
564
138
    InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
565
138
    InitializationKind Kind = InitializationKind::CreateForInit(
566
138
        VD->getLocation(), /*DirectInit=*/true, PLE);
567
138
    InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
568
138
                                   /*TopLevelOfInitList=*/false,
569
138
                                   /*TreatUnavailableAsInvalid=*/false);
570
138
571
    // Attempt to initialize the promise type with the arguments.
572
    // If that fails, fall back to the promise type's default constructor.
573
138
    if (InitSeq) {
574
28
      ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
575
28
      if (Result.isInvalid()) {
576
0
        VD->setInvalidDecl();
577
28
      } else if (Result.get()) {
578
28
        VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
579
28
        VD->setInitStyle(VarDecl::CallInit);
580
28
        CheckCompleteVariableDeclaration(VD);
581
28
      }
582
28
    } else
583
110
      ActOnUninitializedDecl(VD);
584
138
  } else
585
131
    ActOnUninitializedDecl(VD);
586
269
587
269
  FD->addDecl(VD);
588
269
  return VD;
589
269
}
590
591
/// Check that this is a context in which a coroutine suspension can appear.
592
static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
593
                                                StringRef Keyword,
594
1.40k
                                                bool IsImplicit = false) {
595
1.40k
  if (!isValidCoroutineContext(S, Loc, Keyword))
596
9
    return nullptr;
597
1.39k
598
1.39k
  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
599
1.39k
600
1.39k
  auto *ScopeInfo = S.getCurFunction();
601
1.39k
  assert(ScopeInfo && "missing function scope for function");
602
1.39k
603
1.39k
  if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && 
!IsImplicit393
)
604
281
    ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
605
1.39k
606
1.39k
  if (ScopeInfo->CoroutinePromise)
607
1.16k
    return ScopeInfo;
608
227
609
227
  if (!S.buildCoroutineParameterMoves(Loc))
610
1
    return nullptr;
611
226
612
226
  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
613
226
  if (!ScopeInfo->CoroutinePromise)
614
13
    return nullptr;
615
213
616
213
  return ScopeInfo;
617
213
}
618
619
/// Recursively check \p E and all its children to see if any call target
620
/// (including constructor call) is declared noexcept. Also any value returned
621
/// from the call has a noexcept destructor.
622
static void checkNoThrow(Sema &S, const Stmt *E,
623
5.24k
                         llvm::SmallPtrSetImpl<const Decl *> &ThrowingDecls) {
624
1.82k
  auto checkDeclNoexcept = [&](const Decl *D, bool IsDtor = false) {
625
    // In the case of dtor, the call to dtor is implicit and hence we should
626
    // pass nullptr to canCalleeThrow.
627
1.82k
    if (Sema::canCalleeThrow(S, IsDtor ? 
nullptr292
:
cast<Expr>(E)1.53k
, D)) {
628
19
      if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
629
        // co_await promise.final_suspend() could end up calling
630
        // __builtin_coro_resume for symmetric transfer if await_suspend()
631
        // returns a handle. In that case, even __builtin_coro_resume is not
632
        // declared as noexcept and may throw, it does not throw _into_ the
633
        // coroutine that just suspended, but rather throws back out from
634
        // whoever called coroutine_handle::resume(), hence we claim that
635
        // logically it does not throw.
636
19
        if (FD->getBuiltinID() == Builtin::BI__builtin_coro_resume)
637
1
          return;
638
18
      }
639
18
      if (ThrowingDecls.empty()) {
640
        // First time seeing an error, emit the error message.
641
2
        S.Diag(cast<FunctionDecl>(S.CurContext)->getLocation(),
642
2
               diag::err_coroutine_promise_final_suspend_requires_nothrow);
643
2
      }
644
18
      ThrowingDecls.insert(D);
645
18
    }
646
1.82k
  };
647
5.24k
  auto SC = E->getStmtClass();
648
5.24k
  if (SC == Expr::CXXConstructExprClass) {
649
288
    auto const *Ctor = cast<CXXConstructExpr>(E)->getConstructor();
650
288
    checkDeclNoexcept(Ctor);
651
    // Check the corresponding destructor of the constructor.
652
288
    checkDeclNoexcept(Ctor->getParent()->getDestructor(), true);
653
4.95k
  } else if (SC == Expr::CallExprClass || 
SC == Expr::CXXMemberCallExprClass4.49k
||
654
3.66k
             SC == Expr::CXXOperatorCallExprClass) {
655
1.30k
    if (!cast<CallExpr>(E)->isTypeDependent()) {
656
1.24k
      checkDeclNoexcept(cast<CallExpr>(E)->getCalleeDecl());
657
1.24k
      auto ReturnType = cast<CallExpr>(E)->getCallReturnType(S.getASTContext());
658
      // Check the destructor of the call return type, if any.
659
1.24k
      if (ReturnType.isDestructedType() ==
660
4
          QualType::DestructionKind::DK_cxx_destructor) {
661
4
        const auto *T =
662
4
            cast<RecordType>(ReturnType.getCanonicalType().getTypePtr());
663
4
        checkDeclNoexcept(
664
4
            dyn_cast<CXXRecordDecl>(T->getDecl())->getDestructor(), true);
665
4
      }
666
1.24k
    }
667
1.30k
  }
668
5.14k
  for (const auto *Child : E->children()) {
669
5.14k
    if (!Child)
670
162
      continue;
671
4.98k
    checkNoThrow(S, Child, ThrowingDecls);
672
4.98k
  }
673
5.24k
}
674
675
260
bool Sema::checkFinalSuspendNoThrow(const Stmt *FinalSuspend) {
676
260
  llvm::SmallPtrSet<const Decl *, 4> ThrowingDecls;
677
  // We first collect all declarations that should not throw but not declared
678
  // with noexcept. We then sort them based on the location before printing.
679
  // This is to avoid emitting the same note multiple times on the same
680
  // declaration, and also provide a deterministic order for the messages.
681
260
  checkNoThrow(*this, FinalSuspend, ThrowingDecls);
682
260
  auto SortedDecls = llvm::SmallVector<const Decl *, 4>{ThrowingDecls.begin(),
683
260
                                                        ThrowingDecls.end()};
684
38
  sort(SortedDecls, [](const Decl *A, const Decl *B) {
685
38
    return A->getEndLoc() < B->getEndLoc();
686
38
  });
687
18
  for (const auto *D : SortedDecls) {
688
18
    Diag(D->getEndLoc(), diag::note_coroutine_function_declare_noexcept);
689
18
  }
690
260
  return ThrowingDecls.empty();
691
260
}
692
693
bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
694
302
                                   StringRef Keyword) {
695
302
  if (!checkCoroutineContext(*this, KWLoc, Keyword))
696
23
    return false;
697
279
  auto *ScopeInfo = getCurFunction();
698
279
  assert(ScopeInfo->CoroutinePromise);
699
279
700
  // If we have existing coroutine statements then we have already built
701
  // the initial and final suspend points.
702
279
  if (!ScopeInfo->NeedsCoroutineSuspends)
703
66
    return true;
704
213
705
213
  ScopeInfo->setNeedsCoroutineSuspends(false);
706
213
707
213
  auto *Fn = cast<FunctionDecl>(CurContext);
708
213
  SourceLocation Loc = Fn->getLocation();
709
  // Build the initial suspend point
710
419
  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
711
419
    ExprResult Suspend =
712
419
        buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
713
419
    if (Suspend.isInvalid())
714
3
      return StmtError();
715
416
    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
716
416
    if (Suspend.isInvalid())
717
0
      return StmtError();
718
416
    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
719
416
                                       /*IsImplicit*/ true);
720
416
    Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
721
416
    if (Suspend.isInvalid()) {
722
6
      Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
723
5
          << ((Name == "initial_suspend") ? 0 : 
11
);
724
6
      Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
725
6
      return StmtError();
726
6
    }
727
410
    return cast<Stmt>(Suspend.get());
728
410
  };
729
213
730
213
  StmtResult InitSuspend = buildSuspends("initial_suspend");
731
213
  if (InitSuspend.isInvalid())
732
7
    return true;
733
206
734
206
  StmtResult FinalSuspend = buildSuspends("final_suspend");
735
206
  if (FinalSuspend.isInvalid() || 
!checkFinalSuspendNoThrow(FinalSuspend.get())204
)
736
3
    return true;
737
203
738
203
  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
739
203
740
203
  return true;
741
203
}
742
743
// Recursively walks up the scope hierarchy until either a 'catch' or a function
744
// scope is found, whichever comes first.
745
177
static bool isWithinCatchScope(Scope *S) {
746
  // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but
747
  // lambdas that use 'co_await' are allowed. The loop below ends when a
748
  // function scope is found in order to ensure the following behavior:
749
  //
750
  // void foo() {      // <- function scope
751
  //   try {           //
752
  //     co_await x;   // <- 'co_await' is OK within a function scope
753
  //   } catch {       // <- catch scope
754
  //     co_await x;   // <- 'co_await' is not OK within a catch scope
755
  //     []() {        // <- function scope
756
  //       co_await x; // <- 'co_await' is OK within a function scope
757
  //     }();
758
  //   }
759
  // }
760
203
  while (S && 
!(S->getFlags() & Scope::FnScope)197
) {
761
29
    if (S->getFlags() & Scope::CatchScope)
762
3
      return true;
763
26
    S = S->getParent();
764
26
  }
765
174
  return false;
766
177
}
767
768
// [expr.await]p2, emphasis added: "An await-expression shall appear only in
769
// a *potentially evaluated* expression within the compound-statement of a
770
// function-body *outside of a handler* [...] A context within a function
771
// where an await-expression can appear is called a suspension context of the
772
// function."
773
static void checkSuspensionContext(Sema &S, SourceLocation Loc,
774
177
                                   StringRef Keyword) {
775
  // First emphasis of [expr.await]p2: must be a potentially evaluated context.
776
  // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of
777
  // \c sizeof.
778
177
  if (S.isUnevaluatedContext())
779
6
    S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
780
177
781
  // Second emphasis of [expr.await]p2: must be outside of an exception handler.
782
177
  if (isWithinCatchScope(S.getCurScope()))
783
3
    S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword;
784
177
}
785
786
130
ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
787
130
  if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
788
13
    CorrectDelayedTyposInExpr(E);
789
13
    return ExprError();
790
13
  }
791
117
792
117
  checkSuspensionContext(*this, Loc, "co_await");
793
117
794
117
  if (E->getType()->isPlaceholderType()) {
795
2
    ExprResult R = CheckPlaceholderExpr(E);
796
2
    if (R.isInvalid()) 
return ExprError()0
;
797
2
    E = R.get();
798
2
  }
799
117
  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
800
117
  if (Lookup.isInvalid())
801
0
    return ExprError();
802
117
  return BuildUnresolvedCoawaitExpr(Loc, E,
803
117
                                   cast<UnresolvedLookupExpr>(Lookup.get()));
804
117
}
805
806
ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
807
142
                                            UnresolvedLookupExpr *Lookup) {
808
142
  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
809
142
  if (!FSI)
810
0
    return ExprError();
811
142
812
142
  if (E->getType()->isPlaceholderType()) {
813
0
    ExprResult R = CheckPlaceholderExpr(E);
814
0
    if (R.isInvalid())
815
0
      return ExprError();
816
0
    E = R.get();
817
0
  }
818
142
819
142
  auto *Promise = FSI->CoroutinePromise;
820
142
  if (Promise->getType()->isDependentType()) {
821
29
    Expr *Res =
822
29
        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
823
29
    return Res;
824
29
  }
825
113
826
113
  auto *RD = Promise->getType()->getAsCXXRecordDecl();
827
113
  if (lookupMember(*this, "await_transform", RD, Loc)) {
828
23
    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
829
23
    if (R.isInvalid()) {
830
4
      Diag(Loc,
831
4
           diag::note_coroutine_promise_implicit_await_transform_required_here)
832
4
          << E->getSourceRange();
833
4
      return ExprError();
834
4
    }
835
19
    E = R.get();
836
19
  }
837
109
  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
838
109
  if (Awaitable.isInvalid())
839
4
    return ExprError();
840
105
841
105
  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
842
105
}
843
844
ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
845
639
                                  bool IsImplicit) {
846
639
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
847
639
  if (!Coroutine)
848
0
    return ExprError();
849
639
850
639
  if (E->getType()->isPlaceholderType()) {
851
0
    ExprResult R = CheckPlaceholderExpr(E);
852
0
    if (R.isInvalid()) return ExprError();
853
0
    E = R.get();
854
0
  }
855
639
856
639
  if (E->getType()->isDependentType()) {
857
113
    Expr *Res = new (Context)
858
113
        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
859
113
    return Res;
860
113
  }
861
526
862
  // If the expression is a temporary, materialize it as an lvalue so that we
863
  // can use it multiple times.
864
526
  if (E->getValueKind() == VK_RValue)
865
471
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
866
526
867
  // The location of the `co_await` token cannot be used when constructing
868
  // the member call expressions since it's before the location of `Expr`, which
869
  // is used as the start of the member call expression.
870
526
  SourceLocation CallLoc = E->getExprLoc();
871
526
872
  // Build the await_ready, await_suspend, await_resume calls.
873
526
  ReadySuspendResumeResult RSS =
874
526
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E);
875
526
  if (RSS.IsInvalid)
876
19
    return ExprError();
877
507
878
507
  Expr *Res =
879
507
      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
880
507
                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
881
507
882
507
  return Res;
883
507
}
884
885
63
ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
886
63
  if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
887
3
    CorrectDelayedTyposInExpr(E);
888
3
    return ExprError();
889
3
  }
890
60
891
60
  checkSuspensionContext(*this, Loc, "co_yield");
892
60
893
  // Build yield_value call.
894
60
  ExprResult Awaitable = buildPromiseCall(
895
60
      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
896
60
  if (Awaitable.isInvalid())
897
3
    return ExprError();
898
57
899
  // Build 'operator co_await' call.
900
57
  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
901
57
  if (Awaitable.isInvalid())
902
0
    return ExprError();
903
57
904
57
  return BuildCoyieldExpr(Loc, Awaitable.get());
905
57
}
906
78
ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
907
78
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
908
78
  if (!Coroutine)
909
0
    return ExprError();
910
78
911
78
  if (E->getType()->isPlaceholderType()) {
912
0
    ExprResult R = CheckPlaceholderExpr(E);
913
0
    if (R.isInvalid()) return ExprError();
914
0
    E = R.get();
915
0
  }
916
78
917
78
  if (E->getType()->isDependentType()) {
918
22
    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
919
22
    return Res;
920
22
  }
921
56
922
  // If the expression is a temporary, materialize it as an lvalue so that we
923
  // can use it multiple times.
924
56
  if (E->getValueKind() == VK_RValue)
925
56
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
926
56
927
  // Build the await_ready, await_suspend, await_resume calls.
928
56
  ReadySuspendResumeResult RSS =
929
56
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
930
56
  if (RSS.IsInvalid)
931
1
    return ExprError();
932
55
933
55
  Expr *Res =
934
55
      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
935
55
                                RSS.Results[2], RSS.OpaqueValue);
936
55
937
55
  return Res;
938
55
}
939
940
101
StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
941
101
  if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
942
6
    CorrectDelayedTyposInExpr(E);
943
6
    return StmtError();
944
6
  }
945
95
  return BuildCoreturnStmt(Loc, E);
946
95
}
947
948
StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
949
243
                                   bool IsImplicit) {
950
243
  auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
951
243
  if (!FSI)
952
0
    return StmtError();
953
243
954
243
  if (E && 
E->getType()->isPlaceholderType()36
&&
955
2
      !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
956
0
    ExprResult R = CheckPlaceholderExpr(E);
957
0
    if (R.isInvalid()) return StmtError();
958
0
    E = R.get();
959
0
  }
960
243
961
  // Move the return value if we can
962
243
  if (E) {
963
36
    auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove);
964
36
    if (NRVOCandidate) {
965
5
      InitializedEntity Entity =
966
5
          InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate);
967
5
      ExprResult MoveResult = this->PerformMoveOrCopyInitialization(
968
5
          Entity, NRVOCandidate, E->getType(), E);
969
5
      if (MoveResult.get())
970
5
        E = MoveResult.get();
971
5
    }
972
36
  }
973
243
974
  // FIXME: If the operand is a reference to a variable that's about to go out
975
  // of scope, we should treat the operand as an xvalue for this overload
976
  // resolution.
977
243
  VarDecl *Promise = FSI->CoroutinePromise;
978
243
  ExprResult PC;
979
243
  if (E && 
(36
isa<InitListExpr>(E)36
||
!E->getType()->isVoidType()31
)) {
980
35
    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
981
208
  } else {
982
208
    E = MakeFullDiscardedValueExpr(E).get();
983
208
    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
984
208
  }
985
243
  if (PC.isInvalid())
986
4
    return StmtError();
987
239
988
239
  Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get();
989
239
990
239
  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
991
239
  return Res;
992
239
}
993
994
/// Look up the std::nothrow object.
995
6
static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
996
6
  NamespaceDecl *Std = S.getStdNamespace();
997
6
  assert(Std && "Should already be diagnosed");
998
6
999
6
  LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
1000
6
                      Sema::LookupOrdinaryName);
1001
6
  if (!S.LookupQualifiedName(Result, Std)) {
1002
    // FIXME: <experimental/coroutine> should have been included already.
1003
    // If we require it to include <new> then this diagnostic is no longer
1004
    // needed.
1005
0
    S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
1006
0
    return nullptr;
1007
0
  }
1008
6
1009
6
  auto *VD = Result.getAsSingle<VarDecl>();
1010
6
  if (!VD) {
1011
0
    Result.suppressDiagnostics();
1012
    // We found something weird. Complain about the first thing we found.
1013
0
    NamedDecl *Found = *Result.begin();
1014
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
1015
0
    return nullptr;
1016
0
  }
1017
6
1018
6
  ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
1019
6
  if (DR.isInvalid())
1020
0
    return nullptr;
1021
6
1022
6
  return DR.get();
1023
6
}
1024
1025
// Find an appropriate delete for the promise.
1026
static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
1027
176
                                          QualType PromiseType) {
1028
176
  FunctionDecl *OperatorDelete = nullptr;
1029
176
1030
176
  DeclarationName DeleteName =
1031
176
      S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
1032
176
1033
176
  auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
1034
176
  assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
1035
176
1036
176
  if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
1037
0
    return nullptr;
1038
176
1039
176
  if (!OperatorDelete) {
1040
    // Look for a global declaration.
1041
174
    const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
1042
174
    const bool Overaligned = false;
1043
174
    OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
1044
174
                                                     Overaligned, DeleteName);
1045
174
  }
1046
176
  S.MarkFunctionReferenced(Loc, OperatorDelete);
1047
176
  return OperatorDelete;
1048
176
}
1049
1050
1051
281
void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
1052
281
  FunctionScopeInfo *Fn = getCurFunction();
1053
281
  assert(Fn && Fn->isCoroutine() && "not a coroutine");
1054
281
  if (!Body) {
1055
11
    assert(FD->isInvalidDecl() &&
1056
11
           "a null body is only allowed for invalid declarations");
1057
11
    return;
1058
11
  }
1059
  // We have a function that uses coroutine keywords, but we failed to build
1060
  // the promise type.
1061
270
  if (!Fn->CoroutinePromise)
1062
13
    return FD->setInvalidDecl();
1063
257
1064
257
  if (isa<CoroutineBodyStmt>(Body)) {
1065
    // Nothing todo. the body is already a transformed coroutine body statement.
1066
44
    return;
1067
44
  }
1068
213
1069
  // Coroutines [stmt.return]p1:
1070
  //   A return statement shall not appear in a coroutine.
1071
213
  if (Fn->FirstReturnLoc.isValid()) {
1072
13
    assert(Fn->FirstCoroutineStmtLoc.isValid() &&
1073
13
                   "first coroutine location not set");
1074
13
    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
1075
13
    Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1076
13
            << Fn->getFirstCoroutineStmtKeyword();
1077
13
  }
1078
213
  CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
1079
213
  if (Builder.isInvalid() || 
!Builder.buildStatements()203
)
1080
21
    return FD->setInvalidDecl();
1081
192
1082
  // Build body for the coroutine wrapper statement.
1083
192
  Body = CoroutineBodyStmt::Create(Context, Builder);
1084
192
}
1085
1086
CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
1087
                                           sema::FunctionScopeInfo &Fn,
1088
                                           Stmt *Body)
1089
    : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
1090
      IsPromiseDependentType(
1091
          !Fn.CoroutinePromise ||
1092
260
          Fn.CoroutinePromise->getType()->isDependentType()) {
1093
260
  this->Body = Body;
1094
260
1095
260
  for (auto KV : Fn.CoroutineParameterMoves)
1096
151
    this->ParamMovesVector.push_back(KV.second);
1097
260
  this->ParamMoves = this->ParamMovesVector;
1098
260
1099
260
  if (!IsPromiseDependentType) {
1100
206
    PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
1101
206
    assert(PromiseRecordDecl && "Type should have already been checked");
1102
206
  }
1103
260
  this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
1104
260
}
1105
1106
203
bool CoroutineStmtBuilder::buildStatements() {
1107
203
  assert(this->IsValid && "coroutine already invalid");
1108
203
  this->IsValid = makeReturnObject();
1109
203
  if (this->IsValid && 
!IsPromiseDependentType202
)
1110
152
    buildDependentStatements();
1111
203
  return this->IsValid;
1112
203
}
1113
1114
189
bool CoroutineStmtBuilder::buildDependentStatements() {
1115
189
  assert(this->IsValid && "coroutine already invalid");
1116
189
  assert(!this->IsPromiseDependentType &&
1117
189
         "coroutine cannot have a dependent promise type");
1118
189
  this->IsValid = makeOnException() && 
makeOnFallthrough()186
&&
1119
183
                  makeGroDeclAndReturnStmt() && 
makeReturnOnAllocFailure()181
&&
1120
178
                  makeNewAndDeleteExpr();
1121
189
  return this->IsValid;
1122
189
}
1123
1124
260
bool CoroutineStmtBuilder::makePromiseStmt() {
1125
  // Form a declaration statement for the promise declaration, so that AST
1126
  // visitors can more easily find it.
1127
260
  StmtResult PromiseStmt =
1128
260
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
1129
260
  if (PromiseStmt.isInvalid())
1130
0
    return false;
1131
260
1132
260
  this->Promise = PromiseStmt.get();
1133
260
  return true;
1134
260
}
1135
1136
260
bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
1137
260
  if (Fn.hasInvalidCoroutineSuspends())
1138
10
    return false;
1139
250
  this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
1140
250
  this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
1141
250
  return true;
1142
250
}
1143
1144
static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
1145
                                     CXXRecordDecl *PromiseRecordDecl,
1146
12
                                     FunctionScopeInfo &Fn) {
1147
12
  auto Loc = E->getExprLoc();
1148
12
  if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
1149
12
    auto *Decl = DeclRef->getDecl();
1150
12
    if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
1151
12
      if (Method->isStatic())
1152
11
        return true;
1153
1
      else
1154
1
        Loc = Decl->getLocation();
1155
12
    }
1156
12
  }
1157
12
1158
1
  S.Diag(
1159
1
      Loc,
1160
1
      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
1161
1
      << PromiseRecordDecl;
1162
1
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1163
1
      << Fn.getFirstCoroutineStmtKeyword();
1164
1
  return false;
1165
12
}
1166
1167
181
bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
1168
181
  assert(!IsPromiseDependentType &&
1169
181
         "cannot make statement while the promise type is dependent");
1170
181
1171
  // [dcl.fct.def.coroutine]/8
1172
  // The unqualified-id get_return_object_on_allocation_failure is looked up in
1173
  // the scope of class P by class member access lookup (3.4.5). ...
1174
  // If an allocation function returns nullptr, ... the coroutine return value
1175
  // is obtained by a call to ... get_return_object_on_allocation_failure().
1176
181
1177
181
  DeclarationName DN =
1178
181
      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
1179
181
  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
1180
181
  if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
1181
169
    return true;
1182
12
1183
12
  CXXScopeSpec SS;
1184
12
  ExprResult DeclNameExpr =
1185
12
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
1186
12
  if (DeclNameExpr.isInvalid())
1187
0
    return false;
1188
12
1189
12
  if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
1190
1
    return false;
1191
11
1192
11
  ExprResult ReturnObjectOnAllocationFailure =
1193
11
      S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
1194
11
  if (ReturnObjectOnAllocationFailure.isInvalid())
1195
0
    return false;
1196
11
1197
11
  StmtResult ReturnStmt =
1198
11
      S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
1199
11
  if (ReturnStmt.isInvalid()) {
1200
2
    S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
1201
2
        << DN;
1202
2
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1203
2
        << Fn.getFirstCoroutineStmtKeyword();
1204
2
    return false;
1205
2
  }
1206
9
1207
9
  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
1208
9
  return true;
1209
9
}
1210
1211
178
bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
1212
  // Form and check allocation and deallocation calls.
1213
178
  assert(!IsPromiseDependentType &&
1214
178
         "cannot make statement while the promise type is dependent");
1215
178
  QualType PromiseType = Fn.CoroutinePromise->getType();
1216
178
1217
178
  if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
1218
0
    return false;
1219
178
1220
178
  const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
1221
178
1222
  // [dcl.fct.def.coroutine]/7
1223
  // Lookup allocation functions using a parameter list composed of the
1224
  // requested size of the coroutine state being allocated, followed by
1225
  // the coroutine function's arguments. If a matching allocation function
1226
  // exists, use it. Otherwise, use an allocation function that just takes
1227
  // the requested size.
1228
178
1229
178
  FunctionDecl *OperatorNew = nullptr;
1230
178
  FunctionDecl *OperatorDelete = nullptr;
1231
178
  FunctionDecl *UnusedResult = nullptr;
1232
178
  bool PassAlignment = false;
1233
178
  SmallVector<Expr *, 1> PlacementArgs;
1234
178
1235
  // [dcl.fct.def.coroutine]/7
1236
  // "The allocation function’s name is looked up in the scope of P.
1237
  // [...] If the lookup finds an allocation function in the scope of P,
1238
  // overload resolution is performed on a function call created by assembling
1239
  // an argument list. The first argument is the amount of space requested,
1240
  // and has type std::size_t. The lvalues p1 ... pn are the succeeding
1241
  // arguments."
1242
  //
1243
  // ...where "p1 ... pn" are defined earlier as:
1244
  //
1245
  // [dcl.fct.def.coroutine]/3
1246
  // "For a coroutine f that is a non-static member function, let P1 denote the
1247
  // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
1248
  // of the function parameters; otherwise let P1 ... Pn be the types of the
1249
  // function parameters. Let p1 ... pn be lvalues denoting those objects."
1250
178
  if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
1251
50
    if (MD->isInstance() && 
!isLambdaCallOperator(MD)42
) {
1252
37
      ExprResult ThisExpr = S.ActOnCXXThis(Loc);
1253
37
      if (ThisExpr.isInvalid())
1254
0
        return false;
1255
37
      ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
1256
37
      if (ThisExpr.isInvalid())
1257
0
        return false;
1258
37
      PlacementArgs.push_back(ThisExpr.get());
1259
37
    }
1260
50
  }
1261
178
  for (auto *PD : FD.parameters()) {
1262
122
    if (PD->getType()->isDependentType())
1263
0
      continue;
1264
122
1265
    // Build a reference to the parameter.
1266
122
    auto PDLoc = PD->getLocation();
1267
122
    ExprResult PDRefExpr =
1268
122
        S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
1269
122
                           ExprValueKind::VK_LValue, PDLoc);
1270
122
    if (PDRefExpr.isInvalid())
1271
0
      return false;
1272
122
1273
122
    PlacementArgs.push_back(PDRefExpr.get());
1274
122
  }
1275
178
  S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1276
178
                            /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1277
178
                            /*isArray*/ false, PassAlignment, PlacementArgs,
1278
178
                            OperatorNew, UnusedResult, /*Diagnose*/ false);
1279
178
1280
  // [dcl.fct.def.coroutine]/7
1281
  // "If no matching function is found, overload resolution is performed again
1282
  // on a function call created by passing just the amount of space required as
1283
  // an argument of type std::size_t."
1284
178
  if (!OperatorNew && 
!PlacementArgs.empty()174
) {
1285
94
    PlacementArgs.clear();
1286
94
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1287
94
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1288
94
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1289
94
                              OperatorNew, UnusedResult, /*Diagnose*/ false);
1290
94
  }
1291
178
1292
  // [dcl.fct.def.coroutine]/7
1293
  // "The allocation function’s name is looked up in the scope of P. If this
1294
  // lookup fails, the allocation function’s name is looked up in the global
1295
  // scope."
1296
178
  if (!OperatorNew) {
1297
172
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global,
1298
172
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1299
172
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1300
172
                              OperatorNew, UnusedResult);
1301
172
  }
1302
178
1303
178
  bool IsGlobalOverload =
1304
178
      OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1305
  // If we didn't find a class-local new declaration and non-throwing new
1306
  // was is required then we need to lookup the non-throwing global operator
1307
  // instead.
1308
178
  if (RequiresNoThrowAlloc && 
(9
!OperatorNew9
||
IsGlobalOverload9
)) {
1309
6
    auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1310
6
    if (!StdNoThrow)
1311
0
      return false;
1312
6
    PlacementArgs = {StdNoThrow};
1313
6
    OperatorNew = nullptr;
1314
6
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both,
1315
6
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1316
6
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1317
6
                              OperatorNew, UnusedResult);
1318
6
  }
1319
178
1320
178
  if (!OperatorNew)
1321
0
    return false;
1322
178
1323
178
  if (RequiresNoThrowAlloc) {
1324
9
    const auto *FT = OperatorNew->getType()->castAs<FunctionProtoType>();
1325
9
    if (!FT->isNothrow(/*ResultIfDependent*/ false)) {
1326
2
      S.Diag(OperatorNew->getLocation(),
1327
2
             diag::err_coroutine_promise_new_requires_nothrow)
1328
2
          << OperatorNew;
1329
2
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1330
2
          << OperatorNew;
1331
2
      return false;
1332
2
    }
1333
176
  }
1334
176
1335
176
  if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr)
1336
0
    return false;
1337
176
1338
176
  Expr *FramePtr =
1339
176
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1340
176
1341
176
  Expr *FrameSize =
1342
176
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1343
176
1344
  // Make new call.
1345
176
1346
176
  ExprResult NewRef =
1347
176
      S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1348
176
  if (NewRef.isInvalid())
1349
0
    return false;
1350
176
1351
176
  SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1352
176
  for (auto Arg : PlacementArgs)
1353
16
    NewArgs.push_back(Arg);
1354
176
1355
176
  ExprResult NewExpr =
1356
176
      S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1357
176
  NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false);
1358
176
  if (NewExpr.isInvalid())
1359
0
    return false;
1360
176
1361
  // Make delete call.
1362
176
1363
176
  QualType OpDeleteQualType = OperatorDelete->getType();
1364
176
1365
176
  ExprResult DeleteRef =
1366
176
      S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1367
176
  if (DeleteRef.isInvalid())
1368
0
    return false;
1369
176
1370
176
  Expr *CoroFree =
1371
176
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1372
176
1373
176
  SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1374
176
1375
  // Check if we need to pass the size.
1376
176
  const auto *OpDeleteType =
1377
176
      OpDeleteQualType.getTypePtr()->castAs<FunctionProtoType>();
1378
176
  if (OpDeleteType->getNumParams() > 1)
1379
1
    DeleteArgs.push_back(FrameSize);
1380
176
1381
176
  ExprResult DeleteExpr =
1382
176
      S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1383
176
  DeleteExpr =
1384
176
      S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false);
1385
176
  if (DeleteExpr.isInvalid())
1386
0
    return false;
1387
176
1388
176
  this->Allocate = NewExpr.get();
1389
176
  this->Deallocate = DeleteExpr.get();
1390
176
1391
176
  return true;
1392
176
}
1393
1394
186
bool CoroutineStmtBuilder::makeOnFallthrough() {
1395
186
  assert(!IsPromiseDependentType &&
1396
186
         "cannot make statement while the promise type is dependent");
1397
186
1398
  // [dcl.fct.def.coroutine]/4
1399
  // The unqualified-ids 'return_void' and 'return_value' are looked up in
1400
  // the scope of class P. If both are found, the program is ill-formed.
1401
186
  bool HasRVoid, HasRValue;
1402
186
  LookupResult LRVoid =
1403
186
      lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1404
186
  LookupResult LRValue =
1405
186
      lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1406
186
1407
186
  StmtResult Fallthrough;
1408
186
  if (HasRVoid && 
HasRValue133
) {
1409
    // FIXME Improve this diagnostic
1410
2
    S.Diag(FD.getLocation(),
1411
2
           diag::err_coroutine_promise_incompatible_return_functions)
1412
2
        << PromiseRecordDecl;
1413
2
    S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1414
2
           diag::note_member_first_declared_here)
1415
2
        << LRVoid.getLookupName();
1416
2
    S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1417
2
           diag::note_member_first_declared_here)
1418
2
        << LRValue.getLookupName();
1419
2
    return false;
1420
184
  } else if (!HasRVoid && 
!HasRValue53
) {
1421
    // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1422
    // However we still diagnose this as an error since until the PDTS is fixed.
1423
1
    S.Diag(FD.getLocation(),
1424
1
           diag::err_coroutine_promise_requires_return_function)
1425
1
        << PromiseRecordDecl;
1426
1
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1427
1
        << PromiseRecordDecl;
1428
1
    return false;
1429
183
  } else if (HasRVoid) {
1430
    // If the unqualified-id return_void is found, flowing off the end of a
1431
    // coroutine is equivalent to a co_return with no operand. Otherwise,
1432
    // flowing off the end of a coroutine results in undefined behavior.
1433
131
    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1434
131
                                      /*IsImplicit*/false);
1435
131
    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1436
131
    if (Fallthrough.isInvalid())
1437
0
      return false;
1438
183
  }
1439
183
1440
183
  this->OnFallthrough = Fallthrough.get();
1441
183
  return true;
1442
183
}
1443
1444
189
bool CoroutineStmtBuilder::makeOnException() {
1445
  // Try to form 'p.unhandled_exception();'
1446
189
  assert(!IsPromiseDependentType &&
1447
189
         "cannot make statement while the promise type is dependent");
1448
189
1449
189
  const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1450
189
1451
189
  if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) {
1452
27
    auto DiagID =
1453
27
        RequireUnhandledException
1454
2
            ? diag::err_coroutine_promise_unhandled_exception_required
1455
25
            : diag::
1456
25
                  warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1457
27
    S.Diag(Loc, DiagID) << PromiseRecordDecl;
1458
27
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1459
27
        << PromiseRecordDecl;
1460
27
    return !RequireUnhandledException;
1461
27
  }
1462
162
1463
  // If exceptions are disabled, don't try to build OnException.
1464
162
  if (!S.getLangOpts().CXXExceptions)
1465
45
    return true;
1466
117
1467
117
  ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1468
117
                                                   "unhandled_exception", None);
1469
117
  UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc,
1470
117
                                             /*DiscardedValue*/ false);
1471
117
  if (UnhandledException.isInvalid())
1472
0
    return false;
1473
117
1474
  // Since the body of the coroutine will be wrapped in try-catch, it will
1475
  // be incompatible with SEH __try if present in a function.
1476
117
  if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) {
1477
1
    S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1478
1
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1479
1
        << Fn.getFirstCoroutineStmtKeyword();
1480
1
    return false;
1481
1
  }
1482
116
1483
116
  this->OnException = UnhandledException.get();
1484
116
  return true;
1485
116
}
1486
1487
203
bool CoroutineStmtBuilder::makeReturnObject() {
1488
  // Build implicit 'p.get_return_object()' expression and form initialization
1489
  // of return type from it.
1490
203
  ExprResult ReturnObject =
1491
203
      buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1492
203
  if (ReturnObject.isInvalid())
1493
1
    return false;
1494
202
1495
202
  this->ReturnValue = ReturnObject.get();
1496
202
  return true;
1497
202
}
1498
1499
2
static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1500
2
  if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) {
1501
2
    auto *MethodDecl = MbrRef->getMethodDecl();
1502
2
    S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1503
2
        << MethodDecl;
1504
2
  }
1505
2
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1506
2
      << Fn.getFirstCoroutineStmtKeyword();
1507
2
}
1508
1509
183
bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1510
183
  assert(!IsPromiseDependentType &&
1511
183
         "cannot make statement while the promise type is dependent");
1512
183
  assert(this->ReturnValue && "ReturnValue must be already formed");
1513
183
1514
183
  QualType const GroType = this->ReturnValue->getType();
1515
183
  assert(!GroType->isDependentType() &&
1516
183
         "get_return_object type must no longer be dependent");
1517
183
1518
183
  QualType const FnRetType = FD.getReturnType();
1519
183
  assert(!FnRetType->isDependentType() &&
1520
183
         "get_return_object type must no longer be dependent");
1521
183
1522
183
  if (FnRetType->isVoidType()) {
1523
73
    ExprResult Res =
1524
73
        S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
1525
73
    if (Res.isInvalid())
1526
0
      return false;
1527
73
1528
73
    this->ResultDecl = Res.get();
1529
73
    return true;
1530
73
  }
1531
110
1532
110
  if (GroType->isVoidType()) {
1533
    // Trigger a nice error message.
1534
1
    InitializedEntity Entity =
1535
1
        InitializedEntity::InitializeResult(Loc, FnRetType, false);
1536
1
    S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1537
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1538
1
    return false;
1539
1
  }
1540
109
1541
109
  auto *GroDecl = VarDecl::Create(
1542
109
      S.Context, &FD, FD.getLocation(), FD.getLocation(),
1543
109
      &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1544
109
      S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1545
109
1546
109
  S.CheckVariableDeclarationType(GroDecl);
1547
109
  if (GroDecl->isInvalidDecl())
1548
0
    return false;
1549
109
1550
109
  InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1551
109
  ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1552
109
                                                     this->ReturnValue);
1553
109
  if (Res.isInvalid())
1554
0
    return false;
1555
109
1556
109
  Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
1557
109
  if (Res.isInvalid())
1558
0
    return false;
1559
109
1560
109
  S.AddInitializerToDecl(GroDecl, Res.get(),
1561
109
                         /*DirectInit=*/false);
1562
109
1563
109
  S.FinalizeDeclaration(GroDecl);
1564
109
1565
  // Form a declaration statement for the return declaration, so that AST
1566
  // visitors can more easily find it.
1567
109
  StmtResult GroDeclStmt =
1568
109
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1569
109
  if (GroDeclStmt.isInvalid())
1570
0
    return false;
1571
109
1572
109
  this->ResultDecl = GroDeclStmt.get();
1573
109
1574
109
  ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1575
109
  if (declRef.isInvalid())
1576
0
    return false;
1577
109
1578
109
  StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1579
109
  if (ReturnStmt.isInvalid()) {
1580
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1581
1
    return false;
1582
1
  }
1583
108
  if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
1584
94
    GroDecl->setNRVOVariable(true);
1585
108
1586
108
  this->ReturnStmt = ReturnStmt.get();
1587
108
  return true;
1588
108
}
1589
1590
// Create a static_cast\<T&&>(expr).
1591
52
static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1592
52
  if (T.isNull())
1593
52
    T = E->getType();
1594
52
  QualType TargetType = S.BuildReferenceType(
1595
52
      T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1596
52
  SourceLocation ExprLoc = E->getBeginLoc();
1597
52
  TypeSourceInfo *TargetLoc =
1598
52
      S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1599
52
1600
52
  return S
1601
52
      .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1602
52
                         SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1603
52
      .get();
1604
52
}
1605
1606
/// Build a variable declaration for move parameter.
1607
static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1608
165
                             IdentifierInfo *II) {
1609
165
  TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1610
165
  VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
1611
165
                                  TInfo, SC_None);
1612
165
  Decl->setImplicit();
1613
165
  return Decl;
1614
165
}
1615
1616
// Build statements that move coroutine function parameters to the coroutine
1617
// frame, and store them on the function scope info.
1618
283
bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
1619
283
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
1620
283
  auto *FD = cast<FunctionDecl>(CurContext);
1621
283
1622
283
  auto *ScopeInfo = getCurFunction();
1623
283
  if (!ScopeInfo->CoroutineParameterMoves.empty())
1624
1
    return false;
1625
282
1626
282
  for (auto *PD : FD->parameters()) {
1627
208
    if (PD->getType()->isDependentType())
1628
43
      continue;
1629
165
1630
165
    ExprResult PDRefExpr =
1631
165
        BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
1632
165
                         ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1633
165
    if (PDRefExpr.isInvalid())
1634
0
      return false;
1635
165
1636
165
    Expr *CExpr = nullptr;
1637
165
    if (PD->getType()->getAsCXXRecordDecl() ||
1638
121
        PD->getType()->isRValueReferenceType())
1639
52
      CExpr = castForMoving(*this, PDRefExpr.get());
1640
113
    else
1641
113
      CExpr = PDRefExpr.get();
1642
165
1643
165
    auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
1644
165
    AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
1645
165
1646
    // Convert decl to a statement.
1647
165
    StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
1648
165
    if (Stmt.isInvalid())
1649
0
      return false;
1650
165
1651
165
    ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
1652
165
  }
1653
282
  return true;
1654
282
}
1655
1656
44
StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1657
44
  CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1658
44
  if (!Res)
1659
0
    return StmtError();
1660
44
  return Res;
1661
44
}
1662
1663
ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc,
1664
224
                                               SourceLocation FuncLoc) {
1665
224
  if (!StdCoroutineTraitsCache) {
1666
45
    if (auto StdExp = lookupStdExperimentalNamespace()) {
1667
45
      LookupResult Result(*this,
1668
45
                          &PP.getIdentifierTable().get("coroutine_traits"),
1669
45
                          FuncLoc, LookupOrdinaryName);
1670
45
      if (!LookupQualifiedName(Result, StdExp)) {
1671
0
        Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
1672
0
            << "std::experimental::coroutine_traits";
1673
0
        return nullptr;
1674
0
      }
1675
45
      if (!(StdCoroutineTraitsCache =
1676
0
                Result.getAsSingle<ClassTemplateDecl>())) {
1677
0
        Result.suppressDiagnostics();
1678
0
        NamedDecl *Found = *Result.begin();
1679
0
        Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
1680
0
        return nullptr;
1681
0
      }
1682
224
    }
1683
45
  }
1684
224
  return StdCoroutineTraitsCache;
1685
224
}