Coverage Report

Created: 2017-10-03 07:32

/Users/buildslave/jenkins/sharedspace/clang-stage2-coverage-R@2/llvm/tools/clang/lib/Sema/SemaCoroutine.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
2
//
3
//                     The LLVM Compiler Infrastructure
4
//
5
// This file is distributed under the University of Illinois Open Source
6
// License. See LICENSE.TXT for details.
7
//
8
//===----------------------------------------------------------------------===//
9
//
10
//  This file implements semantic analysis for C++ Coroutines.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "CoroutineStmtBuilder.h"
15
#include "clang/AST/Decl.h"
16
#include "clang/AST/ExprCXX.h"
17
#include "clang/AST/StmtCXX.h"
18
#include "clang/Lex/Preprocessor.h"
19
#include "clang/Sema/Initialization.h"
20
#include "clang/Sema/Overload.h"
21
#include "clang/Sema/SemaInternal.h"
22
23
using namespace clang;
24
using namespace sema;
25
26
static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
27
501
                                 SourceLocation Loc, bool &Res) {
28
501
  DeclarationName DN = S.PP.getIdentifierInfo(Name);
29
501
  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
30
501
  // Suppress diagnostics when a private member is selected. The same warnings
31
501
  // will be produced again when building the call.
32
501
  LR.suppressDiagnostics();
33
501
  Res = S.LookupQualifiedName(LR, RD);
34
501
  return LR;
35
501
}
36
37
static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
38
231
                         SourceLocation Loc) {
39
231
  bool Res;
40
231
  lookupMember(S, Name, RD, Loc, Res);
41
231
  return Res;
42
231
}
43
44
/// Look up the std::coroutine_traits<...>::promise_type for the given
45
/// function type.
46
static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
47
172
                                  SourceLocation KwLoc) {
48
172
  const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
49
172
  const SourceLocation FuncLoc = FD->getLocation();
50
172
  // FIXME: Cache std::coroutine_traits once we've found it.
51
172
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
52
172
  if (
!StdExp172
) {
53
4
    S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
54
4
        << "std::experimental::coroutine_traits";
55
4
    return QualType();
56
4
  }
57
168
58
168
  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
59
168
                      FuncLoc, Sema::LookupOrdinaryName);
60
168
  if (
!S.LookupQualifiedName(Result, StdExp)168
) {
61
0
    S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
62
0
        << "std::experimental::coroutine_traits";
63
0
    return QualType();
64
0
  }
65
168
66
168
  ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
67
168
  if (
!CoroTraits168
) {
68
0
    Result.suppressDiagnostics();
69
0
    // We found something weird. Complain about the first thing we found.
70
0
    NamedDecl *Found = *Result.begin();
71
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
72
0
    return QualType();
73
0
  }
74
168
75
168
  // Form template argument list for coroutine_traits<R, P1, P2, ...> according
76
168
  // to [dcl.fct.def.coroutine]3
77
168
  TemplateArgumentListInfo Args(KwLoc, KwLoc);
78
312
  auto AddArg = [&](QualType T) {
79
312
    Args.addArgument(TemplateArgumentLoc(
80
312
        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
81
312
  };
82
168
  AddArg(FnType->getReturnType());
83
168
  // If the function is a non-static member function, add the type
84
168
  // of the implicit object parameter before the formal parameters.
85
168
  if (auto *
MD168
= dyn_cast<CXXMethodDecl>(FD)) {
86
46
    if (
MD->isInstance()46
) {
87
35
      // [over.match.funcs]4
88
35
      // For non-static member functions, the type of the implicit object
89
35
      // parameter is
90
35
      //  -- "lvalue reference to cv X" for functions declared without a
91
35
      //      ref-qualifier or with the & ref-qualifier
92
35
      //  -- "rvalue reference to cv X" for functions declared with the &&
93
35
      //      ref-qualifier
94
35
      QualType T =
95
35
          MD->getThisType(S.Context)->getAs<PointerType>()->getPointeeType();
96
35
      T = FnType->getRefQualifier() == RQ_RValue
97
7
              ? S.Context.getRValueReferenceType(T)
98
28
              : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true);
99
35
      AddArg(T);
100
35
    }
101
46
  }
102
168
  for (QualType T : FnType->getParamTypes())
103
109
    AddArg(T);
104
168
105
168
  // Build the template-id.
106
168
  QualType CoroTrait =
107
168
      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
108
168
  if (CoroTrait.isNull())
109
0
    return QualType();
110
168
  
if (168
S.RequireCompleteType(KwLoc, CoroTrait,
111
168
                            diag::err_coroutine_type_missing_specialization))
112
0
    return QualType();
113
168
114
168
  auto *RD = CoroTrait->getAsCXXRecordDecl();
115
168
  assert(RD && "specialization of class template is not a class?");
116
168
117
168
  // Look up the ::promise_type member.
118
168
  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
119
168
                 Sema::LookupOrdinaryName);
120
168
  S.LookupQualifiedName(R, RD);
121
168
  auto *Promise = R.getAsSingle<TypeDecl>();
122
168
  if (
!Promise168
) {
123
5
    S.Diag(FuncLoc,
124
5
           diag::err_implied_std_coroutine_traits_promise_type_not_found)
125
5
        << RD;
126
5
    return QualType();
127
5
  }
128
163
  // The promise type is required to be a class type.
129
163
  QualType PromiseType = S.Context.getTypeDeclType(Promise);
130
163
131
163
  auto buildElaboratedType = [&]() {
132
163
    auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
133
163
    NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
134
163
                                      CoroTrait.getTypePtr());
135
163
    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
136
163
  };
137
163
138
163
  if (
!PromiseType->getAsCXXRecordDecl()163
) {
139
1
    S.Diag(FuncLoc,
140
1
           diag::err_implied_std_coroutine_traits_promise_type_not_class)
141
1
        << buildElaboratedType();
142
1
    return QualType();
143
1
  }
144
162
  
if (162
S.RequireCompleteType(FuncLoc, buildElaboratedType(),
145
162
                            diag::err_coroutine_promise_type_incomplete))
146
1
    return QualType();
147
161
148
161
  return PromiseType;
149
161
}
150
151
/// Look up the std::experimental::coroutine_handle<PromiseType>.
152
static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
153
445
                                          SourceLocation Loc) {
154
445
  if (PromiseType.isNull())
155
0
    return QualType();
156
445
157
445
  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
158
445
  assert(StdExp && "Should already be diagnosed");
159
445
160
445
  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
161
445
                      Loc, Sema::LookupOrdinaryName);
162
445
  if (
!S.LookupQualifiedName(Result, StdExp)445
) {
163
1
    S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
164
1
        << "std::experimental::coroutine_handle";
165
1
    return QualType();
166
1
  }
167
444
168
444
  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
169
444
  if (
!CoroHandle444
) {
170
0
    Result.suppressDiagnostics();
171
0
    // We found something weird. Complain about the first thing we found.
172
0
    NamedDecl *Found = *Result.begin();
173
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
174
0
    return QualType();
175
0
  }
176
444
177
444
  // Form template argument list for coroutine_handle<Promise>.
178
444
  TemplateArgumentListInfo Args(Loc, Loc);
179
444
  Args.addArgument(TemplateArgumentLoc(
180
444
      TemplateArgument(PromiseType),
181
444
      S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
182
444
183
444
  // Build the template-id.
184
444
  QualType CoroHandleType =
185
444
      S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
186
444
  if (CoroHandleType.isNull())
187
0
    return QualType();
188
444
  
if (444
S.RequireCompleteType(Loc, CoroHandleType,
189
444
                            diag::err_coroutine_type_missing_specialization))
190
0
    return QualType();
191
444
192
444
  return CoroHandleType;
193
444
}
194
195
static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
196
1.08k
                                    StringRef Keyword) {
197
1.08k
  // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
198
1.08k
  if (
S.isUnevaluatedContext()1.08k
) {
199
6
    S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
200
6
    return false;
201
6
  }
202
1.07k
203
1.07k
  // Any other usage must be within a function.
204
1.07k
  auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
205
1.07k
  if (
!FD1.07k
) {
206
0
    S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
207
0
                    ? diag::err_coroutine_objc_method
208
0
                    : diag::err_coroutine_outside_function) << Keyword;
209
0
    return false;
210
0
  }
211
1.07k
212
1.07k
  // An enumeration for mapping the diagnostic type to the correct diagnostic
213
1.07k
  // selection index.
214
1.07k
  enum InvalidFuncDiag {
215
1.07k
    DiagCtor = 0,
216
1.07k
    DiagDtor,
217
1.07k
    DiagCopyAssign,
218
1.07k
    DiagMoveAssign,
219
1.07k
    DiagMain,
220
1.07k
    DiagConstexpr,
221
1.07k
    DiagAutoRet,
222
1.07k
    DiagVarargs,
223
1.07k
  };
224
1.07k
  bool Diagnosed = false;
225
13
  auto DiagInvalid = [&](InvalidFuncDiag ID) {
226
13
    S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
227
13
    Diagnosed = true;
228
13
    return false;
229
13
  };
230
1.07k
231
1.07k
  // Diagnose when a constructor, destructor, copy/move assignment operator,
232
1.07k
  // or the function 'main' are declared as a coroutine.
233
1.07k
  auto *MD = dyn_cast<CXXMethodDecl>(FD);
234
1.07k
  if (
MD && 1.07k
isa<CXXConstructorDecl>(MD)271
)
235
2
    return DiagInvalid(DiagCtor);
236
1.07k
  else 
if (1.07k
MD && 1.07k
isa<CXXDestructorDecl>(MD)269
)
237
1
    return DiagInvalid(DiagDtor);
238
1.07k
  else 
if (1.07k
MD && 1.07k
MD->isCopyAssignmentOperator()268
)
239
2
    return DiagInvalid(DiagCopyAssign);
240
1.06k
  else 
if (1.06k
MD && 1.06k
MD->isMoveAssignmentOperator()266
)
241
2
    return DiagInvalid(DiagMoveAssign);
242
1.06k
  else 
if (1.06k
FD->isMain()1.06k
)
243
1
    return DiagInvalid(DiagMain);
244
1.06k
245
1.06k
  // Emit a diagnostics for each of the following conditions which is not met.
246
1.06k
  
if (1.06k
FD->isConstexpr()1.06k
)
247
2
    DiagInvalid(DiagConstexpr);
248
1.06k
  if (FD->getReturnType()->isUndeducedType())
249
2
    DiagInvalid(DiagAutoRet);
250
1.06k
  if (FD->isVariadic())
251
1
    DiagInvalid(DiagVarargs);
252
1.08k
253
1.08k
  return !Diagnosed;
254
1.08k
}
255
256
static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
257
441
                                                 SourceLocation Loc) {
258
441
  DeclarationName OpName =
259
441
      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
260
441
  LookupResult Operators(SemaRef, OpName, SourceLocation(),
261
441
                         Sema::LookupOperatorName);
262
441
  SemaRef.LookupName(Operators, S);
263
441
264
441
  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
265
441
  const auto &Functions = Operators.asUnresolvedSet();
266
441
  bool IsOverloaded =
267
441
      Functions.size() > 1 ||
268
417
      
(Functions.size() == 1 && 417
isa<FunctionTemplateDecl>(*Functions.begin())32
);
269
441
  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
270
441
      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
271
441
      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
272
441
      Functions.begin(), Functions.end());
273
441
  assert(CoawaitOp);
274
441
  return CoawaitOp;
275
441
}
276
277
/// Build a call to 'operator co_await' if there is a suitable operator for
278
/// the given expression.
279
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
280
                                           Expr *E,
281
438
                                           UnresolvedLookupExpr *Lookup) {
282
438
  UnresolvedSet<16> Functions;
283
438
  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
284
438
  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
285
438
}
286
287
static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
288
352
                                           SourceLocation Loc, Expr *E) {
289
352
  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
290
352
  if (R.isInvalid())
291
0
    return ExprError();
292
352
  return buildOperatorCoawaitCall(SemaRef, Loc, E,
293
352
                                  cast<UnresolvedLookupExpr>(R.get()));
294
352
}
295
296
static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
297
819
                              MultiExprArg CallArgs) {
298
819
  StringRef Name = S.Context.BuiltinInfo.getName(Id);
299
819
  LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
300
819
  S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
301
819
302
819
  auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
303
819
  assert(BuiltInDecl && "failed to find builtin declaration");
304
819
305
819
  ExprResult DeclRef =
306
819
      S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
307
819
  assert(DeclRef.isUsable() && "Builtin reference cannot fail");
308
819
309
819
  ExprResult Call =
310
819
      S.ActOnCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
311
819
312
819
  assert(!Call.isInvalid() && "Call to builtin cannot fail!");
313
819
  return Call.get();
314
819
}
315
316
static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
317
445
                                       SourceLocation Loc) {
318
445
  QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
319
445
  if (CoroHandleType.isNull())
320
1
    return ExprError();
321
444
322
444
  DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
323
444
  LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
324
444
                     Sema::LookupOrdinaryName);
325
444
  if (
!S.LookupQualifiedName(Found, LookupCtx)444
) {
326
1
    S.Diag(Loc, diag::err_coroutine_handle_missing_member)
327
1
        << "from_address";
328
1
    return ExprError();
329
1
  }
330
443
331
443
  Expr *FramePtr =
332
443
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
333
443
334
443
  CXXScopeSpec SS;
335
443
  ExprResult FromAddr =
336
443
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
337
443
  if (FromAddr.isInvalid())
338
0
    return ExprError();
339
443
340
443
  return S.ActOnCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
341
443
}
342
343
struct ReadySuspendResumeResult {
344
  enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
345
  Expr *Results[3];
346
  OpaqueValueExpr *OpaqueValue;
347
  bool IsInvalid;
348
};
349
350
static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
351
2.09k
                                  StringRef Name, MultiExprArg Args) {
352
2.09k
  DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
353
2.09k
354
2.09k
  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
355
2.09k
  CXXScopeSpec SS;
356
2.09k
  ExprResult Result = S.BuildMemberReferenceExpr(
357
2.09k
      Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
358
2.09k
      SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
359
2.09k
      /*Scope=*/nullptr);
360
2.09k
  if (Result.isInvalid())
361
16
    return ExprError();
362
2.08k
363
2.08k
  return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
364
2.08k
}
365
366
// See if return type is coroutine-handle and if so, invoke builtin coro-resume
367
// on its address. This is to enable experimental support for coroutine-handle
368
// returning await_suspend that results in a guranteed tail call to the target
369
// coroutine.
370
static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
371
432
                           SourceLocation Loc) {
372
432
  if (RetType->isReferenceType())
373
2
    return nullptr;
374
430
  Type const *T = RetType.getTypePtr();
375
430
  if (
!T->isClassType() && 430
!T->isStructureType()430
)
376
429
    return nullptr;
377
1
378
1
  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
379
1
  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
380
1
  // a private function in SemaExprCXX.cpp
381
1
382
1
  ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
383
1
  if (AddressExpr.isInvalid())
384
0
    return nullptr;
385
1
386
1
  Expr *JustAddress = AddressExpr.get();
387
1
  // FIXME: Check that the type of AddressExpr is void*
388
1
  return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
389
1
                          JustAddress);
390
1
}
391
392
/// Build calls to await_ready, await_suspend, and await_resume for a co_await
393
/// expression.
394
static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
395
445
                                                  SourceLocation Loc, Expr *E) {
396
445
  OpaqueValueExpr *Operand = new (S.Context)
397
445
      OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
398
445
399
445
  // Assume invalid until we see otherwise.
400
445
  ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
401
445
402
445
  ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
403
445
  if (CoroHandleRes.isInvalid())
404
2
    return Calls;
405
443
  Expr *CoroHandle = CoroHandleRes.get();
406
443
407
443
  const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
408
443
  MultiExprArg Args[] = {None, CoroHandle, None};
409
1.73k
  for (size_t I = 0, N = llvm::array_lengthof(Funcs); 
I != N1.73k
;
++I1.29k
) {
410
1.30k
    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
411
1.30k
    if (Result.isInvalid())
412
11
      return Calls;
413
1.29k
    Calls.Results[I] = Result.get();
414
1.29k
  }
415
443
416
443
  // Assume the calls are valid; all further checking should make them invalid.
417
432
  Calls.IsInvalid = false;
418
432
419
432
  using ACT = ReadySuspendResumeResult::AwaitCallType;
420
432
  CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
421
432
  if (
!AwaitReady->getType()->isDependentType()432
) {
422
432
    // [expr.await]p3 [...]
423
432
    // — await-ready is the expression e.await_ready(), contextually converted
424
432
    // to bool.
425
432
    ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
426
432
    if (
Conv.isInvalid()432
) {
427
1
      S.Diag(AwaitReady->getDirectCallee()->getLocStart(),
428
1
             diag::note_await_ready_no_bool_conversion);
429
1
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
430
1
          << AwaitReady->getDirectCallee() << E->getSourceRange();
431
1
      Calls.IsInvalid = true;
432
1
    }
433
432
    Calls.Results[ACT::ACT_Ready] = Conv.get();
434
432
  }
435
432
  CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
436
432
  if (
!AwaitSuspend->getType()->isDependentType()432
) {
437
432
    // [expr.await]p3 [...]
438
432
    //   - await-suspend is the expression e.await_suspend(h), which shall be
439
432
    //     a prvalue of type void or bool.
440
432
    QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
441
432
442
432
    // Experimental support for coroutine_handle returning await_suspend.
443
432
    if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
444
1
      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
445
431
    else {
446
431
      // non-class prvalues always have cv-unqualified types
447
431
      if (RetType->isReferenceType() ||
448
431
          
(!RetType->isBooleanType() && 429
!RetType->isVoidType()426
)) {
449
3
        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
450
3
               diag::err_await_suspend_invalid_return_type)
451
3
            << RetType;
452
3
        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
453
3
            << AwaitSuspend->getDirectCallee();
454
3
        Calls.IsInvalid = true;
455
3
      }
456
431
    }
457
432
  }
458
432
459
432
  return Calls;
460
445
}
461
462
static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
463
                                   SourceLocation Loc, StringRef Name,
464
789
                                   MultiExprArg Args) {
465
789
466
789
  // Form a reference to the promise.
467
789
  ExprResult PromiseRef = S.BuildDeclRefExpr(
468
789
      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
469
789
  if (PromiseRef.isInvalid())
470
0
    return ExprError();
471
789
472
789
  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
473
789
}
474
475
215
VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
476
215
  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
477
215
  auto *FD = cast<FunctionDecl>(CurContext);
478
215
  bool IsThisDependentType = [&] {
479
215
    if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
480
62
      
return MD->isInstance() && 62
MD->getThisType(Context)->isDependentType()51
;
481
215
    else
482
153
      return false;
483
0
  }();
484
215
485
183
  QualType T = FD->getType()->isDependentType() || IsThisDependentType
486
43
                   ? Context.DependentTy
487
172
                   : lookupPromiseType(*this, FD, Loc);
488
215
  if (T.isNull())
489
11
    return nullptr;
490
204
491
204
  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
492
204
                             &PP.getIdentifierTable().get("__promise"), T,
493
204
                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
494
204
  CheckVariableDeclarationType(VD);
495
204
  if (VD->isInvalidDecl())
496
0
    return nullptr;
497
204
  ActOnUninitializedDecl(VD);
498
204
  FD->addDecl(VD);
499
204
  assert(!VD->isInvalidDecl());
500
204
  return VD;
501
204
}
502
503
/// Check that this is a context in which a coroutine suspension can appear.
504
static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
505
                                                StringRef Keyword,
506
1.08k
                                                bool IsImplicit = false) {
507
1.08k
  if (!isValidCoroutineContext(S, Loc, Keyword))
508
18
    return nullptr;
509
1.06k
510
1.08k
  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
511
1.06k
512
1.06k
  auto *ScopeInfo = S.getCurFunction();
513
1.06k
  assert(ScopeInfo && "missing function scope for function");
514
1.06k
515
1.06k
  if (
ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && 1.06k
!IsImplicit313
)
516
215
    ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
517
1.06k
518
1.06k
  if (ScopeInfo->CoroutinePromise)
519
896
    return ScopeInfo;
520
166
521
166
  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
522
166
  if (!ScopeInfo->CoroutinePromise)
523
11
    return nullptr;
524
155
525
155
  return ScopeInfo;
526
155
}
527
528
bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
529
238
                                   StringRef Keyword) {
530
238
  if (!checkCoroutineContext(*this, KWLoc, Keyword))
531
29
    return false;
532
209
  auto *ScopeInfo = getCurFunction();
533
209
  assert(ScopeInfo->CoroutinePromise);
534
209
535
209
  // If we have existing coroutine statements then we have already built
536
209
  // the initial and final suspend points.
537
209
  if (!ScopeInfo->NeedsCoroutineSuspends)
538
54
    return true;
539
155
540
155
  ScopeInfo->setNeedsCoroutineSuspends(false);
541
155
542
155
  auto *Fn = cast<FunctionDecl>(CurContext);
543
155
  SourceLocation Loc = Fn->getLocation();
544
155
  // Build the initial suspend point
545
304
  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
546
304
    ExprResult Suspend =
547
304
        buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
548
304
    if (Suspend.isInvalid())
549
1
      return StmtError();
550
303
    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
551
303
    if (Suspend.isInvalid())
552
0
      return StmtError();
553
303
    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
554
303
                                       /*IsImplicit*/ true);
555
303
    Suspend = ActOnFinishFullExpr(Suspend.get());
556
303
    if (
Suspend.isInvalid()303
) {
557
6
      Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
558
6
          << ((Name == "initial_suspend") ? 
05
:
11
);
559
6
      Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
560
6
      return StmtError();
561
6
    }
562
297
    return cast<Stmt>(Suspend.get());
563
297
  };
564
155
565
155
  StmtResult InitSuspend = buildSuspends("initial_suspend");
566
155
  if (InitSuspend.isInvalid())
567
6
    return true;
568
149
569
149
  StmtResult FinalSuspend = buildSuspends("final_suspend");
570
149
  if (FinalSuspend.isInvalid())
571
1
    return true;
572
148
573
148
  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
574
148
575
148
  return true;
576
148
}
577
578
104
ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
579
104
  if (
!ActOnCoroutineBodyStart(S, Loc, "co_await")104
) {
580
15
    CorrectDelayedTyposInExpr(E);
581
15
    return ExprError();
582
15
  }
583
89
584
89
  
if (89
E->getType()->isPlaceholderType()89
) {
585
2
    ExprResult R = CheckPlaceholderExpr(E);
586
2
    if (
R.isInvalid()2
)
return ExprError()0
;
587
2
    E = R.get();
588
2
  }
589
89
  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
590
89
  if (Lookup.isInvalid())
591
0
    return ExprError();
592
89
  return BuildUnresolvedCoawaitExpr(Loc, E,
593
89
                                   cast<UnresolvedLookupExpr>(Lookup.get()));
594
89
}
595
596
ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
597
109
                                            UnresolvedLookupExpr *Lookup) {
598
109
  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
599
109
  if (!FSI)
600
0
    return ExprError();
601
109
602
109
  
if (109
E->getType()->isPlaceholderType()109
) {
603
0
    ExprResult R = CheckPlaceholderExpr(E);
604
0
    if (R.isInvalid())
605
0
      return ExprError();
606
0
    E = R.get();
607
0
  }
608
109
609
109
  auto *Promise = FSI->CoroutinePromise;
610
109
  if (
Promise->getType()->isDependentType()109
) {
611
18
    Expr *Res =
612
18
        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
613
18
    return Res;
614
18
  }
615
91
616
91
  auto *RD = Promise->getType()->getAsCXXRecordDecl();
617
91
  if (
lookupMember(*this, "await_transform", RD, Loc)91
) {
618
21
    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
619
21
    if (
R.isInvalid()21
) {
620
5
      Diag(Loc,
621
5
           diag::note_coroutine_promise_implicit_await_transform_required_here)
622
5
          << E->getSourceRange();
623
5
      return ExprError();
624
5
    }
625
16
    E = R.get();
626
16
  }
627
86
  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
628
86
  if (Awaitable.isInvalid())
629
4
    return ExprError();
630
82
631
82
  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
632
82
}
633
634
ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
635
490
                                  bool IsImplicit) {
636
490
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
637
490
  if (!Coroutine)
638
0
    return ExprError();
639
490
640
490
  
if (490
E->getType()->isPlaceholderType()490
) {
641
0
    ExprResult R = CheckPlaceholderExpr(E);
642
0
    if (
R.isInvalid()0
)
return ExprError()0
;
643
0
    E = R.get();
644
0
  }
645
490
646
490
  
if (490
E->getType()->isDependentType()490
) {
647
93
    Expr *Res = new (Context)
648
93
        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
649
93
    return Res;
650
93
  }
651
397
652
397
  // If the expression is a temporary, materialize it as an lvalue so that we
653
397
  // can use it multiple times.
654
397
  
if (397
E->getValueKind() == VK_RValue397
)
655
357
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
656
397
657
397
  // Build the await_ready, await_suspend, await_resume calls.
658
397
  ReadySuspendResumeResult RSS =
659
397
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
660
397
  if (RSS.IsInvalid)
661
16
    return ExprError();
662
381
663
381
  Expr *Res =
664
381
      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
665
381
                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
666
381
667
381
  return Res;
668
381
}
669
670
60
ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
671
60
  if (
!ActOnCoroutineBodyStart(S, Loc, "co_yield")60
) {
672
8
    CorrectDelayedTyposInExpr(E);
673
8
    return ExprError();
674
8
  }
675
52
676
52
  // Build yield_value call.
677
52
  ExprResult Awaitable = buildPromiseCall(
678
52
      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
679
52
  if (Awaitable.isInvalid())
680
3
    return ExprError();
681
49
682
49
  // Build 'operator co_await' call.
683
49
  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
684
49
  if (Awaitable.isInvalid())
685
0
    return ExprError();
686
49
687
49
  return BuildCoyieldExpr(Loc, Awaitable.get());
688
49
}
689
68
ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
690
68
  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
691
68
  if (!Coroutine)
692
0
    return ExprError();
693
68
694
68
  
if (68
E->getType()->isPlaceholderType()68
) {
695
0
    ExprResult R = CheckPlaceholderExpr(E);
696
0
    if (
R.isInvalid()0
)
return ExprError()0
;
697
0
    E = R.get();
698
0
  }
699
68
700
68
  
if (68
E->getType()->isDependentType()68
) {
701
20
    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
702
20
    return Res;
703
20
  }
704
48
705
48
  // If the expression is a temporary, materialize it as an lvalue so that we
706
48
  // can use it multiple times.
707
48
  
if (48
E->getValueKind() == VK_RValue48
)
708
48
    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
709
48
710
48
  // Build the await_ready, await_suspend, await_resume calls.
711
48
  ReadySuspendResumeResult RSS =
712
48
      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
713
48
  if (RSS.IsInvalid)
714
1
    return ExprError();
715
47
716
47
  Expr *Res =
717
47
      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
718
47
                                RSS.Results[2], RSS.OpaqueValue);
719
47
720
47
  return Res;
721
47
}
722
723
66
StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
724
66
  if (
!ActOnCoroutineBodyStart(S, Loc, "co_return")66
) {
725
5
    CorrectDelayedTyposInExpr(E);
726
5
    return StmtError();
727
5
  }
728
61
  return BuildCoreturnStmt(Loc, E);
729
61
}
730
731
StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
732
175
                                   bool IsImplicit) {
733
175
  auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
734
175
  if (!FSI)
735
0
    return StmtError();
736
175
737
175
  
if (175
E && 175
E->getType()->isPlaceholderType()27
&&
738
175
      
!E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)2
) {
739
0
    ExprResult R = CheckPlaceholderExpr(E);
740
0
    if (
R.isInvalid()0
)
return StmtError()0
;
741
0
    E = R.get();
742
0
  }
743
175
744
175
  // FIXME: If the operand is a reference to a variable that's about to go out
745
175
  // of scope, we should treat the operand as an xvalue for this overload
746
175
  // resolution.
747
175
  VarDecl *Promise = FSI->CoroutinePromise;
748
175
  ExprResult PC;
749
175
  if (
E && 175
(isa<InitListExpr>(E) || 27
!E->getType()->isVoidType()25
)) {
750
26
    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
751
175
  } else {
752
149
    E = MakeFullDiscardedValueExpr(E).get();
753
149
    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
754
149
  }
755
175
  if (PC.isInvalid())
756
4
    return StmtError();
757
171
758
171
  Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
759
171
760
171
  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
761
171
  return Res;
762
171
}
763
764
/// Look up the std::nothrow object.
765
5
static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
766
5
  NamespaceDecl *Std = S.getStdNamespace();
767
5
  assert(Std && "Should already be diagnosed");
768
5
769
5
  LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
770
5
                      Sema::LookupOrdinaryName);
771
5
  if (
!S.LookupQualifiedName(Result, Std)5
) {
772
0
    // FIXME: <experimental/coroutine> should have been included already.
773
0
    // If we require it to include <new> then this diagnostic is no longer
774
0
    // needed.
775
0
    S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
776
0
    return nullptr;
777
0
  }
778
5
779
5
  auto *VD = Result.getAsSingle<VarDecl>();
780
5
  if (
!VD5
) {
781
0
    Result.suppressDiagnostics();
782
0
    // We found something weird. Complain about the first thing we found.
783
0
    NamedDecl *Found = *Result.begin();
784
0
    S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
785
0
    return nullptr;
786
0
  }
787
5
788
5
  ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
789
5
  if (DR.isInvalid())
790
0
    return nullptr;
791
5
792
5
  return DR.get();
793
5
}
794
795
// Find an appropriate delete for the promise.
796
static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
797
125
                                          QualType PromiseType) {
798
125
  FunctionDecl *OperatorDelete = nullptr;
799
125
800
125
  DeclarationName DeleteName =
801
125
      S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
802
125
803
125
  auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
804
125
  assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
805
125
806
125
  if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
807
0
    return nullptr;
808
125
809
125
  
if (125
!OperatorDelete125
) {
810
123
    // Look for a global declaration.
811
123
    const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
812
123
    const bool Overaligned = false;
813
123
    OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
814
123
                                                     Overaligned, DeleteName);
815
123
  }
816
125
  S.MarkFunctionReferenced(Loc, OperatorDelete);
817
125
  return OperatorDelete;
818
125
}
819
820
821
215
void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
822
215
  FunctionScopeInfo *Fn = getCurFunction();
823
215
  assert(Fn && Fn->isCoroutine() && "not a coroutine");
824
215
  if (
!Body215
) {
825
12
    assert(FD->isInvalidDecl() &&
826
12
           "a null body is only allowed for invalid declarations");
827
12
    return;
828
12
  }
829
203
  // We have a function that uses coroutine keywords, but we failed to build
830
203
  // the promise type.
831
203
  
if (203
!Fn->CoroutinePromise203
)
832
11
    return FD->setInvalidDecl();
833
192
834
192
  
if (192
isa<CoroutineBodyStmt>(Body)192
) {
835
37
    // Nothing todo. the body is already a transformed coroutine body statement.
836
37
    return;
837
37
  }
838
155
839
155
  // Coroutines [stmt.return]p1:
840
155
  //   A return statement shall not appear in a coroutine.
841
155
  
if (155
Fn->FirstReturnLoc.isValid()155
) {
842
13
    assert(Fn->FirstCoroutineStmtLoc.isValid() &&
843
13
                   "first coroutine location not set");
844
13
    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
845
13
    Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
846
13
            << Fn->getFirstCoroutineStmtKeyword();
847
13
  }
848
155
  CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
849
155
  if (
Builder.isInvalid() || 155
!Builder.buildStatements()148
)
850
19
    return FD->setInvalidDecl();
851
136
852
136
  // Build body for the coroutine wrapper statement.
853
136
  Body = CoroutineBodyStmt::Create(Context, Builder);
854
136
}
855
856
CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
857
                                           sema::FunctionScopeInfo &Fn,
858
                                           Stmt *Body)
859
    : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
860
      IsPromiseDependentType(
861
          !Fn.CoroutinePromise ||
862
196
          Fn.CoroutinePromise->getType()->isDependentType()) {
863
196
  this->Body = Body;
864
196
  if (
!IsPromiseDependentType196
) {
865
153
    PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
866
153
    assert(PromiseRecordDecl && "Type should have already been checked");
867
153
  }
868
196
  this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
869
196
}
870
871
148
bool CoroutineStmtBuilder::buildStatements() {
872
148
  assert(this->IsValid && "coroutine already invalid");
873
147
  this->IsValid = makeReturnObject() && makeParamMoves();
874
148
  if (
this->IsValid && 148
!IsPromiseDependentType147
)
875
104
    buildDependentStatements();
876
148
  return this->IsValid;
877
148
}
878
879
140
bool CoroutineStmtBuilder::buildDependentStatements() {
880
140
  assert(this->IsValid && "coroutine already invalid");
881
140
  assert(!this->IsPromiseDependentType &&
882
140
         "coroutine cannot have a dependent promise type");
883
135
  this->IsValid = makeOnException() && makeOnFallthrough() &&
884
140
                  
makeGroDeclAndReturnStmt()132
&&
makeReturnOnAllocFailure()130
&&
885
127
                  makeNewAndDeleteExpr();
886
140
  return this->IsValid;
887
140
}
888
889
37
bool CoroutineStmtBuilder::buildParameterMoves() {
890
37
  assert(this->IsValid && "coroutine already invalid");
891
37
  assert(this->ParamMoves.empty() && "param moves already built");
892
37
  return this->IsValid = makeParamMoves();
893
37
}
894
895
196
bool CoroutineStmtBuilder::makePromiseStmt() {
896
196
  // Form a declaration statement for the promise declaration, so that AST
897
196
  // visitors can more easily find it.
898
196
  StmtResult PromiseStmt =
899
196
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
900
196
  if (PromiseStmt.isInvalid())
901
0
    return false;
902
196
903
196
  this->Promise = PromiseStmt.get();
904
196
  return true;
905
196
}
906
907
196
bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
908
196
  if (Fn.hasInvalidCoroutineSuspends())
909
7
    return false;
910
189
  this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
911
189
  this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
912
189
  return true;
913
189
}
914
915
static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
916
                                     CXXRecordDecl *PromiseRecordDecl,
917
10
                                     FunctionScopeInfo &Fn) {
918
10
  auto Loc = E->getExprLoc();
919
10
  if (auto *
DeclRef10
= dyn_cast_or_null<DeclRefExpr>(E)) {
920
10
    auto *Decl = DeclRef->getDecl();
921
10
    if (CXXMethodDecl *
Method10
= dyn_cast_or_null<CXXMethodDecl>(Decl)) {
922
10
      if (Method->isStatic())
923
9
        return true;
924
10
      else
925
1
        Loc = Decl->getLocation();
926
10
    }
927
10
  }
928
10
929
1
  S.Diag(
930
1
      Loc,
931
1
      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
932
1
      << PromiseRecordDecl;
933
1
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
934
1
      << Fn.getFirstCoroutineStmtKeyword();
935
1
  return false;
936
10
}
937
938
130
bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
939
130
  assert(!IsPromiseDependentType &&
940
130
         "cannot make statement while the promise type is dependent");
941
130
942
130
  // [dcl.fct.def.coroutine]/8
943
130
  // The unqualified-id get_return_object_on_allocation_failure is looked up in
944
130
  // the scope of class P by class member access lookup (3.4.5). ...
945
130
  // If an allocation function returns nullptr, ... the coroutine return value
946
130
  // is obtained by a call to ... get_return_object_on_allocation_failure().
947
130
948
130
  DeclarationName DN =
949
130
      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
950
130
  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
951
130
  if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
952
120
    return true;
953
10
954
10
  CXXScopeSpec SS;
955
10
  ExprResult DeclNameExpr =
956
10
      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
957
10
  if (DeclNameExpr.isInvalid())
958
0
    return false;
959
10
960
10
  
if (10
!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)10
)
961
1
    return false;
962
9
963
9
  ExprResult ReturnObjectOnAllocationFailure =
964
9
      S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
965
9
  if (ReturnObjectOnAllocationFailure.isInvalid())
966
0
    return false;
967
9
968
9
  StmtResult ReturnStmt =
969
9
      S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
970
9
  if (
ReturnStmt.isInvalid()9
) {
971
2
    S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
972
2
        << DN;
973
2
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
974
2
        << Fn.getFirstCoroutineStmtKeyword();
975
2
    return false;
976
2
  }
977
7
978
7
  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
979
7
  return true;
980
7
}
981
982
127
bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
983
127
  // Form and check allocation and deallocation calls.
984
127
  assert(!IsPromiseDependentType &&
985
127
         "cannot make statement while the promise type is dependent");
986
127
  QualType PromiseType = Fn.CoroutinePromise->getType();
987
127
988
127
  if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
989
0
    return false;
990
127
991
127
  const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
992
127
993
127
  // FIXME: Add support for stateful allocators.
994
127
995
127
  FunctionDecl *OperatorNew = nullptr;
996
127
  FunctionDecl *OperatorDelete = nullptr;
997
127
  FunctionDecl *UnusedResult = nullptr;
998
127
  bool PassAlignment = false;
999
127
  SmallVector<Expr *, 1> PlacementArgs;
1000
127
1001
127
  S.FindAllocationFunctions(Loc, SourceRange(),
1002
127
                            /*UseGlobal*/ false, PromiseType,
1003
127
                            /*isArray*/ false, PassAlignment, PlacementArgs,
1004
127
                            OperatorNew, UnusedResult);
1005
127
1006
127
  bool IsGlobalOverload =
1007
127
      OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1008
127
  // If we didn't find a class-local new declaration and non-throwing new
1009
127
  // was is required then we need to lookup the non-throwing global operator
1010
127
  // instead.
1011
127
  if (
RequiresNoThrowAlloc && 127
(!OperatorNew || 7
IsGlobalOverload7
)) {
1012
5
    auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1013
5
    if (!StdNoThrow)
1014
0
      return false;
1015
5
    PlacementArgs = {StdNoThrow};
1016
5
    OperatorNew = nullptr;
1017
5
    S.FindAllocationFunctions(Loc, SourceRange(),
1018
5
                              /*UseGlobal*/ true, PromiseType,
1019
5
                              /*isArray*/ false, PassAlignment, PlacementArgs,
1020
5
                              OperatorNew, UnusedResult);
1021
5
  }
1022
127
1023
127
  assert(OperatorNew && "expected definition of operator new to be found");
1024
127
1025
127
  if (
RequiresNoThrowAlloc127
) {
1026
7
    const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>();
1027
7
    if (
!FT->isNothrow(S.Context, /*ResultIfDependent*/ false)7
) {
1028
2
      S.Diag(OperatorNew->getLocation(),
1029
2
             diag::err_coroutine_promise_new_requires_nothrow)
1030
2
          << OperatorNew;
1031
2
      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1032
2
          << OperatorNew;
1033
2
      return false;
1034
2
    }
1035
125
  }
1036
125
1037
125
  
if (125
(OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr125
)
1038
0
    return false;
1039
125
1040
125
  Expr *FramePtr =
1041
125
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1042
125
1043
125
  Expr *FrameSize =
1044
125
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1045
125
1046
125
  // Make new call.
1047
125
1048
125
  ExprResult NewRef =
1049
125
      S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1050
125
  if (NewRef.isInvalid())
1051
0
    return false;
1052
125
1053
125
  SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1054
125
  for (auto Arg : PlacementArgs)
1055
5
    NewArgs.push_back(Arg);
1056
125
1057
125
  ExprResult NewExpr =
1058
125
      S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1059
125
  NewExpr = S.ActOnFinishFullExpr(NewExpr.get());
1060
125
  if (NewExpr.isInvalid())
1061
0
    return false;
1062
125
1063
125
  // Make delete call.
1064
125
1065
125
  QualType OpDeleteQualType = OperatorDelete->getType();
1066
125
1067
125
  ExprResult DeleteRef =
1068
125
      S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1069
125
  if (DeleteRef.isInvalid())
1070
0
    return false;
1071
125
1072
125
  Expr *CoroFree =
1073
125
      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1074
125
1075
125
  SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1076
125
1077
125
  // Check if we need to pass the size.
1078
125
  const auto *OpDeleteType =
1079
125
      OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>();
1080
125
  if (OpDeleteType->getNumParams() > 1)
1081
1
    DeleteArgs.push_back(FrameSize);
1082
125
1083
125
  ExprResult DeleteExpr =
1084
125
      S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1085
125
  DeleteExpr = S.ActOnFinishFullExpr(DeleteExpr.get());
1086
125
  if (DeleteExpr.isInvalid())
1087
0
    return false;
1088
125
1089
125
  this->Allocate = NewExpr.get();
1090
125
  this->Deallocate = DeleteExpr.get();
1091
125
1092
125
  return true;
1093
125
}
1094
1095
135
bool CoroutineStmtBuilder::makeOnFallthrough() {
1096
135
  assert(!IsPromiseDependentType &&
1097
135
         "cannot make statement while the promise type is dependent");
1098
135
1099
135
  // [dcl.fct.def.coroutine]/4
1100
135
  // The unqualified-ids 'return_void' and 'return_value' are looked up in
1101
135
  // the scope of class P. If both are found, the program is ill-formed.
1102
135
  bool HasRVoid, HasRValue;
1103
135
  LookupResult LRVoid =
1104
135
      lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1105
135
  LookupResult LRValue =
1106
135
      lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1107
135
1108
135
  StmtResult Fallthrough;
1109
135
  if (
HasRVoid && 135
HasRValue100
) {
1110
2
    // FIXME Improve this diagnostic
1111
2
    S.Diag(FD.getLocation(),
1112
2
           diag::err_coroutine_promise_incompatible_return_functions)
1113
2
        << PromiseRecordDecl;
1114
2
    S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1115
2
           diag::note_member_first_declared_here)
1116
2
        << LRVoid.getLookupName();
1117
2
    S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1118
2
           diag::note_member_first_declared_here)
1119
2
        << LRValue.getLookupName();
1120
2
    return false;
1121
133
  } else 
if (133
!HasRVoid && 133
!HasRValue35
) {
1122
1
    // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1123
1
    // However we still diagnose this as an error since until the PDTS is fixed.
1124
1
    S.Diag(FD.getLocation(),
1125
1
           diag::err_coroutine_promise_requires_return_function)
1126
1
        << PromiseRecordDecl;
1127
1
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1128
1
        << PromiseRecordDecl;
1129
1
    return false;
1130
132
  } else 
if (132
HasRVoid132
) {
1131
98
    // If the unqualified-id return_void is found, flowing off the end of a
1132
98
    // coroutine is equivalent to a co_return with no operand. Otherwise,
1133
98
    // flowing off the end of a coroutine results in undefined behavior.
1134
98
    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1135
98
                                      /*IsImplicit*/false);
1136
98
    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1137
98
    if (Fallthrough.isInvalid())
1138
0
      return false;
1139
132
  }
1140
132
1141
132
  this->OnFallthrough = Fallthrough.get();
1142
132
  return true;
1143
132
}
1144
1145
140
bool CoroutineStmtBuilder::makeOnException() {
1146
140
  // Try to form 'p.unhandled_exception();'
1147
140
  assert(!IsPromiseDependentType &&
1148
140
         "cannot make statement while the promise type is dependent");
1149
140
1150
140
  const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1151
140
1152
140
  if (
!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)140
) {
1153
24
    auto DiagID =
1154
24
        RequireUnhandledException
1155
2
            ? diag::err_coroutine_promise_unhandled_exception_required
1156
22
            : diag::
1157
22
                  warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1158
24
    S.Diag(Loc, DiagID) << PromiseRecordDecl;
1159
24
    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1160
24
        << PromiseRecordDecl;
1161
24
    return !RequireUnhandledException;
1162
24
  }
1163
116
1164
116
  // If exceptions are disabled, don't try to build OnException.
1165
116
  
if (116
!S.getLangOpts().CXXExceptions116
)
1166
27
    return true;
1167
89
1168
89
  ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1169
89
                                                   "unhandled_exception", None);
1170
89
  UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc);
1171
89
  if (UnhandledException.isInvalid())
1172
2
    return false;
1173
87
1174
87
  // Since the body of the coroutine will be wrapped in try-catch, it will
1175
87
  // be incompatible with SEH __try if present in a function.
1176
87
  
if (87
!S.getLangOpts().Borland && 87
Fn.FirstSEHTryLoc.isValid()87
) {
1177
1
    S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1178
1
    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1179
1
        << Fn.getFirstCoroutineStmtKeyword();
1180
1
    return false;
1181
1
  }
1182
86
1183
86
  this->OnException = UnhandledException.get();
1184
86
  return true;
1185
86
}
1186
1187
148
bool CoroutineStmtBuilder::makeReturnObject() {
1188
148
  // Build implicit 'p.get_return_object()' expression and form initialization
1189
148
  // of return type from it.
1190
148
  ExprResult ReturnObject =
1191
148
      buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1192
148
  if (ReturnObject.isInvalid())
1193
1
    return false;
1194
147
1195
147
  this->ReturnValue = ReturnObject.get();
1196
147
  return true;
1197
147
}
1198
1199
2
static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1200
2
  if (auto *
MbrRef2
= dyn_cast<CXXMemberCallExpr>(E)) {
1201
2
    auto *MethodDecl = MbrRef->getMethodDecl();
1202
2
    S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1203
2
        << MethodDecl;
1204
2
  }
1205
2
  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1206
2
      << Fn.getFirstCoroutineStmtKeyword();
1207
2
}
1208
1209
132
bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1210
132
  assert(!IsPromiseDependentType &&
1211
132
         "cannot make statement while the promise type is dependent");
1212
132
  assert(this->ReturnValue && "ReturnValue must be already formed");
1213
132
1214
132
  QualType const GroType = this->ReturnValue->getType();
1215
132
  assert(!GroType->isDependentType() &&
1216
132
         "get_return_object type must no longer be dependent");
1217
132
1218
132
  QualType const FnRetType = FD.getReturnType();
1219
132
  assert(!FnRetType->isDependentType() &&
1220
132
         "get_return_object type must no longer be dependent");
1221
132
1222
132
  if (
FnRetType->isVoidType()132
) {
1223
50
    ExprResult Res = S.ActOnFinishFullExpr(this->ReturnValue, Loc);
1224
50
    if (Res.isInvalid())
1225
0
      return false;
1226
50
1227
50
    this->ResultDecl = Res.get();
1228
50
    return true;
1229
50
  }
1230
82
1231
82
  
if (82
GroType->isVoidType()82
) {
1232
1
    // Trigger a nice error message.
1233
1
    InitializedEntity Entity =
1234
1
        InitializedEntity::InitializeResult(Loc, FnRetType, false);
1235
1
    S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1236
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1237
1
    return false;
1238
1
  }
1239
81
1240
81
  auto *GroDecl = VarDecl::Create(
1241
81
      S.Context, &FD, FD.getLocation(), FD.getLocation(),
1242
81
      &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1243
81
      S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1244
81
1245
81
  S.CheckVariableDeclarationType(GroDecl);
1246
81
  if (GroDecl->isInvalidDecl())
1247
0
    return false;
1248
81
1249
81
  InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1250
81
  ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1251
81
                                                     this->ReturnValue);
1252
81
  if (Res.isInvalid())
1253
0
    return false;
1254
81
1255
81
  Res = S.ActOnFinishFullExpr(Res.get());
1256
81
  if (Res.isInvalid())
1257
0
    return false;
1258
81
1259
81
  
if (81
GroType == FnRetType81
) {
1260
60
    GroDecl->setNRVOVariable(true);
1261
60
  }
1262
81
1263
81
  S.AddInitializerToDecl(GroDecl, Res.get(),
1264
81
                         /*DirectInit=*/false);
1265
81
1266
81
  S.FinalizeDeclaration(GroDecl);
1267
81
1268
81
  // Form a declaration statement for the return declaration, so that AST
1269
81
  // visitors can more easily find it.
1270
81
  StmtResult GroDeclStmt =
1271
81
      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1272
81
  if (GroDeclStmt.isInvalid())
1273
0
    return false;
1274
81
1275
81
  this->ResultDecl = GroDeclStmt.get();
1276
81
1277
81
  ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1278
81
  if (declRef.isInvalid())
1279
0
    return false;
1280
81
1281
81
  StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1282
81
  if (
ReturnStmt.isInvalid()81
) {
1283
1
    noteMemberDeclaredHere(S, ReturnValue, Fn);
1284
1
    return false;
1285
1
  }
1286
80
1287
80
  this->ReturnStmt = ReturnStmt.get();
1288
80
  return true;
1289
80
}
1290
1291
// Create a static_cast\<T&&>(expr).
1292
34
static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1293
34
  if (T.isNull())
1294
34
    T = E->getType();
1295
34
  QualType TargetType = S.BuildReferenceType(
1296
34
      T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1297
34
  SourceLocation ExprLoc = E->getLocStart();
1298
34
  TypeSourceInfo *TargetLoc =
1299
34
      S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1300
34
1301
34
  return S
1302
34
      .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1303
34
                         SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1304
34
      .get();
1305
34
}
1306
1307
1308
/// \brief Build a variable declaration for move parameter.
1309
static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1310
34
                             IdentifierInfo *II) {
1311
34
  TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1312
34
  VarDecl *Decl =
1313
34
      VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, TInfo, SC_None);
1314
34
  Decl->setImplicit();
1315
34
  return Decl;
1316
34
}
1317
1318
184
bool CoroutineStmtBuilder::makeParamMoves() {
1319
137
  for (auto *paramDecl : FD.parameters()) {
1320
137
    auto Ty = paramDecl->getType();
1321
137
    if (Ty->isDependentType())
1322
34
      continue;
1323
103
1324
103
    // No need to copy scalars, llvm will take care of them.
1325
103
    
if (103
Ty->getAsCXXRecordDecl()103
) {
1326
34
      ExprResult ParamRef =
1327
34
          S.BuildDeclRefExpr(paramDecl, paramDecl->getType(),
1328
34
                             ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1329
34
      if (ParamRef.isInvalid())
1330
0
        return false;
1331
34
1332
34
      Expr *RCast = castForMoving(S, ParamRef.get());
1333
34
1334
34
      auto D = buildVarDecl(S, Loc, Ty, paramDecl->getIdentifier());
1335
34
      S.AddInitializerToDecl(D, RCast, /*DirectInit=*/true);
1336
34
1337
34
      // Convert decl to a statement.
1338
34
      StmtResult Stmt = S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(D), Loc, Loc);
1339
34
      if (Stmt.isInvalid())
1340
0
        return false;
1341
34
1342
34
      ParamMovesVector.push_back(Stmt.get());
1343
34
    }
1344
137
  }
1345
184
1346
184
  // Convert to ArrayRef in CtorArgs structure that builder inherits from.
1347
184
  ParamMoves = ParamMovesVector;
1348
184
  return true;
1349
184
}
1350
1351
37
StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1352
37
  CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1353
37
  if (!Res)
1354
0
    return StmtError();
1355
37
  return Res;
1356
37
}