Coverage Report

Created: 2020-02-15 09:57

/Users/buildslave/jenkins/workspace/coverage/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// Instrumentation-based profile-guided optimization
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "CodeGenPGO.h"
14
#include "CodeGenFunction.h"
15
#include "CoverageMappingGen.h"
16
#include "clang/AST/RecursiveASTVisitor.h"
17
#include "clang/AST/StmtVisitor.h"
18
#include "llvm/IR/Intrinsics.h"
19
#include "llvm/IR/MDBuilder.h"
20
#include "llvm/Support/CommandLine.h"
21
#include "llvm/Support/Endian.h"
22
#include "llvm/Support/FileSystem.h"
23
#include "llvm/Support/MD5.h"
24
25
static llvm::cl::opt<bool>
26
    EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27
                         llvm::cl::desc("Enable value profiling"),
28
                         llvm::cl::Hidden, llvm::cl::init(false));
29
30
using namespace clang;
31
using namespace CodeGen;
32
33
void CodeGenPGO::setFuncName(StringRef Name,
34
511
                             llvm::GlobalValue::LinkageTypes Linkage) {
35
511
  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36
511
  FuncName = llvm::getPGOFuncName(
37
511
      Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38
511
      PGOReader ? 
PGOReader->getVersion()187
:
llvm::IndexedInstrProf::Version324
);
39
511
40
511
  // If we're generating a profile, create a variable for the name.
41
511
  if (CGM.getCodeGenOpts().hasProfileClangInstr())
42
324
    FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43
511
}
44
45
494
void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46
494
  setFuncName(Fn->getName(), Fn->getLinkage());
47
494
  // Create PGOFuncName meta data.
48
494
  llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49
494
}
50
51
/// The version of the PGO hash algorithm.
52
enum PGOHashVersion : unsigned {
53
  PGO_HASH_V1,
54
  PGO_HASH_V2,
55
56
  // Keep this set to the latest hash version.
57
  PGO_HASH_LATEST = PGO_HASH_V2
58
};
59
60
namespace {
61
/// Stable hasher for PGO region counters.
62
///
63
/// PGOHash produces a stable hash of a given function's control flow.
64
///
65
/// Changing the output of this hash will invalidate all previously generated
66
/// profiles -- i.e., don't do it.
67
///
68
/// \note  When this hash does eventually change (years?), we still need to
69
/// support old hashes.  We'll need to pull in the version number from the
70
/// profile data format and use the matching hash function.
71
class PGOHash {
72
  uint64_t Working;
73
  unsigned Count;
74
  PGOHashVersion HashVersion;
75
  llvm::MD5 MD5;
76
77
  static const int NumBitsPerType = 6;
78
  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79
  static const unsigned TooBig = 1u << NumBitsPerType;
80
81
public:
82
  /// Hash values for AST nodes.
83
  ///
84
  /// Distinct values for AST nodes that have region counters attached.
85
  ///
86
  /// These values must be stable.  All new members must be added at the end,
87
  /// and no members should be removed.  Changing the enumeration value for an
88
  /// AST node will affect the hash of every function that contains that node.
89
  enum HashType : unsigned char {
90
    None = 0,
91
    LabelStmt = 1,
92
    WhileStmt,
93
    DoStmt,
94
    ForStmt,
95
    CXXForRangeStmt,
96
    ObjCForCollectionStmt,
97
    SwitchStmt,
98
    CaseStmt,
99
    DefaultStmt,
100
    IfStmt,
101
    CXXTryStmt,
102
    CXXCatchStmt,
103
    ConditionalOperator,
104
    BinaryOperatorLAnd,
105
    BinaryOperatorLOr,
106
    BinaryConditionalOperator,
107
    // The preceding values are available with PGO_HASH_V1.
108
109
    EndOfScope,
110
    IfThenBranch,
111
    IfElseBranch,
112
    GotoStmt,
113
    IndirectGotoStmt,
114
    BreakStmt,
115
    ContinueStmt,
116
    ReturnStmt,
117
    ThrowExpr,
118
    UnaryOperatorLNot,
119
    BinaryOperatorLT,
120
    BinaryOperatorGT,
121
    BinaryOperatorLE,
122
    BinaryOperatorGE,
123
    BinaryOperatorEQ,
124
    BinaryOperatorNE,
125
    // The preceding values are available with PGO_HASH_V2.
126
127
    // Keep this last.  It's for the static assert that follows.
128
    LastHashType
129
  };
130
  static_assert(LastHashType <= TooBig, "Too many types in HashType");
131
132
  PGOHash(PGOHashVersion HashVersion)
133
494
      : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134
  void combine(HashType Type);
135
  uint64_t finalize();
136
19.0k
  PGOHashVersion getHashVersion() const { return HashVersion; }
137
};
138
const int PGOHash::NumBitsPerType;
139
const unsigned PGOHash::NumTypesPerWord;
140
const unsigned PGOHash::TooBig;
141
142
/// Get the PGO hash version used in the given indexed profile.
143
static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144
187
                                        CodeGenModule &CGM) {
145
187
  if (PGOReader->getVersion() <= 4)
146
24
    return PGO_HASH_V1;
147
163
  return PGO_HASH_V2;
148
163
}
149
150
/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151
struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152
  using Base = RecursiveASTVisitor<MapRegionCounters>;
153
154
  /// The next counter value to assign.
155
  unsigned NextCounter;
156
  /// The function hash.
157
  PGOHash Hash;
158
  /// The map of statements to counters.
159
  llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160
161
  MapRegionCounters(PGOHashVersion HashVersion,
162
                    llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163
494
      : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164
165
  // Blocks and lambdas are handled as separate functions, so we need not
166
  // traverse them in the parent context.
167
2
  bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168
4
  bool TraverseLambdaExpr(LambdaExpr *LE) {
169
4
    // Traverse the captures, but not the body.
170
4
    for (auto C : zip(LE->captures(), LE->capture_inits()))
171
2
      TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
172
4
    return true;
173
4
  }
174
5
  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
175
176
1.13k
  bool VisitDecl(const Decl *D) {
177
1.13k
    switch (D->getKind()) {
178
642
    default:
179
642
      break;
180
494
    case Decl::Function:
181
494
    case Decl::CXXMethod:
182
494
    case Decl::CXXConstructor:
183
494
    case Decl::CXXDestructor:
184
494
    case Decl::CXXConversion:
185
494
    case Decl::ObjCMethod:
186
494
    case Decl::Block:
187
494
    case Decl::Captured:
188
494
      CounterMap[D->getBody()] = NextCounter++;
189
494
      break;
190
1.13k
    }
191
1.13k
    return true;
192
1.13k
  }
193
194
  /// If \p S gets a fresh counter, update the counter mappings. Return the
195
  /// V1 hash of \p S.
196
9.83k
  PGOHash::HashType updateCounterMappings(Stmt *S) {
197
9.83k
    auto Type = getHashType(PGO_HASH_V1, S);
198
9.83k
    if (Type != PGOHash::None)
199
1.01k
      CounterMap[S] = NextCounter++;
200
9.83k
    return Type;
201
9.83k
  }
202
203
  /// Include \p S in the function hash.
204
9.83k
  bool VisitStmt(Stmt *S) {
205
9.83k
    auto Type = updateCounterMappings(S);
206
9.83k
    if (Hash.getHashVersion() != PGO_HASH_V1)
207
8.55k
      Type = getHashType(Hash.getHashVersion(), S);
208
9.83k
    if (Type != PGOHash::None)
209
1.66k
      Hash.combine(Type);
210
9.83k
    return true;
211
9.83k
  }
212
213
361
  bool TraverseIfStmt(IfStmt *If) {
214
361
    // If we used the V1 hash, use the default traversal.
215
361
    if (Hash.getHashVersion() == PGO_HASH_V1)
216
68
      return Base::TraverseIfStmt(If);
217
293
218
293
    // Otherwise, keep track of which branch we're in while traversing.
219
293
    VisitStmt(If);
220
640
    for (Stmt *CS : If->children()) {
221
640
      if (!CS)
222
0
        continue;
223
640
      if (CS == If->getThen())
224
293
        Hash.combine(PGOHash::IfThenBranch);
225
347
      else if (CS == If->getElse())
226
52
        Hash.combine(PGOHash::IfElseBranch);
227
640
      TraverseStmt(CS);
228
640
    }
229
293
    Hash.combine(PGOHash::EndOfScope);
230
293
    return true;
231
293
  }
232
233
// If the statement type \p N is nestable, and its nesting impacts profile
234
// stability, define a custom traversal which tracks the end of the statement
235
// in the hash (provided we're not using the V1 hash).
236
#define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
237
348
  bool Traverse##N(N *S) {                                                     \
238
348
    Base::Traverse##N(S);                                                      \
239
348
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
348
      
Hash.combine(PGOHash::EndOfScope)308
; \
241
348
    return true;                                                               \
242
348
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseCXXCatchStmt(clang::CXXCatchStmt*)
Line
Count
Source
237
26
  bool Traverse##N(N *S) {                                                     \
238
26
    Base::Traverse##N(S);                                                      \
239
26
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
26
      Hash.combine(PGOHash::EndOfScope);                                       \
241
26
    return true;                                                               \
242
26
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseCXXForRangeStmt(clang::CXXForRangeStmt*)
Line
Count
Source
237
12
  bool Traverse##N(N *S) {                                                     \
238
12
    Base::Traverse##N(S);                                                      \
239
12
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
12
      Hash.combine(PGOHash::EndOfScope);                                       \
241
12
    return true;                                                               \
242
12
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseCXXTryStmt(clang::CXXTryStmt*)
Line
Count
Source
237
24
  bool Traverse##N(N *S) {                                                     \
238
24
    Base::Traverse##N(S);                                                      \
239
24
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
24
      Hash.combine(PGOHash::EndOfScope);                                       \
241
24
    return true;                                                               \
242
24
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseDoStmt(clang::DoStmt*)
Line
Count
Source
237
31
  bool Traverse##N(N *S) {                                                     \
238
31
    Base::Traverse##N(S);                                                      \
239
31
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
31
      
Hash.combine(PGOHash::EndOfScope)25
; \
241
31
    return true;                                                               \
242
31
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseForStmt(clang::ForStmt*)
Line
Count
Source
237
180
  bool Traverse##N(N *S) {                                                     \
238
180
    Base::Traverse##N(S);                                                      \
239
180
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
180
      
Hash.combine(PGOHash::EndOfScope)158
; \
241
180
    return true;                                                               \
242
180
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseObjCForCollectionStmt(clang::ObjCForCollectionStmt*)
Line
Count
Source
237
11
  bool Traverse##N(N *S) {                                                     \
238
11
    Base::Traverse##N(S);                                                      \
239
11
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
11
      Hash.combine(PGOHash::EndOfScope);                                       \
241
11
    return true;                                                               \
242
11
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseWhileStmt(clang::WhileStmt*)
Line
Count
Source
237
64
  bool Traverse##N(N *S) {                                                     \
238
64
    Base::Traverse##N(S);                                                      \
239
64
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240
64
      
Hash.combine(PGOHash::EndOfScope)52
; \
241
64
    return true;                                                               \
242
64
  }
243
244
  DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
245
  DEFINE_NESTABLE_TRAVERSAL(DoStmt)
246
  DEFINE_NESTABLE_TRAVERSAL(ForStmt)
247
  DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
248
  DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
249
  DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
250
  DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
251
252
  /// Get version \p HashVersion of the PGO hash for \p S.
253
18.3k
  PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
254
18.3k
    switch (S->getStmtClass()) {
255
15.0k
    default:
256
15.0k
      break;
257
82
    case Stmt::LabelStmtClass:
258
82
      return PGOHash::LabelStmt;
259
116
    case Stmt::WhileStmtClass:
260
116
      return PGOHash::WhileStmt;
261
56
    case Stmt::DoStmtClass:
262
56
      return PGOHash::DoStmt;
263
338
    case Stmt::ForStmtClass:
264
338
      return PGOHash::ForStmt;
265
24
    case Stmt::CXXForRangeStmtClass:
266
24
      return PGOHash::CXXForRangeStmt;
267
22
    case Stmt::ObjCForCollectionStmtClass:
268
22
      return PGOHash::ObjCForCollectionStmt;
269
84
    case Stmt::SwitchStmtClass:
270
84
      return PGOHash::SwitchStmt;
271
172
    case Stmt::CaseStmtClass:
272
172
      return PGOHash::CaseStmt;
273
52
    case Stmt::DefaultStmtClass:
274
52
      return PGOHash::DefaultStmt;
275
654
    case Stmt::IfStmtClass:
276
654
      return PGOHash::IfStmt;
277
48
    case Stmt::CXXTryStmtClass:
278
48
      return PGOHash::CXXTryStmt;
279
52
    case Stmt::CXXCatchStmtClass:
280
52
      return PGOHash::CXXCatchStmt;
281
34
    case Stmt::ConditionalOperatorClass:
282
34
      return PGOHash::ConditionalOperator;
283
10
    case Stmt::BinaryConditionalOperatorClass:
284
10
      return PGOHash::BinaryConditionalOperator;
285
1.55k
    case Stmt::BinaryOperatorClass: {
286
1.55k
      const BinaryOperator *BO = cast<BinaryOperator>(S);
287
1.55k
      if (BO->getOpcode() == BO_LAnd)
288
46
        return PGOHash::BinaryOperatorLAnd;
289
1.51k
      if (BO->getOpcode() == BO_LOr)
290
46
        return PGOHash::BinaryOperatorLOr;
291
1.46k
      if (HashVersion == PGO_HASH_V2) {
292
675
        switch (BO->getOpcode()) {
293
401
        default:
294
401
          break;
295
179
        case BO_LT:
296
179
          return PGOHash::BinaryOperatorLT;
297
24
        case BO_GT:
298
24
          return PGOHash::BinaryOperatorGT;
299
9
        case BO_LE:
300
9
          return PGOHash::BinaryOperatorLE;
301
8
        case BO_GE:
302
8
          return PGOHash::BinaryOperatorGE;
303
50
        case BO_EQ:
304
50
          return PGOHash::BinaryOperatorEQ;
305
4
        case BO_NE:
306
4
          return PGOHash::BinaryOperatorNE;
307
1.19k
        }
308
1.19k
      }
309
1.19k
      break;
310
1.19k
    }
311
16.2k
    }
312
16.2k
313
16.2k
    if (HashVersion == PGO_HASH_V2) {
314
7.45k
      switch (S->getStmtClass()) {
315
6.85k
      default:
316
6.85k
        break;
317
29
      case Stmt::GotoStmtClass:
318
29
        return PGOHash::GotoStmt;
319
2
      case Stmt::IndirectGotoStmtClass:
320
2
        return PGOHash::IndirectGotoStmt;
321
57
      case Stmt::BreakStmtClass:
322
57
        return PGOHash::BreakStmt;
323
16
      case Stmt::ContinueStmtClass:
324
16
        return PGOHash::ContinueStmt;
325
225
      case Stmt::ReturnStmtClass:
326
225
        return PGOHash::ReturnStmt;
327
17
      case Stmt::CXXThrowExprClass:
328
17
        return PGOHash::ThrowExpr;
329
254
      case Stmt::UnaryOperatorClass: {
330
254
        const UnaryOperator *UO = cast<UnaryOperator>(S);
331
254
        if (UO->getOpcode() == UO_LNot)
332
27
          return PGOHash::UnaryOperatorLNot;
333
227
        break;
334
227
      }
335
7.45k
      }
336
7.45k
    }
337
15.9k
338
15.9k
    return PGOHash::None;
339
15.9k
  }
340
};
341
342
/// A StmtVisitor that propagates the raw counts through the AST and
343
/// records the count at statements where the value may change.
344
struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
345
  /// PGO state.
346
  CodeGenPGO &PGO;
347
348
  /// A flag that is set when the current count should be recorded on the
349
  /// next statement, such as at the exit of a loop.
350
  bool RecordNextStmtCount;
351
352
  /// The count at the current location in the traversal.
353
  uint64_t CurrentCount;
354
355
  /// The map of statements to count values.
356
  llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
357
358
  /// BreakContinueStack - Keep counts of breaks and continues inside loops.
359
  struct BreakContinue {
360
    uint64_t BreakCount;
361
    uint64_t ContinueCount;
362
165
    BreakContinue() : BreakCount(0), ContinueCount(0) {}
363
  };
364
  SmallVector<BreakContinue, 8> BreakContinueStack;
365
366
  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
367
                      CodeGenPGO &PGO)
368
187
      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
369
370
4.16k
  void RecordStmtCount(const Stmt *S) {
371
4.16k
    if (RecordNextStmtCount) {
372
331
      CountMap[S] = CurrentCount;
373
331
      RecordNextStmtCount = false;
374
331
    }
375
4.16k
  }
376
377
  /// Set and return the current count.
378
1.30k
  uint64_t setCount(uint64_t Count) {
379
1.30k
    CurrentCount = Count;
380
1.30k
    return Count;
381
1.30k
  }
382
383
3.60k
  void VisitStmt(const Stmt *S) {
384
3.60k
    RecordStmtCount(S);
385
3.60k
    for (const Stmt *Child : S->children())
386
2.89k
      if (Child)
387
2.89k
        this->Visit(Child);
388
3.60k
  }
389
390
183
  void VisitFunctionDecl(const FunctionDecl *D) {
391
183
    // Counter tracks entry to the function body.
392
183
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
393
183
    CountMap[D->getBody()] = BodyCount;
394
183
    Visit(D->getBody());
395
183
  }
396
397
  // Skip lambda expressions. We visit these as FunctionDecls when we're
398
  // generating them and aren't interested in the body when generating a
399
  // parent context.
400
1
  void VisitLambdaExpr(const LambdaExpr *LE) {}
401
402
2
  void VisitCapturedDecl(const CapturedDecl *D) {
403
2
    // Counter tracks entry to the capture body.
404
2
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
405
2
    CountMap[D->getBody()] = BodyCount;
406
2
    Visit(D->getBody());
407
2
  }
408
409
1
  void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410
1
    // Counter tracks entry to the method body.
411
1
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412
1
    CountMap[D->getBody()] = BodyCount;
413
1
    Visit(D->getBody());
414
1
  }
415
416
1
  void VisitBlockDecl(const BlockDecl *D) {
417
1
    // Counter tracks entry to the block body.
418
1
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
419
1
    CountMap[D->getBody()] = BodyCount;
420
1
    Visit(D->getBody());
421
1
  }
422
423
78
  void VisitReturnStmt(const ReturnStmt *S) {
424
78
    RecordStmtCount(S);
425
78
    if (S->getRetValue())
426
69
      Visit(S->getRetValue());
427
78
    CurrentCount = 0;
428
78
    RecordNextStmtCount = true;
429
78
  }
430
431
8
  void VisitCXXThrowExpr(const CXXThrowExpr *E) {
432
8
    RecordStmtCount(E);
433
8
    if (E->getSubExpr())
434
8
      Visit(E->getSubExpr());
435
8
    CurrentCount = 0;
436
8
    RecordNextStmtCount = true;
437
8
  }
438
439
26
  void VisitGotoStmt(const GotoStmt *S) {
440
26
    RecordStmtCount(S);
441
26
    CurrentCount = 0;
442
26
    RecordNextStmtCount = true;
443
26
  }
444
445
27
  void VisitLabelStmt(const LabelStmt *S) {
446
27
    RecordNextStmtCount = false;
447
27
    // Counter tracks the block following the label.
448
27
    uint64_t BlockCount = setCount(PGO.getRegionCount(S));
449
27
    CountMap[S] = BlockCount;
450
27
    Visit(S->getSubStmt());
451
27
  }
452
453
46
  void VisitBreakStmt(const BreakStmt *S) {
454
46
    RecordStmtCount(S);
455
46
    assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
456
46
    BreakContinueStack.back().BreakCount += CurrentCount;
457
46
    CurrentCount = 0;
458
46
    RecordNextStmtCount = true;
459
46
  }
460
461
12
  void VisitContinueStmt(const ContinueStmt *S) {
462
12
    RecordStmtCount(S);
463
12
    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
464
12
    BreakContinueStack.back().ContinueCount += CurrentCount;
465
12
    CurrentCount = 0;
466
12
    RecordNextStmtCount = true;
467
12
  }
468
469
30
  void VisitWhileStmt(const WhileStmt *S) {
470
30
    RecordStmtCount(S);
471
30
    uint64_t ParentCount = CurrentCount;
472
30
473
30
    BreakContinueStack.push_back(BreakContinue());
474
30
    // Visit the body region first so the break/continue adjustments can be
475
30
    // included when visiting the condition.
476
30
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
477
30
    CountMap[S->getBody()] = CurrentCount;
478
30
    Visit(S->getBody());
479
30
    uint64_t BackedgeCount = CurrentCount;
480
30
481
30
    // ...then go back and propagate counts through the condition. The count
482
30
    // at the start of the condition is the sum of the incoming edges,
483
30
    // the backedge from the end of the loop body, and the edges from
484
30
    // continue statements.
485
30
    BreakContinue BC = BreakContinueStack.pop_back_val();
486
30
    uint64_t CondCount =
487
30
        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
488
30
    CountMap[S->getCond()] = CondCount;
489
30
    Visit(S->getCond());
490
30
    setCount(BC.BreakCount + CondCount - BodyCount);
491
30
    RecordNextStmtCount = true;
492
30
  }
493
494
19
  void VisitDoStmt(const DoStmt *S) {
495
19
    RecordStmtCount(S);
496
19
    uint64_t LoopCount = PGO.getRegionCount(S);
497
19
498
19
    BreakContinueStack.push_back(BreakContinue());
499
19
    // The count doesn't include the fallthrough from the parent scope. Add it.
500
19
    uint64_t BodyCount = setCount(LoopCount + CurrentCount);
501
19
    CountMap[S->getBody()] = BodyCount;
502
19
    Visit(S->getBody());
503
19
    uint64_t BackedgeCount = CurrentCount;
504
19
505
19
    BreakContinue BC = BreakContinueStack.pop_back_val();
506
19
    // The count at the start of the condition is equal to the count at the
507
19
    // end of the body, plus any continues.
508
19
    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
509
19
    CountMap[S->getCond()] = CondCount;
510
19
    Visit(S->getCond());
511
19
    setCount(BC.BreakCount + CondCount - LoopCount);
512
19
    RecordNextStmtCount = true;
513
19
  }
514
515
80
  void VisitForStmt(const ForStmt *S) {
516
80
    RecordStmtCount(S);
517
80
    if (S->getInit())
518
77
      Visit(S->getInit());
519
80
520
80
    uint64_t ParentCount = CurrentCount;
521
80
522
80
    BreakContinueStack.push_back(BreakContinue());
523
80
    // Visit the body region first. (This is basically the same as a while
524
80
    // loop; see further comments in VisitWhileStmt.)
525
80
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
526
80
    CountMap[S->getBody()] = BodyCount;
527
80
    Visit(S->getBody());
528
80
    uint64_t BackedgeCount = CurrentCount;
529
80
    BreakContinue BC = BreakContinueStack.pop_back_val();
530
80
531
80
    // The increment is essentially part of the body but it needs to include
532
80
    // the count for all the continue statements.
533
80
    if (S->getInc()) {
534
80
      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
535
80
      CountMap[S->getInc()] = IncCount;
536
80
      Visit(S->getInc());
537
80
    }
538
80
539
80
    // ...then go back and propagate counts through the condition.
540
80
    uint64_t CondCount =
541
80
        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
542
80
    if (S->getCond()) {
543
80
      CountMap[S->getCond()] = CondCount;
544
80
      Visit(S->getCond());
545
80
    }
546
80
    setCount(BC.BreakCount + CondCount - BodyCount);
547
80
    RecordNextStmtCount = true;
548
80
  }
549
550
9
  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
551
9
    RecordStmtCount(S);
552
9
    if (S->getInit())
553
0
      Visit(S->getInit());
554
9
    Visit(S->getLoopVarStmt());
555
9
    Visit(S->getRangeStmt());
556
9
    Visit(S->getBeginStmt());
557
9
    Visit(S->getEndStmt());
558
9
559
9
    uint64_t ParentCount = CurrentCount;
560
9
    BreakContinueStack.push_back(BreakContinue());
561
9
    // Visit the body region first. (This is basically the same as a while
562
9
    // loop; see further comments in VisitWhileStmt.)
563
9
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
564
9
    CountMap[S->getBody()] = BodyCount;
565
9
    Visit(S->getBody());
566
9
    uint64_t BackedgeCount = CurrentCount;
567
9
    BreakContinue BC = BreakContinueStack.pop_back_val();
568
9
569
9
    // The increment is essentially part of the body but it needs to include
570
9
    // the count for all the continue statements.
571
9
    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
572
9
    CountMap[S->getInc()] = IncCount;
573
9
    Visit(S->getInc());
574
9
575
9
    // ...then go back and propagate counts through the condition.
576
9
    uint64_t CondCount =
577
9
        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
578
9
    CountMap[S->getCond()] = CondCount;
579
9
    Visit(S->getCond());
580
9
    setCount(BC.BreakCount + CondCount - BodyCount);
581
9
    RecordNextStmtCount = true;
582
9
  }
583
584
5
  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
585
5
    RecordStmtCount(S);
586
5
    Visit(S->getElement());
587
5
    uint64_t ParentCount = CurrentCount;
588
5
    BreakContinueStack.push_back(BreakContinue());
589
5
    // Counter tracks the body of the loop.
590
5
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
591
5
    CountMap[S->getBody()] = BodyCount;
592
5
    Visit(S->getBody());
593
5
    uint64_t BackedgeCount = CurrentCount;
594
5
    BreakContinue BC = BreakContinueStack.pop_back_val();
595
5
596
5
    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
597
5
             BodyCount);
598
5
    RecordNextStmtCount = true;
599
5
  }
600
601
22
  void VisitSwitchStmt(const SwitchStmt *S) {
602
22
    RecordStmtCount(S);
603
22
    if (S->getInit())
604
0
      Visit(S->getInit());
605
22
    Visit(S->getCond());
606
22
    CurrentCount = 0;
607
22
    BreakContinueStack.push_back(BreakContinue());
608
22
    Visit(S->getBody());
609
22
    // If the switch is inside a loop, add the continue counts.
610
22
    BreakContinue BC = BreakContinueStack.pop_back_val();
611
22
    if (!BreakContinueStack.empty())
612
17
      BreakContinueStack.back().ContinueCount += BC.ContinueCount;
613
22
    // Counter tracks the exit block of the switch.
614
22
    setCount(PGO.getRegionCount(S));
615
22
    RecordNextStmtCount = true;
616
22
  }
617
618
79
  void VisitSwitchCase(const SwitchCase *S) {
619
79
    RecordNextStmtCount = false;
620
79
    // Counter for this particular case. This counts only jumps from the
621
79
    // switch header and does not include fallthrough from the case before
622
79
    // this one.
623
79
    uint64_t CaseCount = PGO.getRegionCount(S);
624
79
    setCount(CurrentCount + CaseCount);
625
79
    // We need the count without fallthrough in the mapping, so it's more useful
626
79
    // for branch probabilities.
627
79
    CountMap[S] = CaseCount;
628
79
    RecordNextStmtCount = true;
629
79
    Visit(S->getSubStmt());
630
79
  }
631
632
166
  void VisitIfStmt(const IfStmt *S) {
633
166
    RecordStmtCount(S);
634
166
    uint64_t ParentCount = CurrentCount;
635
166
    if (S->getInit())
636
0
      Visit(S->getInit());
637
166
    Visit(S->getCond());
638
166
639
166
    // Counter tracks the "then" part of an if statement. The count for
640
166
    // the "else" part, if it exists, will be calculated from this counter.
641
166
    uint64_t ThenCount = setCount(PGO.getRegionCount(S));
642
166
    CountMap[S->getThen()] = ThenCount;
643
166
    Visit(S->getThen());
644
166
    uint64_t OutCount = CurrentCount;
645
166
646
166
    uint64_t ElseCount = ParentCount - ThenCount;
647
166
    if (S->getElse()) {
648
24
      setCount(ElseCount);
649
24
      CountMap[S->getElse()] = ElseCount;
650
24
      Visit(S->getElse());
651
24
      OutCount += CurrentCount;
652
24
    } else
653
142
      OutCount += ElseCount;
654
166
    setCount(OutCount);
655
166
    RecordNextStmtCount = true;
656
166
  }
657
658
12
  void VisitCXXTryStmt(const CXXTryStmt *S) {
659
12
    RecordStmtCount(S);
660
12
    Visit(S->getTryBlock());
661
24
    for (unsigned I = 0, E = S->getNumHandlers(); I < E; 
++I12
)
662
12
      Visit(S->getHandler(I));
663
12
    // Counter tracks the continuation block of the try statement.
664
12
    setCount(PGO.getRegionCount(S));
665
12
    RecordNextStmtCount = true;
666
12
  }
667
668
12
  void VisitCXXCatchStmt(const CXXCatchStmt *S) {
669
12
    RecordNextStmtCount = false;
670
12
    // Counter tracks the catch statement's handler block.
671
12
    uint64_t CatchCount = setCount(PGO.getRegionCount(S));
672
12
    CountMap[S] = CatchCount;
673
12
    Visit(S->getHandlerBlock());
674
12
  }
675
676
7
  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
677
7
    RecordStmtCount(E);
678
7
    uint64_t ParentCount = CurrentCount;
679
7
    Visit(E->getCond());
680
7
681
7
    // Counter tracks the "true" part of a conditional operator. The
682
7
    // count in the "false" part will be calculated from this counter.
683
7
    uint64_t TrueCount = setCount(PGO.getRegionCount(E));
684
7
    CountMap[E->getTrueExpr()] = TrueCount;
685
7
    Visit(E->getTrueExpr());
686
7
    uint64_t OutCount = CurrentCount;
687
7
688
7
    uint64_t FalseCount = setCount(ParentCount - TrueCount);
689
7
    CountMap[E->getFalseExpr()] = FalseCount;
690
7
    Visit(E->getFalseExpr());
691
7
    OutCount += CurrentCount;
692
7
693
7
    setCount(OutCount);
694
7
    RecordNextStmtCount = true;
695
7
  }
696
697
19
  void VisitBinLAnd(const BinaryOperator *E) {
698
19
    RecordStmtCount(E);
699
19
    uint64_t ParentCount = CurrentCount;
700
19
    Visit(E->getLHS());
701
19
    // Counter tracks the right hand side of a logical and operator.
702
19
    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
703
19
    CountMap[E->getRHS()] = RHSCount;
704
19
    Visit(E->getRHS());
705
19
    setCount(ParentCount + RHSCount - CurrentCount);
706
19
    RecordNextStmtCount = true;
707
19
  }
708
709
18
  void VisitBinLOr(const BinaryOperator *E) {
710
18
    RecordStmtCount(E);
711
18
    uint64_t ParentCount = CurrentCount;
712
18
    Visit(E->getLHS());
713
18
    // Counter tracks the right hand side of a logical or operator.
714
18
    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
715
18
    CountMap[E->getRHS()] = RHSCount;
716
18
    Visit(E->getRHS());
717
18
    setCount(ParentCount + RHSCount - CurrentCount);
718
18
    RecordNextStmtCount = true;
719
18
  }
720
};
721
} // end anonymous namespace
722
723
2.60k
void PGOHash::combine(HashType Type) {
724
2.60k
  // Check that we never combine 0 and only have six bits.
725
2.60k
  assert(Type && "Hash is invalid: unexpected type 0");
726
2.60k
  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
727
2.60k
728
2.60k
  // Pass through MD5 if enough work has built up.
729
2.60k
  if (Count && 
Count % NumTypesPerWord == 02.21k
) {
730
100
    using namespace llvm::support;
731
100
    uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
732
100
    MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
733
100
    Working = 0;
734
100
  }
735
2.60k
736
2.60k
  // Accumulate the current type.
737
2.60k
  ++Count;
738
2.60k
  Working = Working << NumBitsPerType | Type;
739
2.60k
}
740
741
494
uint64_t PGOHash::finalize() {
742
494
  // Use Working as the hash directly if we never used MD5.
743
494
  if (Count <= NumTypesPerWord)
744
435
    // No need to byte swap here, since none of the math was endian-dependent.
745
435
    // This number will be byte-swapped as required on endianness transitions,
746
435
    // so we will see the same value on the other side.
747
435
    return Working;
748
59
749
59
  // Check for remaining work in Working.
750
59
  if (Working)
751
59
    MD5.update(Working);
752
59
753
59
  // Finalize the MD5 and return the hash.
754
59
  llvm::MD5::MD5Result Result;
755
59
  MD5.final(Result);
756
59
  using namespace llvm::support;
757
59
  return Result.low();
758
59
}
759
760
229k
void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
761
229k
  const Decl *D = GD.getDecl();
762
229k
  if (!D->hasBody())
763
125
    return;
764
229k
765
229k
  bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
766
229k
  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
767
229k
  if (!InstrumentRegions && 
!PGOReader229k
)
768
229k
    return;
769
548
  if (D->isImplicit())
770
17
    return;
771
531
  // Constructors and destructors may be represented by several functions in IR.
772
531
  // If so, instrument only base variant, others are implemented by delegation
773
531
  // to the base one, it would be counted twice otherwise.
774
531
  if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
775
516
    if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
776
45
      if (GD.getCtorType() != Ctor_Base &&
777
45
          
CodeGenFunction::IsConstructorDelegationValid(CCD)22
)
778
17
        return;
779
514
  }
780
514
  if (isa<CXXDestructorDecl>(D) && 
GD.getDtorType() != Dtor_Base39
)
781
20
    return;
782
494
783
494
  CGM.ClearUnusedCoverageMapping(D);
784
494
  setFuncName(Fn);
785
494
786
494
  mapRegionCounters(D);
787
494
  if (CGM.getCodeGenOpts().CoverageMapping)
788
164
    emitCounterRegionMapping(D);
789
494
  if (PGOReader) {
790
187
    SourceManager &SM = CGM.getContext().getSourceManager();
791
187
    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
792
187
    computeRegionCounts(D);
793
187
    applyFunctionAttributes(PGOReader, Fn);
794
187
  }
795
494
}
796
797
494
void CodeGenPGO::mapRegionCounters(const Decl *D) {
798
494
  // Use the latest hash version when inserting instrumentation, but use the
799
494
  // version in the indexed profile if we're reading PGO data.
800
494
  PGOHashVersion HashVersion = PGO_HASH_LATEST;
801
494
  if (auto *PGOReader = CGM.getPGOReader())
802
187
    HashVersion = getPGOHashVersion(PGOReader, CGM);
803
494
804
494
  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
805
494
  MapRegionCounters Walker(HashVersion, *RegionCounterMap);
806
494
  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
807
484
    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
808
10
  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
809
3
    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
810
7
  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
811
2
    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
812
5
  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
813
5
    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
814
494
  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
815
494
  NumRegionCounters = Walker.NextCounter;
816
494
  FunctionHash = Walker.Hash.finalize();
817
494
}
818
819
182
bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
820
182
  if (!D->getBody())
821
0
    return true;
822
182
823
182
  // Don't map the functions in system headers.
824
182
  const auto &SM = CGM.getContext().getSourceManager();
825
182
  auto Loc = D->getBody()->getBeginLoc();
826
182
  return SM.isInSystemHeader(Loc);
827
182
}
828
829
164
void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
830
164
  if (skipRegionMappingForDecl(D))
831
0
    return;
832
164
833
164
  std::string CoverageMapping;
834
164
  llvm::raw_string_ostream OS(CoverageMapping);
835
164
  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
836
164
                                CGM.getContext().getSourceManager(),
837
164
                                CGM.getLangOpts(), RegionCounterMap.get());
838
164
  MappingGen.emitCounterMapping(D, OS);
839
164
  OS.flush();
840
164
841
164
  if (CoverageMapping.empty())
842
1
    return;
843
163
844
163
  CGM.getCoverageMapping()->addFunctionMappingRecord(
845
163
      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
846
163
}
847
848
void
849
CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
850
18
                                    llvm::GlobalValue::LinkageTypes Linkage) {
851
18
  if (skipRegionMappingForDecl(D))
852
1
    return;
853
17
854
17
  std::string CoverageMapping;
855
17
  llvm::raw_string_ostream OS(CoverageMapping);
856
17
  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
857
17
                                CGM.getContext().getSourceManager(),
858
17
                                CGM.getLangOpts());
859
17
  MappingGen.emitEmptyMapping(D, OS);
860
17
  OS.flush();
861
17
862
17
  if (CoverageMapping.empty())
863
0
    return;
864
17
865
17
  setFuncName(Name, Linkage);
866
17
  CGM.getCoverageMapping()->addFunctionMappingRecord(
867
17
      FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
868
17
}
869
870
187
void CodeGenPGO::computeRegionCounts(const Decl *D) {
871
187
  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
872
187
  ComputeRegionCounts Walker(*StmtCountMap, *this);
873
187
  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
874
183
    Walker.VisitFunctionDecl(FD);
875
4
  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
876
1
    Walker.VisitObjCMethodDecl(MD);
877
3
  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
878
1
    Walker.VisitBlockDecl(BD);
879
2
  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
880
2
    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
881
187
}
882
883
void
884
CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
885
187
                                    llvm::Function *Fn) {
886
187
  if (!haveRegionCounts())
887
14
    return;
888
173
889
173
  uint64_t FunctionCount = getRegionCount(nullptr);
890
173
  Fn->setEntryCount(FunctionCount);
891
173
}
892
893
void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
894
775
                                      llvm::Value *StepV) {
895
775
  if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
896
28
    return;
897
747
  if (!Builder.GetInsertBlock())
898
5
    return;
899
742
900
742
  unsigned Counter = (*RegionCounterMap)[S];
901
742
  auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
902
742
903
742
  llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
904
742
                         Builder.getInt64(FunctionHash),
905
742
                         Builder.getInt32(NumRegionCounters),
906
742
                         Builder.getInt32(Counter), StepV};
907
742
  if (!StepV)
908
741
    Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
909
741
                       makeArrayRef(Args, 4));
910
1
  else
911
1
    Builder.CreateCall(
912
1
        CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
913
1
        makeArrayRef(Args));
914
742
}
915
916
// This method either inserts a call to the profile run-time during
917
// instrumentation or puts profile data into metadata for PGO use.
918
void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
919
18.2k
    llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
920
18.2k
921
18.2k
  if (!EnableValueProfiling)
922
18.2k
    return;
923
4
924
4
  if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
925
0
    return;
926
4
927
4
  if (isa<llvm::Constant>(ValuePtr))
928
1
    return;
929
3
930
3
  bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
931
3
  if (InstrumentValueSites && RegionCounterMap) {
932
3
    auto BuilderInsertPoint = Builder.saveIP();
933
3
    Builder.SetInsertPoint(ValueSite);
934
3
    llvm::Value *Args[5] = {
935
3
        llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
936
3
        Builder.getInt64(FunctionHash),
937
3
        Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
938
3
        Builder.getInt32(ValueKind),
939
3
        Builder.getInt32(NumValueSites[ValueKind]++)
940
3
    };
941
3
    Builder.CreateCall(
942
3
        CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
943
3
    Builder.restoreIP(BuilderInsertPoint);
944
3
    return;
945
3
  }
946
0
947
0
  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
948
0
  if (PGOReader && haveRegionCounts()) {
949
0
    // We record the top most called three functions at each call site.
950
0
    // Profile metadata contains "VP" string identifying this metadata
951
0
    // as value profiling data, then a uint32_t value for the value profiling
952
0
    // kind, a uint64_t value for the total number of times the call is
953
0
    // executed, followed by the function hash and execution count (uint64_t)
954
0
    // pairs for each function.
955
0
    if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
956
0
      return;
957
0
958
0
    llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
959
0
                            (llvm::InstrProfValueKind)ValueKind,
960
0
                            NumValueSites[ValueKind]);
961
0
962
0
    NumValueSites[ValueKind]++;
963
0
  }
964
0
}
965
966
void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
967
187
                                  bool IsInMainFile) {
968
187
  CGM.getPGOStats().addVisited(IsInMainFile);
969
187
  RegionCounts.clear();
970
187
  llvm::Expected<llvm::InstrProfRecord> RecordExpected =
971
187
      PGOReader->getInstrProfRecord(FuncName, FunctionHash);
972
187
  if (auto E = RecordExpected.takeError()) {
973
14
    auto IPE = llvm::InstrProfError::take(std::move(E));
974
14
    if (IPE == llvm::instrprof_error::unknown_function)
975
5
      CGM.getPGOStats().addMissing(IsInMainFile);
976
9
    else if (IPE == llvm::instrprof_error::hash_mismatch)
977
9
      CGM.getPGOStats().addMismatched(IsInMainFile);
978
0
    else if (IPE == llvm::instrprof_error::malformed)
979
0
      // TODO: Consider a more specific warning for this case.
980
0
      CGM.getPGOStats().addMismatched(IsInMainFile);
981
14
    return;
982
14
  }
983
173
  ProfRecord =
984
173
      std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
985
173
  RegionCounts = ProfRecord->Counts;
986
173
}
987
988
/// Calculate what to divide by to scale weights.
989
///
990
/// Given the maximum weight, calculate a divisor that will scale all the
991
/// weights to strictly less than UINT32_MAX.
992
280
static uint64_t calculateWeightScale(uint64_t MaxWeight) {
993
280
  return MaxWeight < UINT32_MAX ? 
1277
:
MaxWeight / UINT32_MAX + 13
;
994
280
}
995
996
/// Scale an individual branch weight (and add 1).
997
///
998
/// Scale a 64-bit weight down to 32-bits using \c Scale.
999
///
1000
/// According to Laplace's Rule of Succession, it is better to compute the
1001
/// weight based on the count plus 1, so universally add 1 to the value.
1002
///
1003
/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1004
/// greater than \c Weight.
1005
609
static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1006
609
  assert(Scale && "scale by 0?");
1007
609
  uint64_t Scaled = Weight / Scale + 1;
1008
609
  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1009
609
  return Scaled;
1010
609
}
1011
1012
llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1013
88.6k
                                                    uint64_t FalseCount) {
1014
88.6k
  // Check for empty weights.
1015
88.6k
  if (!TrueCount && 
!FalseCount88.4k
)
1016
88.3k
    return nullptr;
1017
264
1018
264
  // Calculate how to scale down to 32-bits.
1019
264
  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1020
264
1021
264
  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1022
264
  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1023
264
                                      scaleBranchWeight(FalseCount, Scale));
1024
264
}
1025
1026
llvm::MDNode *
1027
18
CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1028
18
  // We need at least two elements to create meaningful weights.
1029
18
  if (Weights.size() < 2)
1030
0
    return nullptr;
1031
18
1032
18
  // Check for empty weights.
1033
18
  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1034
18
  if (MaxWeight == 0)
1035
2
    return nullptr;
1036
16
1037
16
  // Calculate how to scale down to 32-bits.
1038
16
  uint64_t Scale = calculateWeightScale(MaxWeight);
1039
16
1040
16
  SmallVector<uint32_t, 16> ScaledWeights;
1041
16
  ScaledWeights.reserve(Weights.size());
1042
16
  for (uint64_t W : Weights)
1043
81
    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1044
16
1045
16
  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1046
16
  return MDHelper.createBranchWeights(ScaledWeights);
1047
16
}
1048
1049
llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1050
9.83k
                                                           uint64_t LoopCount) {
1051
9.83k
  if (!PGO.haveRegionCounts())
1052
9.70k
    return nullptr;
1053
127
  Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1054
127
  assert(CondCount.hasValue() && "missing expected loop condition count");
1055
127
  if (*CondCount == 0)
1056
34
    return nullptr;
1057
93
  return createProfileWeights(LoopCount,
1058
93
                              std::max(*CondCount, LoopCount) - LoopCount);
1059
93
}