Coverage Report

Created: 2020-02-18 08:44

/Users/buildslave/jenkins/workspace/coverage/llvm-project/clang/lib/Sema/SemaCoroutine.cpp
Line
Count
Source (jump to first uncovered line)
1
//===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
//  This file implements semantic analysis for C++ Coroutines.
10
//
11
//  This file contains references to sections of the Coroutines TS, which
12
//  can be found at http://wg21.link/coroutines.
13
//
14
//===----------------------------------------------------------------------===//
15
16
#include "CoroutineStmtBuilder.h"
17
#include "clang/AST/ASTLambda.h"
18
#include "clang/AST/Decl.h"
19
#include "clang/AST/ExprCXX.h"
20
#include "clang/AST/StmtCXX.h"
21
#include "clang/Basic/Builtins.h"
22
#include "clang/Lex/Preprocessor.h"
23
#include "clang/Sema/Initialization.h"
24
#include "clang/Sema/Overload.h"
25
#include "clang/Sema/ScopeInfo.h"
26
#include "clang/Sema/SemaInternal.h"
27
28
using namespace clang;
29
using namespace sema;
30
31
static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
32
637
                                 SourceLocation Loc, bool &Res) {
33
637
  DeclarationName DN = S.PP.getIdentifierInfo(Name);
34
637
  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
35
637
  // Suppress diagnostics when a private member is selected. The same warnings
36
637
  // will be produced again when building the call.
37
637
  LR.suppressDiagnostics();
38
637
  Res = S.LookupQualifiedName(LR, RD);
39
637
  return LR;
40
637
}
41
42
static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
43
287
                         SourceLocation Loc) {
44
287
  bool Res;
45
287
  lookupMember(S, Name, RD, Loc, Res);
46
287
  return Res;
47
287
}
48
49
/// Look up the std::coroutine_traits<...>::promise_type for the given
50
/// function type.
51
static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
52
214
                                  SourceLocation KwLoc) {
53
214
  const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
54
214
  const SourceLocation FuncLoc = FD->getLocation();
55
214
  // FIXME: Cache std::coroutine_traits once we've found it.
56
214
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
57
214
  if (!StdExp) {
58
4
    S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
59
4
        << "std::experimental::coroutine_traits";
60
4
    return QualType();
61
4
  }
62
210
63
210
  ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc);
64
210
  if (!CoroTraits) {
65
0
    return QualType();
66
0
  }
67
210
68
210
  // Form template argument list for coroutine_traits<R, P1, P2, ...> according
69
210
  // to [dcl.fct.def.coroutine]3
70
210
  TemplateArgumentListInfo Args(KwLoc, KwLoc);
71
406
  auto AddArg = [&](QualType T) {
72
406
    Args.addArgument(TemplateArgumentLoc(
73
406
        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
74
406
  };
75
210
  AddArg(FnType->getReturnType());
76
210
  // If the function is a non-static member function, add the type
77
210
  // of the implicit object parameter before the formal parameters.
78
210
  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
79
59
    if (MD->isInstance()) {
80
47
      // [over.match.funcs]4
81
47
      // For non-static member functions, the type of the implicit object
82
47
      // parameter is
83
47
      //  -- "lvalue reference to cv X" for functions declared without a
84
47
      //      ref-qualifier or with the & ref-qualifier
85
47
      //  -- "rvalue reference to cv X" for functions declared with the &&
86
47
      //      ref-qualifier
87
47
      QualType T = MD->getThisType()->castAs<PointerType>()->getPointeeType();
88
47
      T = FnType->getRefQualifier() == RQ_RValue
89
47
              ? 
S.Context.getRValueReferenceType(T)7
90
47
              : 
S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true)40
;
91
47
      AddArg(T);
92
47
    }
93
59
  }
94
210
  for (QualType T : FnType->getParamTypes())
95
149
    AddArg(T);
96
210
97
210
  // Build the template-id.
98
210
  QualType CoroTrait =
99
210
      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
100
210
  if (CoroTrait.isNull())
101
0
    return QualType();
102
210
  if (S.RequireCompleteType(KwLoc, CoroTrait,
103
210
                            diag::err_coroutine_type_missing_specialization))
104
1
    return QualType();
105
209
106
209
  auto *RD = CoroTrait->getAsCXXRecordDecl();
107
209
  assert(RD && "specialization of class template is not a class?");
108
209
109
209
  // Look up the ::promise_type member.
110
209
  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
111
209
                 Sema::LookupOrdinaryName);
112
209
  S.LookupQualifiedName(R, RD);
113
209
  auto *Promise = R.getAsSingle<TypeDecl>();
114
209
  if (!Promise) {
115
6
    S.Diag(FuncLoc,
116
6
           diag::err_implied_std_coroutine_traits_promise_type_not_found)
117
6
        << RD;
118
6
    return QualType();
119
6
  }
120
203
  // The promise type is required to be a class type.
121
203
  QualType PromiseType = S.Context.getTypeDeclType(Promise);
122
203
123
203
  auto buildElaboratedType = [&]() {
124
203
    auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
125
203
    NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
126
203
                                      CoroTrait.getTypePtr());
127
203
    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
128
203
  };
129
203
130
203
  if (!PromiseType->getAsCXXRecordDecl()) {
131
1
    S.Diag(FuncLoc,
132
1
           diag::err_implied_std_coroutine_traits_promise_type_not_class)
133
1
        << buildElaboratedType();
134
1
    return QualType();
135
1
  }
136
202
  if (S.RequireCompleteType(FuncLoc, buildElaboratedType(),
137
202
                            diag::err_coroutine_promise_type_incomplete))
138
1
    return QualType();
139
201
140
201
  return PromiseType;
141
201
}
142
143
/// Look up the std::experimental::coroutine_handle<PromiseType>.
144
static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
145
549
                                          SourceLocation Loc) {
146
549
  if (PromiseType.isNull())
147
0
    return QualType();
148
549
149
549
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
150
549
  assert(StdExp && "Should already be diagnosed");
151
549
152
549
  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
153
549
                      Loc, Sema::LookupOrdinaryName);
154
549
  if (!S.LookupQualifiedName(Result, StdExp)) {
155
1
    S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
156
1
        << "std::experimental::coroutine_handle";
157
1
    return QualType();
158
1
  }
159
548
160
548
  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
161
548
  if (!CoroHandle) {
162
0
    Result.suppressDiagnostics();
163
0
    // We found something weird. Complain about the first thing we found.
164
0
    NamedDecl *Found = *Result.begin();
165
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
166
0
    return QualType();
167
0
  }
168
548
169
548
  // Form template argument list for coroutine_handle<Promise>.
170
548
  TemplateArgumentListInfo Args(Loc, Loc);
171
548
  Args.addArgument(TemplateArgumentLoc(
172
548
      TemplateArgument(PromiseType),
173
548
      S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
174
548
175
548
  // Build the template-id.
176
548
  QualType CoroHandleType =
177
548
      S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
178
548
  if (CoroHandleType.isNull())
179
0
    return QualType();
180
548
  if (S.RequireCompleteType(Loc, CoroHandleType,
181
548
                            diag::err_coroutine_type_missing_specialization))
182
0
    return QualType();
183
548
184
548
  return CoroHandleType;
185
548
}
186
187
static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
188
1.32k
                                    StringRef Keyword) {
189
1.32k
  // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within
190
1.32k
  // a function body.
191
1.32k
  // FIXME: This also covers [expr.await]p2: "An await-expression shall not
192
1.32k
  // appear in a default argument." But the diagnostic QoI here could be
193
1.32k
  // improved to inform the user that default arguments specifically are not
194
1.32k
  // allowed.
195
1.32k
  auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
196
1.32k
  if (!FD) {
197
1
    S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
198
1
                    ? 
diag::err_coroutine_objc_method0
199
1
                    : diag::err_coroutine_outside_function) << Keyword;
200
1
    return false;
201
1
  }
202
1.32k
203
1.32k
  // An enumeration for mapping the diagnostic type to the correct diagnostic
204
1.32k
  // selection index.
205
1.32k
  enum InvalidFuncDiag {
206
1.32k
    DiagCtor = 0,
207
1.32k
    DiagDtor,
208
1.32k
    DiagMain,
209
1.32k
    DiagConstexpr,
210
1.32k
    DiagAutoRet,
211
1.32k
    DiagVarargs,
212
1.32k
    DiagConsteval,
213
1.32k
  };
214
1.32k
  bool Diagnosed = false;
215
1.32k
  auto DiagInvalid = [&](InvalidFuncDiag ID) {
216
9
    S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
217
9
    Diagnosed = true;
218
9
    return false;
219
9
  };
220
1.32k
221
1.32k
  // Diagnose when a constructor, destructor
222
1.32k
  // or the function 'main' are declared as a coroutine.
223
1.32k
  auto *MD = dyn_cast<CXXMethodDecl>(FD);
224
1.32k
  // [class.ctor]p11: "A constructor shall not be a coroutine."
225
1.32k
  if (MD && 
isa<CXXConstructorDecl>(MD)359
)
226
2
    return DiagInvalid(DiagCtor);
227
1.32k
  // [class.dtor]p17: "A destructor shall not be a coroutine."
228
1.32k
  else if (MD && 
isa<CXXDestructorDecl>(MD)357
)
229
1
    return DiagInvalid(DiagDtor);
230
1.32k
  // [basic.start.main]p3: "The function main shall not be a coroutine."
231
1.32k
  else if (FD->isMain())
232
1
    return DiagInvalid(DiagMain);
233
1.32k
234
1.32k
  // Emit a diagnostics for each of the following conditions which is not met.
235
1.32k
  // [expr.const]p2: "An expression e is a core constant expression unless the
236
1.32k
  // evaluation of e [...] would evaluate one of the following expressions:
237
1.32k
  // [...] an await-expression [...] a yield-expression."
238
1.32k
  if (FD->isConstexpr())
239
2
    DiagInvalid(FD->isConsteval() ? 
DiagConsteval0
: DiagConstexpr);
240
1.32k
  // [dcl.spec.auto]p15: "A function declared with a return type that uses a
241
1.32k
  // placeholder type shall not be a coroutine."
242
1.32k
  if (FD->getReturnType()->isUndeducedType())
243
2
    DiagInvalid(DiagAutoRet);
244
1.32k
  // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the
245
1.32k
  // coroutine shall not terminate with an ellipsis that is not part of a
246
1.32k
  // parameter-declaration."
247
1.32k
  if (FD->isVariadic())
248
1
    DiagInvalid(DiagVarargs);
249
1.32k
250
1.32k
  return !Diagnosed;
251
1.32k
}
252
253
static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
254
559
                                                 SourceLocation Loc) {
255
559
  DeclarationName OpName =
256
559
      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
257
559
  LookupResult Operators(SemaRef, OpName, SourceLocation(),
258
559
                         Sema::LookupOperatorName);
259
559
  SemaRef.LookupName(Operators, S);
260
559
261
559
  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
262
559
  const auto &Functions = Operators.asUnresolvedSet();
263
559
  bool IsOverloaded =
264
559
      Functions.size() > 1 ||
265
559
      
(535
Functions.size() == 1535
&&
isa<FunctionTemplateDecl>(*Functions.begin())38
);
266
559
  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
267
559
      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
268
559
      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
269
559
      Functions.begin(), Functions.end());
270
559
  assert(CoawaitOp);
271
559
  return CoawaitOp;
272
559
}
273
274
/// Build a call to 'operator co_await' if there is a suitable operator for
275
/// the given expression.
276
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
277
                                           Expr *E,
278
552
                                           UnresolvedLookupExpr *Lookup) {
279
552
  UnresolvedSet<16> Functions;
280
552
  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
281
552
  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
282
552
}
283
284
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
285
447
                                           SourceLocation Loc, Expr *E) {
286
447
  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
287
447
  if (R.isInvalid())
288
0
    return ExprError();
289
447
  return buildOperatorCoawaitCall(SemaRef, Loc, E,
290
447
                                  cast<UnresolvedLookupExpr>(R.get()));
291
447
}
292
293
static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
294
1.04k
                              MultiExprArg CallArgs) {
295
1.04k
  StringRef Name = S.Context.BuiltinInfo.getName(Id);
296
1.04k
  LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
297
1.04k
  S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
298
1.04k
299
1.04k
  auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
300
1.04k
  assert(BuiltInDecl && "failed to find builtin declaration");
301
1.04k
302
1.04k
  ExprResult DeclRef =
303
1.04k
      S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
304
1.04k
  assert(DeclRef.isUsable() && "Builtin reference cannot fail");
305
1.04k
306
1.04k
  ExprResult Call =
307
1.04k
      S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
308
1.04k
309
1.04k
  assert(!Call.isInvalid() && "Call to builtin cannot fail!");
310
1.04k
  return Call.get();
311
1.04k
}
312
313
static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
314
549
                                       SourceLocation Loc) {
315
549
  QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
316
549
  if (CoroHandleType.isNull())
317
1
    return ExprError();
318
548
319
548
  DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
320
548
  LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
321
548
                     Sema::LookupOrdinaryName);
322
548
  if (!S.LookupQualifiedName(Found, LookupCtx)) {
323
1
    S.Diag(Loc, diag::err_coroutine_handle_missing_member)
324
1
        << "from_address";
325
1
    return ExprError();
326
1
  }
327
547
328
547
  Expr *FramePtr =
329
547
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
330
547
331
547
  CXXScopeSpec SS;
332
547
  ExprResult FromAddr =
333
547
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
334
547
  if (FromAddr.isInvalid())
335
0
    return ExprError();
336
547
337
547
  return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
338
547
}
339
340
struct ReadySuspendResumeResult {
341
  enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
342
  Expr *Results[3];
343
  OpaqueValueExpr *OpaqueValue;
344
  bool IsInvalid;
345
};
346
347
static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
348
2.62k
                                  StringRef Name, MultiExprArg Args) {
349
2.62k
  DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
350
2.62k
351
2.62k
  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
352
2.62k
  CXXScopeSpec SS;
353
2.62k
  ExprResult Result = S.BuildMemberReferenceExpr(
354
2.62k
      Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
355
2.62k
      SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
356
2.62k
      /*Scope=*/nullptr);
357
2.62k
  if (Result.isInvalid())
358
18
    return ExprError();
359
2.60k
360
2.60k
  // We meant exactly what we asked for. No need for typo correction.
361
2.60k
  if (auto *TE = dyn_cast<TypoExpr>(Result.get())) {
362
3
    S.clearDelayedTypo(TE);
363
3
    S.Diag(Loc, diag::err_no_member)
364
3
        << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl()
365
3
        << Base->getSourceRange();
366
3
    return ExprError();
367
3
  }
368
2.60k
369
2.60k
  return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
370
2.60k
}
371
372
// See if return type is coroutine-handle and if so, invoke builtin coro-resume
373
// on its address. This is to enable experimental support for coroutine-handle
374
// returning await_suspend that results in a guaranteed tail call to the target
375
// coroutine.
376
static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
377
533
                           SourceLocation Loc) {
378
533
  if (RetType->isReferenceType())
379
2
    return nullptr;
380
531
  Type const *T = RetType.getTypePtr();
381
531
  if (!T->isClassType() && !T->isStructureType())
382
530
    return nullptr;
383
1
384
1
  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
385
1
  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
386
1
  // a private function in SemaExprCXX.cpp
387
1
388
1
  ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
389
1
  if (AddressExpr.isInvalid())
390
0
    return nullptr;
391
1
392
1
  Expr *JustAddress = AddressExpr.get();
393
1
  // FIXME: Check that the type of AddressExpr is void*
394
1
  return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
395
1
                          JustAddress);
396
1
}
397
398
/// Build calls to await_ready, await_suspend, and await_resume for a co_await
399
/// expression.
400
static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
401
549
                                                  SourceLocation Loc, Expr *E) {
402
549
  OpaqueValueExpr *Operand = new (S.Context)
403
549
      OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
404
549
405
549
  // Assume invalid until we see otherwise.
406
549
  ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
407
549
408
549
  ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
409
549
  if (CoroHandleRes.isInvalid())
410
2
    return Calls;
411
547
  Expr *CoroHandle = CoroHandleRes.get();
412
547
413
547
  const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
414
547
  MultiExprArg Args[] = {None, CoroHandle, None};
415
2.14k
  for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; 
++I1.60k
) {
416
1.61k
    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
417
1.61k
    if (Result.isInvalid())
418
14
      return Calls;
419
1.60k
    Calls.Results[I] = Result.get();
420
1.60k
  }
421
547
422
547
  // Assume the calls are valid; all further checking should make them invalid.
423
547
  Calls.IsInvalid = false;
424
533
425
533
  using ACT = ReadySuspendResumeResult::AwaitCallType;
426
533
  CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
427
533
  if (!AwaitReady->getType()->isDependentType()) {
428
533
    // [expr.await]p3 [...]
429
533
    // — await-ready is the expression e.await_ready(), contextually converted
430
533
    // to bool.
431
533
    ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
432
533
    if (Conv.isInvalid()) {
433
1
      S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(),
434
1
             diag::note_await_ready_no_bool_conversion);
435
1
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
436
1
          << AwaitReady->getDirectCallee() << E->getSourceRange();
437
1
      Calls.IsInvalid = true;
438
1
    }
439
533
    Calls.Results[ACT::ACT_Ready] = Conv.get();
440
533
  }
441
533
  CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
442
533
  if (!AwaitSuspend->getType()->isDependentType()) {
443
533
    // [expr.await]p3 [...]
444
533
    //   - await-suspend is the expression e.await_suspend(h), which shall be
445
533
    //     a prvalue of type void or bool.
446
533
    QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
447
533
448
533
    // Experimental support for coroutine_handle returning await_suspend.
449
533
    if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
450
1
      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
451
532
    else {
452
532
      // non-class prvalues always have cv-unqualified types
453
532
      if (RetType->isReferenceType() ||
454
532
          
(530
!RetType->isBooleanType()530
&&
!RetType->isVoidType()526
)) {
455
3
        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
456
3
               diag::err_await_suspend_invalid_return_type)
457
3
            << RetType;
458
3
        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
459
3
            << AwaitSuspend->getDirectCallee();
460
3
        Calls.IsInvalid = true;
461
3
      }
462
532
    }
463
533
  }
464
533
465
533
  return Calls;
466
547
}
467
468
static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
469
                                   SourceLocation Loc, StringRef Name,
470
1.00k
                                   MultiExprArg Args) {
471
1.00k
472
1.00k
  // Form a reference to the promise.
473
1.00k
  ExprResult PromiseRef = S.BuildDeclRefExpr(
474
1.00k
      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
475
1.00k
  if (PromiseRef.isInvalid())
476
0
    return ExprError();
477
1.00k
478
1.00k
  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
479
1.00k
}
480
481
267
VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
482
267
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
483
267
  auto *FD = cast<FunctionDecl>(CurContext);
484
267
  bool IsThisDependentType = [&] {
485
267
    if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
486
83
      return MD->isInstance() && 
MD->getThisType()->isDependentType()71
;
487
184
    else
488
184
      return false;
489
267
  }();
490
267
491
267
  QualType T = FD->getType()->isDependentType() || 
IsThisDependentType227
492
267
                   ? 
Context.DependentTy53
493
267
                   : 
lookupPromiseType(*this, FD, Loc)214
;
494
267
  if (T.isNull())
495
13
    return nullptr;
496
254
497
254
  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
498
254
                             &PP.getIdentifierTable().get("__promise"), T,
499
254
                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
500
254
  CheckVariableDeclarationType(VD);
501
254
  if (VD->isInvalidDecl())
502
0
    return nullptr;
503
254
504
254
  auto *ScopeInfo = getCurFunction();
505
254
  // Build a list of arguments, based on the coroutine functions arguments,
506
254
  // that will be passed to the promise type's constructor.
507
254
  llvm::SmallVector<Expr *, 4> CtorArgExprs;
508
254
509
254
  // Add implicit object parameter.
510
254
  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
511
80
    if (MD->isInstance() && 
!isLambdaCallOperator(MD)69
) {
512
55
      ExprResult ThisExpr = ActOnCXXThis(Loc);
513
55
      if (ThisExpr.isInvalid())
514
0
        return nullptr;
515
55
      ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
516
55
      if (ThisExpr.isInvalid())
517
0
        return nullptr;
518
55
      CtorArgExprs.push_back(ThisExpr.get());
519
55
    }
520
80
  }
521
254
522
254
  auto &Moves = ScopeInfo->CoroutineParameterMoves;
523
254
  for (auto *PD : FD->parameters()) {
524
199
    if (PD->getType()->isDependentType())
525
42
      continue;
526
157
527
157
    auto RefExpr = ExprEmpty();
528
157
    auto Move = Moves.find(PD);
529
157
    assert(Move != Moves.end() &&
530
157
           "Coroutine function parameter not inserted into move map");
531
157
    // If a reference to the function parameter exists in the coroutine
532
157
    // frame, use that reference.
533
157
    auto *MoveDecl =
534
157
        cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
535
157
    RefExpr =
536
157
        BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
537
157
                         ExprValueKind::VK_LValue, FD->getLocation());
538
157
    if (RefExpr.isInvalid())
539
0
      return nullptr;
540
157
    CtorArgExprs.push_back(RefExpr.get());
541
157
  }
542
254
543
254
  // Create an initialization sequence for the promise type using the
544
254
  // constructor arguments, wrapped in a parenthesized list expression.
545
254
  Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(),
546
254
                                    CtorArgExprs, FD->getLocation());
547
254
  InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
548
254
  InitializationKind Kind = InitializationKind::CreateForInit(
549
254
      VD->getLocation(), /*DirectInit=*/true, PLE);
550
254
  InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
551
254
                                 /*TopLevelOfInitList=*/false,
552
254
                                 /*TreatUnavailableAsInvalid=*/false);
553
254
554
254
  // Attempt to initialize the promise type with the arguments.
555
254
  // If that fails, fall back to the promise type's default constructor.
556
254
  if (InitSeq) {
557
147
    ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
558
147
    if (Result.isInvalid()) {
559
0
      VD->setInvalidDecl();
560
147
    } else if (Result.get()) {
561
147
      VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
562
147
      VD->setInitStyle(VarDecl::CallInit);
563
147
      CheckCompleteVariableDeclaration(VD);
564
147
    }
565
147
  } else
566
107
    ActOnUninitializedDecl(VD);
567
254
568
254
  FD->addDecl(VD);
569
254
  return VD;
570
254
}
571
572
/// Check that this is a context in which a coroutine suspension can appear.
573
static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
574
                                                StringRef Keyword,
575
1.32k
                                                bool IsImplicit = false) {
576
1.32k
  if (!isValidCoroutineContext(S, Loc, Keyword))
577
9
    return nullptr;
578
1.31k
579
1.31k
  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
580
1.31k
581
1.31k
  auto *ScopeInfo = S.getCurFunction();
582
1.31k
  assert(ScopeInfo && "missing function scope for function");
583
1.31k
584
1.31k
  if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && 
!IsImplicit375
)
585
267
    ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
586
1.31k
587
1.31k
  if (ScopeInfo->CoroutinePromise)
588
1.10k
    return ScopeInfo;
589
214
590
214
  if (!S.buildCoroutineParameterMoves(Loc))
591
1
    return nullptr;
592
213
593
213
  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
594
213
  if (!ScopeInfo->CoroutinePromise)
595
13
    return nullptr;
596
200
597
200
  return ScopeInfo;
598
200
}
599
600
bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
601
287
                                   StringRef Keyword) {
602
287
  if (!checkCoroutineContext(*this, KWLoc, Keyword))
603
23
    return false;
604
264
  auto *ScopeInfo = getCurFunction();
605
264
  assert(ScopeInfo->CoroutinePromise);
606
264
607
264
  // If we have existing coroutine statements then we have already built
608
264
  // the initial and final suspend points.
609
264
  if (!ScopeInfo->NeedsCoroutineSuspends)
610
64
    return true;
611
200
612
200
  ScopeInfo->setNeedsCoroutineSuspends(false);
613
200
614
200
  auto *Fn = cast<FunctionDecl>(CurContext);
615
200
  SourceLocation Loc = Fn->getLocation();
616
200
  // Build the initial suspend point
617
393
  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
618
393
    ExprResult Suspend =
619
393
        buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
620
393
    if (Suspend.isInvalid())
621
3
      return StmtError();
622
390
    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
623
390
    if (Suspend.isInvalid())
624
0
      return StmtError();
625
390
    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
626
390
                                       /*IsImplicit*/ true);
627
390
    Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
628
390
    if (Suspend.isInvalid()) {
629
6
      Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
630
6
          << ((Name == "initial_suspend") ? 
05
:
11
);
631
6
      Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
632
6
      return StmtError();
633
6
    }
634
384
    return cast<Stmt>(Suspend.get());
635
384
  };
636
200
637
200
  StmtResult InitSuspend = buildSuspends("initial_suspend");
638
200
  if (InitSuspend.isInvalid())
639
7
    return true;
640
193
641
193
  StmtResult FinalSuspend = buildSuspends("final_suspend");
642
193
  if (FinalSuspend.isInvalid())
643
2
    return true;
644
191
645
191
  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
646
191
647
191
  return true;
648
191
}
649
650
// Recursively walks up the scope hierarchy until either a 'catch' or a function
651
// scope is found, whichever comes first.
652
172
static bool isWithinCatchScope(Scope *S) {
653
172
  // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but
654
172
  // lambdas that use 'co_await' are allowed. The loop below ends when a
655
172
  // function scope is found in order to ensure the following behavior:
656
172
  //
657
172
  // void foo() {      // <- function scope
658
172
  //   try {           //
659
172
  //     co_await x;   // <- 'co_await' is OK within a function scope
660
172
  //   } catch {       // <- catch scope
661
172
  //     co_await x;   // <- 'co_await' is not OK within a catch scope
662
172
  //     []() {        // <- function scope
663
172
  //       co_await x; // <- 'co_await' is OK within a function scope
664
172
  //     }();
665
172
  //   }
666
172
  // }
667
198
  while (S && 
!(S->getFlags() & Scope::FnScope)192
) {
668
29
    if (S->getFlags() & Scope::CatchScope)
669
3
      return true;
670
26
    S = S->getParent();
671
26
  }
672
172
  
return false169
;
673
172
}
674
675
// [expr.await]p2, emphasis added: "An await-expression shall appear only in
676
// a *potentially evaluated* expression within the compound-statement of a
677
// function-body *outside of a handler* [...] A context within a function
678
// where an await-expression can appear is called a suspension context of the
679
// function."
680
static void checkSuspensionContext(Sema &S, SourceLocation Loc,
681
172
                                   StringRef Keyword) {
682
172
  // First emphasis of [expr.await]p2: must be a potentially evaluated context.
683
172
  // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of
684
172
  // \c sizeof.
685
172
  if (S.isUnevaluatedContext())
686
6
    S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
687
172
688
172
  // Second emphasis of [expr.await]p2: must be outside of an exception handler.
689
172
  if (isWithinCatchScope(S.getCurScope()))
690
3
    S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword;
691
172
}
692
693
125
ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
694
125
  if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
695
13
    CorrectDelayedTyposInExpr(E);
696
13
    return ExprError();
697
13
  }
698
112
699
112
  checkSuspensionContext(*this, Loc, "co_await");
700
112
701
112
  if (E->getType()->isPlaceholderType()) {
702
2
    ExprResult R = CheckPlaceholderExpr(E);
703
2
    if (R.isInvalid()) 
return ExprError()0
;
704
2
    E = R.get();
705
2
  }
706
112
  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
707
112
  if (Lookup.isInvalid())
708
0
    return ExprError();
709
112
  return BuildUnresolvedCoawaitExpr(Loc, E,
710
112
                                   cast<UnresolvedLookupExpr>(Lookup.get()));
711
112
}
712
713
ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
714
137
                                            UnresolvedLookupExpr *Lookup) {
715
137
  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
716
137
  if (!FSI)
717
0
    return ExprError();
718
137
719
137
  if (E->getType()->isPlaceholderType()) {
720
0
    ExprResult R = CheckPlaceholderExpr(E);
721
0
    if (R.isInvalid())
722
0
      return ExprError();
723
0
    E = R.get();
724
0
  }
725
137
726
137
  auto *Promise = FSI->CoroutinePromise;
727
137
  if (Promise->getType()->isDependentType()) {
728
28
    Expr *Res =
729
28
        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
730
28
    return Res;
731
28
  }
732
109
733
109
  auto *RD = Promise->getType()->getAsCXXRecordDecl();
734
109
  if (lookupMember(*this, "await_transform", RD, Loc)) {
735
22
    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
736
22
    if (R.isInvalid()) {
737
4
      Diag(Loc,
738
4
           diag::note_coroutine_promise_implicit_await_transform_required_here)
739
4
          << E->getSourceRange();
740
4
      return ExprError();
741
4
    }
742
18
    E = R.get();
743
18
  }
744
109
  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
745
105
  if (Awaitable.isInvalid())
746
4
    return ExprError();
747
101
748
101
  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
749
101
}
750
751
ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
752
604
                                  bool IsImplicit) {
753
604
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
754
604
  if (!Coroutine)
755
0
    return ExprError();
756
604
757
604
  if (E->getType()->isPlaceholderType()) {
758
0
    ExprResult R = CheckPlaceholderExpr(E);
759
0
    if (R.isInvalid()) return ExprError();
760
0
    E = R.get();
761
0
  }
762
604
763
604
  if (E->getType()->isDependentType()) {
764
111
    Expr *Res = new (Context)
765
111
        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
766
111
    return Res;
767
111
  }
768
493
769
493
  // If the expression is a temporary, materialize it as an lvalue so that we
770
493
  // can use it multiple times.
771
493
  if (E->getValueKind() == VK_RValue)
772
441
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
773
493
774
493
  // The location of the `co_await` token cannot be used when constructing
775
493
  // the member call expressions since it's before the location of `Expr`, which
776
493
  // is used as the start of the member call expression.
777
493
  SourceLocation CallLoc = E->getExprLoc();
778
493
779
493
  // Build the await_ready, await_suspend, await_resume calls.
780
493
  ReadySuspendResumeResult RSS =
781
493
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E);
782
493
  if (RSS.IsInvalid)
783
19
    return ExprError();
784
474
785
474
  Expr *Res =
786
474
      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
787
474
                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
788
474
789
474
  return Res;
790
474
}
791
792
63
ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
793
63
  if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
794
3
    CorrectDelayedTyposInExpr(E);
795
3
    return ExprError();
796
3
  }
797
60
798
60
  checkSuspensionContext(*this, Loc, "co_yield");
799
60
800
60
  // Build yield_value call.
801
60
  ExprResult Awaitable = buildPromiseCall(
802
60
      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
803
60
  if (Awaitable.isInvalid())
804
3
    return ExprError();
805
57
806
57
  // Build 'operator co_await' call.
807
57
  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
808
57
  if (Awaitable.isInvalid())
809
0
    return ExprError();
810
57
811
57
  return BuildCoyieldExpr(Loc, Awaitable.get());
812
57
}
813
76
ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
814
76
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
815
76
  if (!Coroutine)
816
0
    return ExprError();
817
76
818
76
  if (E->getType()->isPlaceholderType()) {
819
0
    ExprResult R = CheckPlaceholderExpr(E);
820
0
    if (R.isInvalid()) return ExprError();
821
0
    E = R.get();
822
0
  }
823
76
824
76
  if (E->getType()->isDependentType()) {
825
20
    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
826
20
    return Res;
827
20
  }
828
56
829
56
  // If the expression is a temporary, materialize it as an lvalue so that we
830
56
  // can use it multiple times.
831
56
  if (E->getValueKind() == VK_RValue)
832
56
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
833
56
834
56
  // Build the await_ready, await_suspend, await_resume calls.
835
56
  ReadySuspendResumeResult RSS =
836
56
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
837
56
  if (RSS.IsInvalid)
838
1
    return ExprError();
839
55
840
55
  Expr *Res =
841
55
      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
842
55
                                RSS.Results[2], RSS.OpaqueValue);
843
55
844
55
  return Res;
845
55
}
846
847
91
StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
848
91
  if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
849
6
    CorrectDelayedTyposInExpr(E);
850
6
    return StmtError();
851
6
  }
852
85
  return BuildCoreturnStmt(Loc, E);
853
85
}
854
855
StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
856
223
                                   bool IsImplicit) {
857
223
  auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
858
223
  if (!FSI)
859
0
    return StmtError();
860
223
861
223
  if (E && 
E->getType()->isPlaceholderType()32
&&
862
223
      
!E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)2
) {
863
0
    ExprResult R = CheckPlaceholderExpr(E);
864
0
    if (R.isInvalid()) return StmtError();
865
0
    E = R.get();
866
0
  }
867
223
868
223
  // Move the return value if we can
869
223
  if (E) {
870
32
    auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove);
871
32
    if (NRVOCandidate) {
872
5
      InitializedEntity Entity =
873
5
          InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate);
874
5
      ExprResult MoveResult = this->PerformMoveOrCopyInitialization(
875
5
          Entity, NRVOCandidate, E->getType(), E);
876
5
      if (MoveResult.get())
877
5
        E = MoveResult.get();
878
5
    }
879
32
  }
880
223
881
223
  // FIXME: If the operand is a reference to a variable that's about to go out
882
223
  // of scope, we should treat the operand as an xvalue for this overload
883
223
  // resolution.
884
223
  VarDecl *Promise = FSI->CoroutinePromise;
885
223
  ExprResult PC;
886
223
  if (E && 
(32
isa<InitListExpr>(E)32
||
!E->getType()->isVoidType()29
)) {
887
31
    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
888
192
  } else {
889
192
    E = MakeFullDiscardedValueExpr(E).get();
890
192
    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
891
192
  }
892
223
  if (PC.isInvalid())
893
4
    return StmtError();
894
219
895
219
  Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get();
896
219
897
219
  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
898
219
  return Res;
899
219
}
900
901
/// Look up the std::nothrow object.
902
6
static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
903
6
  NamespaceDecl *Std = S.getStdNamespace();
904
6
  assert(Std && "Should already be diagnosed");
905
6
906
6
  LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
907
6
                      Sema::LookupOrdinaryName);
908
6
  if (!S.LookupQualifiedName(Result, Std)) {
909
0
    // FIXME: <experimental/coroutine> should have been included already.
910
0
    // If we require it to include <new> then this diagnostic is no longer
911
0
    // needed.
912
0
    S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
913
0
    return nullptr;
914
0
  }
915
6
916
6
  auto *VD = Result.getAsSingle<VarDecl>();
917
6
  if (!VD) {
918
0
    Result.suppressDiagnostics();
919
0
    // We found something weird. Complain about the first thing we found.
920
0
    NamedDecl *Found = *Result.begin();
921
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
922
0
    return nullptr;
923
0
  }
924
6
925
6
  ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
926
6
  if (DR.isInvalid())
927
0
    return nullptr;
928
6
929
6
  return DR.get();
930
6
}
931
932
// Find an appropriate delete for the promise.
933
static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
934
165
                                          QualType PromiseType) {
935
165
  FunctionDecl *OperatorDelete = nullptr;
936
165
937
165
  DeclarationName DeleteName =
938
165
      S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
939
165
940
165
  auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
941
165
  assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
942
165
943
165
  if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
944
0
    return nullptr;
945
165
946
165
  if (!OperatorDelete) {
947
163
    // Look for a global declaration.
948
163
    const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
949
163
    const bool Overaligned = false;
950
163
    OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
951
163
                                                     Overaligned, DeleteName);
952
163
  }
953
165
  S.MarkFunctionReferenced(Loc, OperatorDelete);
954
165
  return OperatorDelete;
955
165
}
956
957
958
267
void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
959
267
  FunctionScopeInfo *Fn = getCurFunction();
960
267
  assert(Fn && Fn->isCoroutine() && "not a coroutine");
961
267
  if (!Body) {
962
11
    assert(FD->isInvalidDecl() &&
963
11
           "a null body is only allowed for invalid declarations");
964
11
    return;
965
11
  }
966
256
  // We have a function that uses coroutine keywords, but we failed to build
967
256
  // the promise type.
968
256
  if (!Fn->CoroutinePromise)
969
13
    return FD->setInvalidDecl();
970
243
971
243
  if (isa<CoroutineBodyStmt>(Body)) {
972
43
    // Nothing todo. the body is already a transformed coroutine body statement.
973
43
    return;
974
43
  }
975
200
976
200
  // Coroutines [stmt.return]p1:
977
200
  //   A return statement shall not appear in a coroutine.
978
200
  if (Fn->FirstReturnLoc.isValid()) {
979
13
    assert(Fn->FirstCoroutineStmtLoc.isValid() &&
980
13
                   "first coroutine location not set");
981
13
    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
982
13
    Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
983
13
            << Fn->getFirstCoroutineStmtKeyword();
984
13
  }
985
200
  CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
986
200
  if (Builder.isInvalid() || 
!Builder.buildStatements()191
)
987
20
    return FD->setInvalidDecl();
988
180
989
180
  // Build body for the coroutine wrapper statement.
990
180
  Body = CoroutineBodyStmt::Create(Context, Builder);
991
180
}
992
993
CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
994
                                           sema::FunctionScopeInfo &Fn,
995
                                           Stmt *Body)
996
    : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
997
      IsPromiseDependentType(
998
          !Fn.CoroutinePromise ||
999
246
          Fn.CoroutinePromise->getType()->isDependentType()) {
1000
246
  this->Body = Body;
1001
246
1002
246
  for (auto KV : Fn.CoroutineParameterMoves)
1003
148
    this->ParamMovesVector.push_back(KV.second);
1004
246
  this->ParamMoves = this->ParamMovesVector;
1005
246
1006
246
  if (!IsPromiseDependentType) {
1007
193
    PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
1008
193
    assert(PromiseRecordDecl && "Type should have already been checked");
1009
193
  }
1010
246
  this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
1011
246
}
1012
1013
191
bool CoroutineStmtBuilder::buildStatements() {
1014
191
  assert(this->IsValid && "coroutine already invalid");
1015
191
  this->IsValid = makeReturnObject();
1016
191
  if (this->IsValid && 
!IsPromiseDependentType190
)
1017
141
    buildDependentStatements();
1018
191
  return this->IsValid;
1019
191
}
1020
1021
178
bool CoroutineStmtBuilder::buildDependentStatements() {
1022
178
  assert(this->IsValid && "coroutine already invalid");
1023
178
  assert(!this->IsPromiseDependentType &&
1024
178
         "coroutine cannot have a dependent promise type");
1025
178
  this->IsValid = makeOnException() && 
makeOnFallthrough()175
&&
1026
178
                  
makeGroDeclAndReturnStmt()172
&&
makeReturnOnAllocFailure()170
&&
1027
178
                  
makeNewAndDeleteExpr()167
;
1028
178
  return this->IsValid;
1029
178
}
1030
1031
246
bool CoroutineStmtBuilder::makePromiseStmt() {
1032
246
  // Form a declaration statement for the promise declaration, so that AST
1033
246
  // visitors can more easily find it.
1034
246
  StmtResult PromiseStmt =
1035
246
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
1036
246
  if (PromiseStmt.isInvalid())
1037
0
    return false;
1038
246
1039
246
  this->Promise = PromiseStmt.get();
1040
246
  return true;
1041
246
}
1042
1043
246
bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
1044
246
  if (Fn.hasInvalidCoroutineSuspends())
1045
9
    return false;
1046
237
  this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
1047
237
  this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
1048
237
  return true;
1049
237
}
1050
1051
static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
1052
                                     CXXRecordDecl *PromiseRecordDecl,
1053
12
                                     FunctionScopeInfo &Fn) {
1054
12
  auto Loc = E->getExprLoc();
1055
12
  if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
1056
12
    auto *Decl = DeclRef->getDecl();
1057
12
    if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
1058
12
      if (Method->isStatic())
1059
11
        return true;
1060
1
      else
1061
1
        Loc = Decl->getLocation();
1062
12
    }
1063
12
  }
1064
12
1065
12
  S.Diag(
1066
1
      Loc,
1067
1
      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
1068
1
      << PromiseRecordDecl;
1069
1
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1070
1
      << Fn.getFirstCoroutineStmtKeyword();
1071
1
  return false;
1072
12
}
1073
1074
170
bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
1075
170
  assert(!IsPromiseDependentType &&
1076
170
         "cannot make statement while the promise type is dependent");
1077
170
1078
170
  // [dcl.fct.def.coroutine]/8
1079
170
  // The unqualified-id get_return_object_on_allocation_failure is looked up in
1080
170
  // the scope of class P by class member access lookup (3.4.5). ...
1081
170
  // If an allocation function returns nullptr, ... the coroutine return value
1082
170
  // is obtained by a call to ... get_return_object_on_allocation_failure().
1083
170
1084
170
  DeclarationName DN =
1085
170
      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
1086
170
  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
1087
170
  if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
1088
158
    return true;
1089
12
1090
12
  CXXScopeSpec SS;
1091
12
  ExprResult DeclNameExpr =
1092
12
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
1093
12
  if (DeclNameExpr.isInvalid())
1094
0
    return false;
1095
12
1096
12
  if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
1097
1
    return false;
1098
11
1099
11
  ExprResult ReturnObjectOnAllocationFailure =
1100
11
      S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
1101
11
  if (ReturnObjectOnAllocationFailure.isInvalid())
1102
0
    return false;
1103
11
1104
11
  StmtResult ReturnStmt =
1105
11
      S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
1106
11
  if (ReturnStmt.isInvalid()) {
1107
2
    S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
1108
2
        << DN;
1109
2
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1110
2
        << Fn.getFirstCoroutineStmtKeyword();
1111
2
    return false;
1112
2
  }
1113
9
1114
9
  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
1115
9
  return true;
1116
9
}
1117
1118
167
bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
1119
167
  // Form and check allocation and deallocation calls.
1120
167
  assert(!IsPromiseDependentType &&
1121
167
         "cannot make statement while the promise type is dependent");
1122
167
  QualType PromiseType = Fn.CoroutinePromise->getType();
1123
167
1124
167
  if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
1125
0
    return false;
1126
167
1127
167
  const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
1128
167
1129
167
  // [dcl.fct.def.coroutine]/7
1130
167
  // Lookup allocation functions using a parameter list composed of the
1131
167
  // requested size of the coroutine state being allocated, followed by
1132
167
  // the coroutine function's arguments. If a matching allocation function
1133
167
  // exists, use it. Otherwise, use an allocation function that just takes
1134
167
  // the requested size.
1135
167
1136
167
  FunctionDecl *OperatorNew = nullptr;
1137
167
  FunctionDecl *OperatorDelete = nullptr;
1138
167
  FunctionDecl *UnusedResult = nullptr;
1139
167
  bool PassAlignment = false;
1140
167
  SmallVector<Expr *, 1> PlacementArgs;
1141
167
1142
167
  // [dcl.fct.def.coroutine]/7
1143
167
  // "The allocation function’s name is looked up in the scope of P.
1144
167
  // [...] If the lookup finds an allocation function in the scope of P,
1145
167
  // overload resolution is performed on a function call created by assembling
1146
167
  // an argument list. The first argument is the amount of space requested,
1147
167
  // and has type std::size_t. The lvalues p1 ... pn are the succeeding
1148
167
  // arguments."
1149
167
  //
1150
167
  // ...where "p1 ... pn" are defined earlier as:
1151
167
  //
1152
167
  // [dcl.fct.def.coroutine]/3
1153
167
  // "For a coroutine f that is a non-static member function, let P1 denote the
1154
167
  // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
1155
167
  // of the function parameters; otherwise let P1 ... Pn be the types of the
1156
167
  // function parameters. Let p1 ... pn be lvalues denoting those objects."
1157
167
  if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
1158
49
    if (MD->isInstance() && 
!isLambdaCallOperator(MD)42
) {
1159
37
      ExprResult ThisExpr = S.ActOnCXXThis(Loc);
1160
37
      if (ThisExpr.isInvalid())
1161
0
        return false;
1162
37
      ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
1163
37
      if (ThisExpr.isInvalid())
1164
0
        return false;
1165
37
      PlacementArgs.push_back(ThisExpr.get());
1166
37
    }
1167
49
  }
1168
167
  for (auto *PD : FD.parameters()) {
1169
120
    if (PD->getType()->isDependentType())
1170
0
      continue;
1171
120
1172
120
    // Build a reference to the parameter.
1173
120
    auto PDLoc = PD->getLocation();
1174
120
    ExprResult PDRefExpr =
1175
120
        S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
1176
120
                           ExprValueKind::VK_LValue, PDLoc);
1177
120
    if (PDRefExpr.isInvalid())
1178
0
      return false;
1179
120
1180
120
    PlacementArgs.push_back(PDRefExpr.get());
1181
120
  }
1182
167
  S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1183
167
                            /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1184
167
                            /*isArray*/ false, PassAlignment, PlacementArgs,
1185
167
                            OperatorNew, UnusedResult, /*Diagnose*/ false);
1186
167
1187
167
  // [dcl.fct.def.coroutine]/7
1188
167
  // "If no matching function is found, overload resolution is performed again
1189
167
  // on a function call created by passing just the amount of space required as
1190
167
  // an argument of type std::size_t."
1191
167
  if (!OperatorNew && 
!PlacementArgs.empty()163
) {
1192
92
    PlacementArgs.clear();
1193
92
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1194
92
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1195
92
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1196
92
                              OperatorNew, UnusedResult, /*Diagnose*/ false);
1197
92
  }
1198
167
1199
167
  // [dcl.fct.def.coroutine]/7
1200
167
  // "The allocation function’s name is looked up in the scope of P. If this
1201
167
  // lookup fails, the allocation function’s name is looked up in the global
1202
167
  // scope."
1203
167
  if (!OperatorNew) {
1204
161
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global,
1205
161
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1206
161
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1207
161
                              OperatorNew, UnusedResult);
1208
161
  }
1209
167
1210
167
  bool IsGlobalOverload =
1211
167
      OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1212
167
  // If we didn't find a class-local new declaration and non-throwing new
1213
167
  // was is required then we need to lookup the non-throwing global operator
1214
167
  // instead.
1215
167
  if (RequiresNoThrowAlloc && 
(9
!OperatorNew9
||
IsGlobalOverload9
)) {
1216
6
    auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1217
6
    if (!StdNoThrow)
1218
0
      return false;
1219
6
    PlacementArgs = {StdNoThrow};
1220
6
    OperatorNew = nullptr;
1221
6
    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both,
1222
6
                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1223
6
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1224
6
                              OperatorNew, UnusedResult);
1225
6
  }
1226
167
1227
167
  if (!OperatorNew)
1228
0
    return false;
1229
167
1230
167
  if (RequiresNoThrowAlloc) {
1231
9
    const auto *FT = OperatorNew->getType()->castAs<FunctionProtoType>();
1232
9
    if (!FT->isNothrow(/*ResultIfDependent*/ false)) {
1233
2
      S.Diag(OperatorNew->getLocation(),
1234
2
             diag::err_coroutine_promise_new_requires_nothrow)
1235
2
          << OperatorNew;
1236
2
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1237
2
          << OperatorNew;
1238
2
      return false;
1239
2
    }
1240
165
  }
1241
165
1242
165
  if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr)
1243
0
    return false;
1244
165
1245
165
  Expr *FramePtr =
1246
165
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1247
165
1248
165
  Expr *FrameSize =
1249
165
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1250
165
1251
165
  // Make new call.
1252
165
1253
165
  ExprResult NewRef =
1254
165
      S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1255
165
  if (NewRef.isInvalid())
1256
0
    return false;
1257
165
1258
165
  SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1259
165
  for (auto Arg : PlacementArgs)
1260
16
    NewArgs.push_back(Arg);
1261
165
1262
165
  ExprResult NewExpr =
1263
165
      S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1264
165
  NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false);
1265
165
  if (NewExpr.isInvalid())
1266
0
    return false;
1267
165
1268
165
  // Make delete call.
1269
165
1270
165
  QualType OpDeleteQualType = OperatorDelete->getType();
1271
165
1272
165
  ExprResult DeleteRef =
1273
165
      S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1274
165
  if (DeleteRef.isInvalid())
1275
0
    return false;
1276
165
1277
165
  Expr *CoroFree =
1278
165
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1279
165
1280
165
  SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1281
165
1282
165
  // Check if we need to pass the size.
1283
165
  const auto *OpDeleteType =
1284
165
      OpDeleteQualType.getTypePtr()->castAs<FunctionProtoType>();
1285
165
  if (OpDeleteType->getNumParams() > 1)
1286
1
    DeleteArgs.push_back(FrameSize);
1287
165
1288
165
  ExprResult DeleteExpr =
1289
165
      S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1290
165
  DeleteExpr =
1291
165
      S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false);
1292
165
  if (DeleteExpr.isInvalid())
1293
0
    return false;
1294
165
1295
165
  this->Allocate = NewExpr.get();
1296
165
  this->Deallocate = DeleteExpr.get();
1297
165
1298
165
  return true;
1299
165
}
1300
1301
175
bool CoroutineStmtBuilder::makeOnFallthrough() {
1302
175
  assert(!IsPromiseDependentType &&
1303
175
         "cannot make statement while the promise type is dependent");
1304
175
1305
175
  // [dcl.fct.def.coroutine]/4
1306
175
  // The unqualified-ids 'return_void' and 'return_value' are looked up in
1307
175
  // the scope of class P. If both are found, the program is ill-formed.
1308
175
  bool HasRVoid, HasRValue;
1309
175
  LookupResult LRVoid =
1310
175
      lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1311
175
  LookupResult LRValue =
1312
175
      lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1313
175
1314
175
  StmtResult Fallthrough;
1315
175
  if (HasRVoid && 
HasRValue124
) {
1316
2
    // FIXME Improve this diagnostic
1317
2
    S.Diag(FD.getLocation(),
1318
2
           diag::err_coroutine_promise_incompatible_return_functions)
1319
2
        << PromiseRecordDecl;
1320
2
    S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1321
2
           diag::note_member_first_declared_here)
1322
2
        << LRVoid.getLookupName();
1323
2
    S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1324
2
           diag::note_member_first_declared_here)
1325
2
        << LRValue.getLookupName();
1326
2
    return false;
1327
173
  } else if (!HasRVoid && 
!HasRValue51
) {
1328
1
    // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1329
1
    // However we still diagnose this as an error since until the PDTS is fixed.
1330
1
    S.Diag(FD.getLocation(),
1331
1
           diag::err_coroutine_promise_requires_return_function)
1332
1
        << PromiseRecordDecl;
1333
1
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1334
1
        << PromiseRecordDecl;
1335
1
    return false;
1336
172
  } else if (HasRVoid) {
1337
122
    // If the unqualified-id return_void is found, flowing off the end of a
1338
122
    // coroutine is equivalent to a co_return with no operand. Otherwise,
1339
122
    // flowing off the end of a coroutine results in undefined behavior.
1340
122
    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1341
122
                                      /*IsImplicit*/false);
1342
122
    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1343
122
    if (Fallthrough.isInvalid())
1344
0
      return false;
1345
172
  }
1346
172
1347
172
  this->OnFallthrough = Fallthrough.get();
1348
172
  return true;
1349
172
}
1350
1351
178
bool CoroutineStmtBuilder::makeOnException() {
1352
178
  // Try to form 'p.unhandled_exception();'
1353
178
  assert(!IsPromiseDependentType &&
1354
178
         "cannot make statement while the promise type is dependent");
1355
178
1356
178
  const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1357
178
1358
178
  if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) {
1359
26
    auto DiagID =
1360
26
        RequireUnhandledException
1361
26
            ? 
diag::err_coroutine_promise_unhandled_exception_required2
1362
26
            : diag::
1363
24
                  warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1364
26
    S.Diag(Loc, DiagID) << PromiseRecordDecl;
1365
26
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1366
26
        << PromiseRecordDecl;
1367
26
    return !RequireUnhandledException;
1368
26
  }
1369
152
1370
152
  // If exceptions are disabled, don't try to build OnException.
1371
152
  if (!S.getLangOpts().CXXExceptions)
1372
36
    return true;
1373
116
1374
116
  ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1375
116
                                                   "unhandled_exception", None);
1376
116
  UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc,
1377
116
                                             /*DiscardedValue*/ false);
1378
116
  if (UnhandledException.isInvalid())
1379
0
    return false;
1380
116
1381
116
  // Since the body of the coroutine will be wrapped in try-catch, it will
1382
116
  // be incompatible with SEH __try if present in a function.
1383
116
  if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) {
1384
1
    S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1385
1
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1386
1
        << Fn.getFirstCoroutineStmtKeyword();
1387
1
    return false;
1388
1
  }
1389
115
1390
115
  this->OnException = UnhandledException.get();
1391
115
  return true;
1392
115
}
1393
1394
191
bool CoroutineStmtBuilder::makeReturnObject() {
1395
191
  // Build implicit 'p.get_return_object()' expression and form initialization
1396
191
  // of return type from it.
1397
191
  ExprResult ReturnObject =
1398
191
      buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1399
191
  if (ReturnObject.isInvalid())
1400
1
    return false;
1401
190
1402
190
  this->ReturnValue = ReturnObject.get();
1403
190
  return true;
1404
190
}
1405
1406
2
static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1407
2
  if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) {
1408
2
    auto *MethodDecl = MbrRef->getMethodDecl();
1409
2
    S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1410
2
        << MethodDecl;
1411
2
  }
1412
2
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1413
2
      << Fn.getFirstCoroutineStmtKeyword();
1414
2
}
1415
1416
172
bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1417
172
  assert(!IsPromiseDependentType &&
1418
172
         "cannot make statement while the promise type is dependent");
1419
172
  assert(this->ReturnValue && "ReturnValue must be already formed");
1420
172
1421
172
  QualType const GroType = this->ReturnValue->getType();
1422
172
  assert(!GroType->isDependentType() &&
1423
172
         "get_return_object type must no longer be dependent");
1424
172
1425
172
  QualType const FnRetType = FD.getReturnType();
1426
172
  assert(!FnRetType->isDependentType() &&
1427
172
         "get_return_object type must no longer be dependent");
1428
172
1429
172
  if (FnRetType->isVoidType()) {
1430
67
    ExprResult Res =
1431
67
        S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
1432
67
    if (Res.isInvalid())
1433
0
      return false;
1434
67
1435
67
    this->ResultDecl = Res.get();
1436
67
    return true;
1437
67
  }
1438
105
1439
105
  if (GroType->isVoidType()) {
1440
1
    // Trigger a nice error message.
1441
1
    InitializedEntity Entity =
1442
1
        InitializedEntity::InitializeResult(Loc, FnRetType, false);
1443
1
    S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1444
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1445
1
    return false;
1446
1
  }
1447
104
1448
104
  auto *GroDecl = VarDecl::Create(
1449
104
      S.Context, &FD, FD.getLocation(), FD.getLocation(),
1450
104
      &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1451
104
      S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1452
104
1453
104
  S.CheckVariableDeclarationType(GroDecl);
1454
104
  if (GroDecl->isInvalidDecl())
1455
0
    return false;
1456
104
1457
104
  InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1458
104
  ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1459
104
                                                     this->ReturnValue);
1460
104
  if (Res.isInvalid())
1461
0
    return false;
1462
104
1463
104
  Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
1464
104
  if (Res.isInvalid())
1465
0
    return false;
1466
104
1467
104
  S.AddInitializerToDecl(GroDecl, Res.get(),
1468
104
                         /*DirectInit=*/false);
1469
104
1470
104
  S.FinalizeDeclaration(GroDecl);
1471
104
1472
104
  // Form a declaration statement for the return declaration, so that AST
1473
104
  // visitors can more easily find it.
1474
104
  StmtResult GroDeclStmt =
1475
104
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1476
104
  if (GroDeclStmt.isInvalid())
1477
0
    return false;
1478
104
1479
104
  this->ResultDecl = GroDeclStmt.get();
1480
104
1481
104
  ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1482
104
  if (declRef.isInvalid())
1483
0
    return false;
1484
104
1485
104
  StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1486
104
  if (ReturnStmt.isInvalid()) {
1487
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1488
1
    return false;
1489
1
  }
1490
103
  if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
1491
90
    GroDecl->setNRVOVariable(true);
1492
103
1493
103
  this->ReturnStmt = ReturnStmt.get();
1494
103
  return true;
1495
103
}
1496
1497
// Create a static_cast\<T&&>(expr).
1498
52
static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1499
52
  if (T.isNull())
1500
52
    T = E->getType();
1501
52
  QualType TargetType = S.BuildReferenceType(
1502
52
      T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1503
52
  SourceLocation ExprLoc = E->getBeginLoc();
1504
52
  TypeSourceInfo *TargetLoc =
1505
52
      S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1506
52
1507
52
  return S
1508
52
      .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1509
52
                         SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1510
52
      .get();
1511
52
}
1512
1513
/// Build a variable declaration for move parameter.
1514
static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1515
161
                             IdentifierInfo *II) {
1516
161
  TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1517
161
  VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
1518
161
                                  TInfo, SC_None);
1519
161
  Decl->setImplicit();
1520
161
  return Decl;
1521
161
}
1522
1523
// Build statements that move coroutine function parameters to the coroutine
1524
// frame, and store them on the function scope info.
1525
268
bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
1526
268
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
1527
268
  auto *FD = cast<FunctionDecl>(CurContext);
1528
268
1529
268
  auto *ScopeInfo = getCurFunction();
1530
268
  if (!ScopeInfo->CoroutineParameterMoves.empty())
1531
1
    return false;
1532
267
1533
267
  for (auto *PD : FD->parameters()) {
1534
203
    if (PD->getType()->isDependentType())
1535
42
      continue;
1536
161
1537
161
    ExprResult PDRefExpr =
1538
161
        BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
1539
161
                         ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1540
161
    if (PDRefExpr.isInvalid())
1541
0
      return false;
1542
161
1543
161
    Expr *CExpr = nullptr;
1544
161
    if (PD->getType()->getAsCXXRecordDecl() ||
1545
161
        
PD->getType()->isRValueReferenceType()117
)
1546
52
      CExpr = castForMoving(*this, PDRefExpr.get());
1547
109
    else
1548
109
      CExpr = PDRefExpr.get();
1549
161
1550
161
    auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
1551
161
    AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
1552
161
1553
161
    // Convert decl to a statement.
1554
161
    StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
1555
161
    if (Stmt.isInvalid())
1556
0
      return false;
1557
161
1558
161
    ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
1559
161
  }
1560
267
  return true;
1561
267
}
1562
1563
43
StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1564
43
  CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1565
43
  if (!Res)
1566
0
    return StmtError();
1567
43
  return Res;
1568
43
}
1569
1570
ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc,
1571
210
                                               SourceLocation FuncLoc) {
1572
210
  if (!StdCoroutineTraitsCache) {
1573
33
    if (auto StdExp = lookupStdExperimentalNamespace()) {
1574
33
      LookupResult Result(*this,
1575
33
                          &PP.getIdentifierTable().get("coroutine_traits"),
1576
33
                          FuncLoc, LookupOrdinaryName);
1577
33
      if (!LookupQualifiedName(Result, StdExp)) {
1578
0
        Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
1579
0
            << "std::experimental::coroutine_traits";
1580
0
        return nullptr;
1581
0
      }
1582
33
      if (!(StdCoroutineTraitsCache =
1583
33
                Result.getAsSingle<ClassTemplateDecl>())) {
1584
0
        Result.suppressDiagnostics();
1585
0
        NamedDecl *Found = *Result.begin();
1586
0
        Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
1587
0
        return nullptr;
1588
0
      }
1589
210
    }
1590
33
  }
1591
210
  return StdCoroutineTraitsCache;
1592
210
}