Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// Instrumentation-based profile-guided optimization
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "CodeGenPGO.h"
14
#include "CodeGenFunction.h"
15
#include "CoverageMappingGen.h"
16
#include "clang/AST/RecursiveASTVisitor.h"
17
#include "clang/AST/StmtVisitor.h"
18
#include "llvm/IR/Intrinsics.h"
19
#include "llvm/IR/MDBuilder.h"
20
#include "llvm/Support/Endian.h"
21
#include "llvm/Support/FileSystem.h"
22
#include "llvm/Support/MD5.h"
23
24
static llvm::cl::opt<bool>
25
    EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
26
                         llvm::cl::desc("Enable value profiling"),
27
                         llvm::cl::Hidden, llvm::cl::init(false));
28
29
using namespace clang;
30
using namespace CodeGen;
31
32
void CodeGenPGO::setFuncName(StringRef Name,
33
483
                             llvm::GlobalValue::LinkageTypes Linkage) {
34
483
  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35
483
  FuncName = llvm::getPGOFuncName(
36
483
      Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37
483
      PGOReader ? 
PGOReader->getVersion()169
:
llvm::IndexedInstrProf::Version314
);
38
483
39
483
  // If we're generating a profile, create a variable for the name.
40
483
  if (CGM.getCodeGenOpts().hasProfileClangInstr())
41
314
    FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
42
483
}
43
44
466
void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45
466
  setFuncName(Fn->getName(), Fn->getLinkage());
46
466
  // Create PGOFuncName meta data.
47
466
  llvm::createPGOFuncNameMetadata(*Fn, FuncName);
48
466
}
49
50
/// The version of the PGO hash algorithm.
51
enum PGOHashVersion : unsigned {
52
  PGO_HASH_V1,
53
  PGO_HASH_V2,
54
55
  // Keep this set to the latest hash version.
56
  PGO_HASH_LATEST = PGO_HASH_V2
57
};
58
59
namespace {
60
/// Stable hasher for PGO region counters.
61
///
62
/// PGOHash produces a stable hash of a given function's control flow.
63
///
64
/// Changing the output of this hash will invalidate all previously generated
65
/// profiles -- i.e., don't do it.
66
///
67
/// \note  When this hash does eventually change (years?), we still need to
68
/// support old hashes.  We'll need to pull in the version number from the
69
/// profile data format and use the matching hash function.
70
class PGOHash {
71
  uint64_t Working;
72
  unsigned Count;
73
  PGOHashVersion HashVersion;
74
  llvm::MD5 MD5;
75
76
  static const int NumBitsPerType = 6;
77
  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
78
  static const unsigned TooBig = 1u << NumBitsPerType;
79
80
public:
81
  /// Hash values for AST nodes.
82
  ///
83
  /// Distinct values for AST nodes that have region counters attached.
84
  ///
85
  /// These values must be stable.  All new members must be added at the end,
86
  /// and no members should be removed.  Changing the enumeration value for an
87
  /// AST node will affect the hash of every function that contains that node.
88
  enum HashType : unsigned char {
89
    None = 0,
90
    LabelStmt = 1,
91
    WhileStmt,
92
    DoStmt,
93
    ForStmt,
94
    CXXForRangeStmt,
95
    ObjCForCollectionStmt,
96
    SwitchStmt,
97
    CaseStmt,
98
    DefaultStmt,
99
    IfStmt,
100
    CXXTryStmt,
101
    CXXCatchStmt,
102
    ConditionalOperator,
103
    BinaryOperatorLAnd,
104
    BinaryOperatorLOr,
105
    BinaryConditionalOperator,
106
    // The preceding values are available with PGO_HASH_V1.
107
108
    EndOfScope,
109
    IfThenBranch,
110
    IfElseBranch,
111
    GotoStmt,
112
    IndirectGotoStmt,
113
    BreakStmt,
114
    ContinueStmt,
115
    ReturnStmt,
116
    ThrowExpr,
117
    UnaryOperatorLNot,
118
    BinaryOperatorLT,
119
    BinaryOperatorGT,
120
    BinaryOperatorLE,
121
    BinaryOperatorGE,
122
    BinaryOperatorEQ,
123
    BinaryOperatorNE,
124
    // The preceding values are available with PGO_HASH_V2.
125
126
    // Keep this last.  It's for the static assert that follows.
127
    LastHashType
128
  };
129
  static_assert(LastHashType <= TooBig, "Too many types in HashType");
130
131
  PGOHash(PGOHashVersion HashVersion)
132
466
      : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
133
  void combine(HashType Type);
134
  uint64_t finalize();
135
16.9k
  PGOHashVersion getHashVersion() const { return HashVersion; }
136
};
137
const int PGOHash::NumBitsPerType;
138
const unsigned PGOHash::NumTypesPerWord;
139
const unsigned PGOHash::TooBig;
140
141
/// Get the PGO hash version used in the given indexed profile.
142
static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
143
169
                                        CodeGenModule &CGM) {
144
169
  if (PGOReader->getVersion() <= 4)
145
24
    return PGO_HASH_V1;
146
145
  return PGO_HASH_V2;
147
145
}
148
149
/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
150
struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
151
  using Base = RecursiveASTVisitor<MapRegionCounters>;
152
153
  /// The next counter value to assign.
154
  unsigned NextCounter;
155
  /// The function hash.
156
  PGOHash Hash;
157
  /// The map of statements to counters.
158
  llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
159
160
  MapRegionCounters(PGOHashVersion HashVersion,
161
                    llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
162
466
      : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
163
164
  // Blocks and lambdas are handled as separate functions, so we need not
165
  // traverse them in the parent context.
166
2
  bool TraverseBlockExpr(BlockExpr *BE) { return true; }
167
4
  bool TraverseLambdaExpr(LambdaExpr *LE) {
168
4
    // Traverse the captures, but not the body.
169
4
    for (const auto &C : zip(LE->captures(), LE->capture_inits()))
170
2
      TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
171
4
    return true;
172
4
  }
173
5
  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
174
175
1.04k
  bool VisitDecl(const Decl *D) {
176
1.04k
    switch (D->getKind()) {
177
1.04k
    default:
178
578
      break;
179
1.04k
    case Decl::Function:
180
466
    case Decl::CXXMethod:
181
466
    case Decl::CXXConstructor:
182
466
    case Decl::CXXDestructor:
183
466
    case Decl::CXXConversion:
184
466
    case Decl::ObjCMethod:
185
466
    case Decl::Block:
186
466
    case Decl::Captured:
187
466
      CounterMap[D->getBody()] = NextCounter++;
188
466
      break;
189
1.04k
    }
190
1.04k
    return true;
191
1.04k
  }
192
193
  /// If \p S gets a fresh counter, update the counter mappings. Return the
194
  /// V1 hash of \p S.
195
8.79k
  PGOHash::HashType updateCounterMappings(Stmt *S) {
196
8.79k
    auto Type = getHashType(PGO_HASH_V1, S);
197
8.79k
    if (Type != PGOHash::None)
198
919
      CounterMap[S] = NextCounter++;
199
8.79k
    return Type;
200
8.79k
  }
201
202
  /// Include \p S in the function hash.
203
8.79k
  bool VisitStmt(Stmt *S) {
204
8.79k
    auto Type = updateCounterMappings(S);
205
8.79k
    if (Hash.getHashVersion() != PGO_HASH_V1)
206
7.51k
      Type = getHashType(Hash.getHashVersion(), S);
207
8.79k
    if (Type != PGOHash::None)
208
1.48k
      Hash.combine(Type);
209
8.79k
    return true;
210
8.79k
  }
211
212
327
  bool TraverseIfStmt(IfStmt *If) {
213
327
    // If we used the V1 hash, use the default traversal.
214
327
    if (Hash.getHashVersion() == PGO_HASH_V1)
215
68
      return Base::TraverseIfStmt(If);
216
259
217
259
    // Otherwise, keep track of which branch we're in while traversing.
218
259
    VisitStmt(If);
219
562
    for (Stmt *CS : If->children()) {
220
562
      if (!CS)
221
0
        continue;
222
562
      if (CS == If->getThen())
223
259
        Hash.combine(PGOHash::IfThenBranch);
224
303
      else if (CS == If->getElse())
225
42
        Hash.combine(PGOHash::IfElseBranch);
226
562
      TraverseStmt(CS);
227
562
    }
228
259
    Hash.combine(PGOHash::EndOfScope);
229
259
    return true;
230
259
  }
231
232
// If the statement type \p N is nestable, and its nesting impacts profile
233
// stability, define a custom traversal which tracks the end of the statement
234
// in the hash (provided we're not using the V1 hash).
235
#define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
236
312
  bool Traverse##N(N *S) {                                                     \
237
312
    Base::Traverse##N(S);                                                      \
238
312
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
312
      
Hash.combine(PGOHash::EndOfScope)272
; \
240
312
    return true;                                                               \
241
312
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseCXXCatchStmt(clang::CXXCatchStmt*)
Line
Count
Source
236
18
  bool Traverse##N(N *S) {                                                     \
237
18
    Base::Traverse##N(S);                                                      \
238
18
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
18
      Hash.combine(PGOHash::EndOfScope);                                       \
240
18
    return true;                                                               \
241
18
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseCXXForRangeStmt(clang::CXXForRangeStmt*)
Line
Count
Source
236
12
  bool Traverse##N(N *S) {                                                     \
237
12
    Base::Traverse##N(S);                                                      \
238
12
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
12
      Hash.combine(PGOHash::EndOfScope);                                       \
240
12
    return true;                                                               \
241
12
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseCXXTryStmt(clang::CXXTryStmt*)
Line
Count
Source
236
16
  bool Traverse##N(N *S) {                                                     \
237
16
    Base::Traverse##N(S);                                                      \
238
16
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
16
      Hash.combine(PGOHash::EndOfScope);                                       \
240
16
    return true;                                                               \
241
16
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseDoStmt(clang::DoStmt*)
Line
Count
Source
236
31
  bool Traverse##N(N *S) {                                                     \
237
31
    Base::Traverse##N(S);                                                      \
238
31
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
31
      
Hash.combine(PGOHash::EndOfScope)25
; \
240
31
    return true;                                                               \
241
31
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseForStmt(clang::ForStmt*)
Line
Count
Source
236
160
  bool Traverse##N(N *S) {                                                     \
237
160
    Base::Traverse##N(S);                                                      \
238
160
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
160
      
Hash.combine(PGOHash::EndOfScope)138
; \
240
160
    return true;                                                               \
241
160
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseObjCForCollectionStmt(clang::ObjCForCollectionStmt*)
Line
Count
Source
236
11
  bool Traverse##N(N *S) {                                                     \
237
11
    Base::Traverse##N(S);                                                      \
238
11
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
11
      Hash.combine(PGOHash::EndOfScope);                                       \
240
11
    return true;                                                               \
241
11
  }
CodeGenPGO.cpp:(anonymous namespace)::MapRegionCounters::TraverseWhileStmt(clang::WhileStmt*)
Line
Count
Source
236
64
  bool Traverse##N(N *S) {                                                     \
237
64
    Base::Traverse##N(S);                                                      \
238
64
    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239
64
      
Hash.combine(PGOHash::EndOfScope)52
; \
240
64
    return true;                                                               \
241
64
  }
242
243
  DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
244
  DEFINE_NESTABLE_TRAVERSAL(DoStmt)
245
  DEFINE_NESTABLE_TRAVERSAL(ForStmt)
246
  DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
247
  DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
248
  DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
249
  DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
250
251
  /// Get version \p HashVersion of the PGO hash for \p S.
252
16.3k
  PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
253
16.3k
    switch (S->getStmtClass()) {
254
16.3k
    default:
255
13.3k
      break;
256
16.3k
    case Stmt::LabelStmtClass:
257
82
      return PGOHash::LabelStmt;
258
16.3k
    case Stmt::WhileStmtClass:
259
116
      return PGOHash::WhileStmt;
260
16.3k
    case Stmt::DoStmtClass:
261
56
      return PGOHash::DoStmt;
262
16.3k
    case Stmt::ForStmtClass:
263
298
      return PGOHash::ForStmt;
264
16.3k
    case Stmt::CXXForRangeStmtClass:
265
24
      return PGOHash::CXXForRangeStmt;
266
16.3k
    case Stmt::ObjCForCollectionStmtClass:
267
22
      return PGOHash::ObjCForCollectionStmt;
268
16.3k
    case Stmt::SwitchStmtClass:
269
74
      return PGOHash::SwitchStmt;
270
16.3k
    case Stmt::CaseStmtClass:
271
142
      return PGOHash::CaseStmt;
272
16.3k
    case Stmt::DefaultStmtClass:
273
42
      return PGOHash::DefaultStmt;
274
16.3k
    case Stmt::IfStmtClass:
275
586
      return PGOHash::IfStmt;
276
16.3k
    case Stmt::CXXTryStmtClass:
277
32
      return PGOHash::CXXTryStmt;
278
16.3k
    case Stmt::CXXCatchStmtClass:
279
36
      return PGOHash::CXXCatchStmt;
280
16.3k
    case Stmt::ConditionalOperatorClass:
281
32
      return PGOHash::ConditionalOperator;
282
16.3k
    case Stmt::BinaryConditionalOperatorClass:
283
10
      return PGOHash::BinaryConditionalOperator;
284
16.3k
    case Stmt::BinaryOperatorClass: {
285
1.37k
      const BinaryOperator *BO = cast<BinaryOperator>(S);
286
1.37k
      if (BO->getOpcode() == BO_LAnd)
287
46
        return PGOHash::BinaryOperatorLAnd;
288
1.32k
      if (BO->getOpcode() == BO_LOr)
289
46
        return PGOHash::BinaryOperatorLOr;
290
1.28k
      if (HashVersion == PGO_HASH_V2) {
291
583
        switch (BO->getOpcode()) {
292
583
        default:
293
347
          break;
294
583
        case BO_LT:
295
150
          return PGOHash::BinaryOperatorLT;
296
583
        case BO_GT:
297
24
          return PGOHash::BinaryOperatorGT;
298
583
        case BO_LE:
299
9
          return PGOHash::BinaryOperatorLE;
300
583
        case BO_GE:
301
4
          return PGOHash::BinaryOperatorGE;
302
583
        case BO_EQ:
303
44
          return PGOHash::BinaryOperatorEQ;
304
583
        case BO_NE:
305
5
          return PGOHash::BinaryOperatorNE;
306
1.04k
        }
307
1.04k
      }
308
1.04k
      break;
309
1.04k
    }
310
14.4k
    }
311
14.4k
312
14.4k
    if (HashVersion == PGO_HASH_V2) {
313
6.55k
      switch (S->getStmtClass()) {
314
6.55k
      default:
315
6.01k
        break;
316
6.55k
      case Stmt::GotoStmtClass:
317
29
        return PGOHash::GotoStmt;
318
6.55k
      case Stmt::IndirectGotoStmtClass:
319
2
        return PGOHash::IndirectGotoStmt;
320
6.55k
      case Stmt::BreakStmtClass:
321
48
        return PGOHash::BreakStmt;
322
6.55k
      case Stmt::ContinueStmtClass:
323
16
        return PGOHash::ContinueStmt;
324
6.55k
      case Stmt::ReturnStmtClass:
325
211
        return PGOHash::ReturnStmt;
326
6.55k
      case Stmt::CXXThrowExprClass:
327
5
        return PGOHash::ThrowExpr;
328
6.55k
      case Stmt::UnaryOperatorClass: {
329
229
        const UnaryOperator *UO = cast<UnaryOperator>(S);
330
229
        if (UO->getOpcode() == UO_LNot)
331
17
          return PGOHash::UnaryOperatorLNot;
332
212
        break;
333
212
      }
334
6.55k
      }
335
6.55k
    }
336
14.1k
337
14.1k
    return PGOHash::None;
338
14.1k
  }
339
};
340
341
/// A StmtVisitor that propagates the raw counts through the AST and
342
/// records the count at statements where the value may change.
343
struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
344
  /// PGO state.
345
  CodeGenPGO &PGO;
346
347
  /// A flag that is set when the current count should be recorded on the
348
  /// next statement, such as at the exit of a loop.
349
  bool RecordNextStmtCount;
350
351
  /// The count at the current location in the traversal.
352
  uint64_t CurrentCount;
353
354
  /// The map of statements to count values.
355
  llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
356
357
  /// BreakContinueStack - Keep counts of breaks and continues inside loops.
358
  struct BreakContinue {
359
    uint64_t BreakCount;
360
    uint64_t ContinueCount;
361
150
    BreakContinue() : BreakCount(0), ContinueCount(0) {}
362
  };
363
  SmallVector<BreakContinue, 8> BreakContinueStack;
364
365
  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
366
                      CodeGenPGO &PGO)
367
169
      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
368
369
3.43k
  void RecordStmtCount(const Stmt *S) {
370
3.43k
    if (RecordNextStmtCount) {
371
291
      CountMap[S] = CurrentCount;
372
291
      RecordNextStmtCount = false;
373
291
    }
374
3.43k
  }
375
376
  /// Set and return the current count.
377
1.16k
  uint64_t setCount(uint64_t Count) {
378
1.16k
    CurrentCount = Count;
379
1.16k
    return Count;
380
1.16k
  }
381
382
2.94k
  void VisitStmt(const Stmt *S) {
383
2.94k
    RecordStmtCount(S);
384
2.94k
    for (const Stmt *Child : S->children())
385
2.30k
      if (Child)
386
2.30k
        this->Visit(Child);
387
2.94k
  }
388
389
165
  void VisitFunctionDecl(const FunctionDecl *D) {
390
165
    // Counter tracks entry to the function body.
391
165
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
392
165
    CountMap[D->getBody()] = BodyCount;
393
165
    Visit(D->getBody());
394
165
  }
395
396
  // Skip lambda expressions. We visit these as FunctionDecls when we're
397
  // generating them and aren't interested in the body when generating a
398
  // parent context.
399
1
  void VisitLambdaExpr(const LambdaExpr *LE) {}
400
401
2
  void VisitCapturedDecl(const CapturedDecl *D) {
402
2
    // Counter tracks entry to the capture body.
403
2
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
404
2
    CountMap[D->getBody()] = BodyCount;
405
2
    Visit(D->getBody());
406
2
  }
407
408
1
  void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
409
1
    // Counter tracks entry to the method body.
410
1
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411
1
    CountMap[D->getBody()] = BodyCount;
412
1
    Visit(D->getBody());
413
1
  }
414
415
1
  void VisitBlockDecl(const BlockDecl *D) {
416
1
    // Counter tracks entry to the block body.
417
1
    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
418
1
    CountMap[D->getBody()] = BodyCount;
419
1
    Visit(D->getBody());
420
1
  }
421
422
64
  void VisitReturnStmt(const ReturnStmt *S) {
423
64
    RecordStmtCount(S);
424
64
    if (S->getRetValue())
425
55
      Visit(S->getRetValue());
426
64
    CurrentCount = 0;
427
64
    RecordNextStmtCount = true;
428
64
  }
429
430
2
  void VisitCXXThrowExpr(const CXXThrowExpr *E) {
431
2
    RecordStmtCount(E);
432
2
    if (E->getSubExpr())
433
2
      Visit(E->getSubExpr());
434
2
    CurrentCount = 0;
435
2
    RecordNextStmtCount = true;
436
2
  }
437
438
26
  void VisitGotoStmt(const GotoStmt *S) {
439
26
    RecordStmtCount(S);
440
26
    CurrentCount = 0;
441
26
    RecordNextStmtCount = true;
442
26
  }
443
444
27
  void VisitLabelStmt(const LabelStmt *S) {
445
27
    RecordNextStmtCount = false;
446
27
    // Counter tracks the block following the label.
447
27
    uint64_t BlockCount = setCount(PGO.getRegionCount(S));
448
27
    CountMap[S] = BlockCount;
449
27
    Visit(S->getSubStmt());
450
27
  }
451
452
37
  void VisitBreakStmt(const BreakStmt *S) {
453
37
    RecordStmtCount(S);
454
37
    assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
455
37
    BreakContinueStack.back().BreakCount += CurrentCount;
456
37
    CurrentCount = 0;
457
37
    RecordNextStmtCount = true;
458
37
  }
459
460
12
  void VisitContinueStmt(const ContinueStmt *S) {
461
12
    RecordStmtCount(S);
462
12
    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
463
12
    BreakContinueStack.back().ContinueCount += CurrentCount;
464
12
    CurrentCount = 0;
465
12
    RecordNextStmtCount = true;
466
12
  }
467
468
30
  void VisitWhileStmt(const WhileStmt *S) {
469
30
    RecordStmtCount(S);
470
30
    uint64_t ParentCount = CurrentCount;
471
30
472
30
    BreakContinueStack.push_back(BreakContinue());
473
30
    // Visit the body region first so the break/continue adjustments can be
474
30
    // included when visiting the condition.
475
30
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
476
30
    CountMap[S->getBody()] = CurrentCount;
477
30
    Visit(S->getBody());
478
30
    uint64_t BackedgeCount = CurrentCount;
479
30
480
30
    // ...then go back and propagate counts through the condition. The count
481
30
    // at the start of the condition is the sum of the incoming edges,
482
30
    // the backedge from the end of the loop body, and the edges from
483
30
    // continue statements.
484
30
    BreakContinue BC = BreakContinueStack.pop_back_val();
485
30
    uint64_t CondCount =
486
30
        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
487
30
    CountMap[S->getCond()] = CondCount;
488
30
    Visit(S->getCond());
489
30
    setCount(BC.BreakCount + CondCount - BodyCount);
490
30
    RecordNextStmtCount = true;
491
30
  }
492
493
19
  void VisitDoStmt(const DoStmt *S) {
494
19
    RecordStmtCount(S);
495
19
    uint64_t LoopCount = PGO.getRegionCount(S);
496
19
497
19
    BreakContinueStack.push_back(BreakContinue());
498
19
    // The count doesn't include the fallthrough from the parent scope. Add it.
499
19
    uint64_t BodyCount = setCount(LoopCount + CurrentCount);
500
19
    CountMap[S->getBody()] = BodyCount;
501
19
    Visit(S->getBody());
502
19
    uint64_t BackedgeCount = CurrentCount;
503
19
504
19
    BreakContinue BC = BreakContinueStack.pop_back_val();
505
19
    // The count at the start of the condition is equal to the count at the
506
19
    // end of the body, plus any continues.
507
19
    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
508
19
    CountMap[S->getCond()] = CondCount;
509
19
    Visit(S->getCond());
510
19
    setCount(BC.BreakCount + CondCount - LoopCount);
511
19
    RecordNextStmtCount = true;
512
19
  }
513
514
69
  void VisitForStmt(const ForStmt *S) {
515
69
    RecordStmtCount(S);
516
69
    if (S->getInit())
517
66
      Visit(S->getInit());
518
69
519
69
    uint64_t ParentCount = CurrentCount;
520
69
521
69
    BreakContinueStack.push_back(BreakContinue());
522
69
    // Visit the body region first. (This is basically the same as a while
523
69
    // loop; see further comments in VisitWhileStmt.)
524
69
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
525
69
    CountMap[S->getBody()] = BodyCount;
526
69
    Visit(S->getBody());
527
69
    uint64_t BackedgeCount = CurrentCount;
528
69
    BreakContinue BC = BreakContinueStack.pop_back_val();
529
69
530
69
    // The increment is essentially part of the body but it needs to include
531
69
    // the count for all the continue statements.
532
69
    if (S->getInc()) {
533
69
      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
534
69
      CountMap[S->getInc()] = IncCount;
535
69
      Visit(S->getInc());
536
69
    }
537
69
538
69
    // ...then go back and propagate counts through the condition.
539
69
    uint64_t CondCount =
540
69
        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
541
69
    if (S->getCond()) {
542
69
      CountMap[S->getCond()] = CondCount;
543
69
      Visit(S->getCond());
544
69
    }
545
69
    setCount(BC.BreakCount + CondCount - BodyCount);
546
69
    RecordNextStmtCount = true;
547
69
  }
548
549
9
  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
550
9
    RecordStmtCount(S);
551
9
    if (S->getInit())
552
0
      Visit(S->getInit());
553
9
    Visit(S->getLoopVarStmt());
554
9
    Visit(S->getRangeStmt());
555
9
    Visit(S->getBeginStmt());
556
9
    Visit(S->getEndStmt());
557
9
558
9
    uint64_t ParentCount = CurrentCount;
559
9
    BreakContinueStack.push_back(BreakContinue());
560
9
    // Visit the body region first. (This is basically the same as a while
561
9
    // loop; see further comments in VisitWhileStmt.)
562
9
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
563
9
    CountMap[S->getBody()] = BodyCount;
564
9
    Visit(S->getBody());
565
9
    uint64_t BackedgeCount = CurrentCount;
566
9
    BreakContinue BC = BreakContinueStack.pop_back_val();
567
9
568
9
    // The increment is essentially part of the body but it needs to include
569
9
    // the count for all the continue statements.
570
9
    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
571
9
    CountMap[S->getInc()] = IncCount;
572
9
    Visit(S->getInc());
573
9
574
9
    // ...then go back and propagate counts through the condition.
575
9
    uint64_t CondCount =
576
9
        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
577
9
    CountMap[S->getCond()] = CondCount;
578
9
    Visit(S->getCond());
579
9
    setCount(BC.BreakCount + CondCount - BodyCount);
580
9
    RecordNextStmtCount = true;
581
9
  }
582
583
5
  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
584
5
    RecordStmtCount(S);
585
5
    Visit(S->getElement());
586
5
    uint64_t ParentCount = CurrentCount;
587
5
    BreakContinueStack.push_back(BreakContinue());
588
5
    // Counter tracks the body of the loop.
589
5
    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
590
5
    CountMap[S->getBody()] = BodyCount;
591
5
    Visit(S->getBody());
592
5
    uint64_t BackedgeCount = CurrentCount;
593
5
    BreakContinue BC = BreakContinueStack.pop_back_val();
594
5
595
5
    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
596
5
             BodyCount);
597
5
    RecordNextStmtCount = true;
598
5
  }
599
600
18
  void VisitSwitchStmt(const SwitchStmt *S) {
601
18
    RecordStmtCount(S);
602
18
    if (S->getInit())
603
0
      Visit(S->getInit());
604
18
    Visit(S->getCond());
605
18
    CurrentCount = 0;
606
18
    BreakContinueStack.push_back(BreakContinue());
607
18
    Visit(S->getBody());
608
18
    // If the switch is inside a loop, add the continue counts.
609
18
    BreakContinue BC = BreakContinueStack.pop_back_val();
610
18
    if (!BreakContinueStack.empty())
611
13
      BreakContinueStack.back().ContinueCount += BC.ContinueCount;
612
18
    // Counter tracks the exit block of the switch.
613
18
    setCount(PGO.getRegionCount(S));
614
18
    RecordNextStmtCount = true;
615
18
  }
616
617
61
  void VisitSwitchCase(const SwitchCase *S) {
618
61
    RecordNextStmtCount = false;
619
61
    // Counter for this particular case. This counts only jumps from the
620
61
    // switch header and does not include fallthrough from the case before
621
61
    // this one.
622
61
    uint64_t CaseCount = PGO.getRegionCount(S);
623
61
    setCount(CurrentCount + CaseCount);
624
61
    // We need the count without fallthrough in the mapping, so it's more useful
625
61
    // for branch probabilities.
626
61
    CountMap[S] = CaseCount;
627
61
    RecordNextStmtCount = true;
628
61
    Visit(S->getSubStmt());
629
61
  }
630
631
146
  void VisitIfStmt(const IfStmt *S) {
632
146
    RecordStmtCount(S);
633
146
    uint64_t ParentCount = CurrentCount;
634
146
    if (S->getInit())
635
0
      Visit(S->getInit());
636
146
    Visit(S->getCond());
637
146
638
146
    // Counter tracks the "then" part of an if statement. The count for
639
146
    // the "else" part, if it exists, will be calculated from this counter.
640
146
    uint64_t ThenCount = setCount(PGO.getRegionCount(S));
641
146
    CountMap[S->getThen()] = ThenCount;
642
146
    Visit(S->getThen());
643
146
    uint64_t OutCount = CurrentCount;
644
146
645
146
    uint64_t ElseCount = ParentCount - ThenCount;
646
146
    if (S->getElse()) {
647
16
      setCount(ElseCount);
648
16
      CountMap[S->getElse()] = ElseCount;
649
16
      Visit(S->getElse());
650
16
      OutCount += CurrentCount;
651
16
    } else
652
130
      OutCount += ElseCount;
653
146
    setCount(OutCount);
654
146
    RecordNextStmtCount = true;
655
146
  }
656
657
8
  void VisitCXXTryStmt(const CXXTryStmt *S) {
658
8
    RecordStmtCount(S);
659
8
    Visit(S->getTryBlock());
660
16
    for (unsigned I = 0, E = S->getNumHandlers(); I < E; 
++I8
)
661
8
      Visit(S->getHandler(I));
662
8
    // Counter tracks the continuation block of the try statement.
663
8
    setCount(PGO.getRegionCount(S));
664
8
    RecordNextStmtCount = true;
665
8
  }
666
667
8
  void VisitCXXCatchStmt(const CXXCatchStmt *S) {
668
8
    RecordNextStmtCount = false;
669
8
    // Counter tracks the catch statement's handler block.
670
8
    uint64_t CatchCount = setCount(PGO.getRegionCount(S));
671
8
    CountMap[S] = CatchCount;
672
8
    Visit(S->getHandlerBlock());
673
8
  }
674
675
7
  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
676
7
    RecordStmtCount(E);
677
7
    uint64_t ParentCount = CurrentCount;
678
7
    Visit(E->getCond());
679
7
680
7
    // Counter tracks the "true" part of a conditional operator. The
681
7
    // count in the "false" part will be calculated from this counter.
682
7
    uint64_t TrueCount = setCount(PGO.getRegionCount(E));
683
7
    CountMap[E->getTrueExpr()] = TrueCount;
684
7
    Visit(E->getTrueExpr());
685
7
    uint64_t OutCount = CurrentCount;
686
7
687
7
    uint64_t FalseCount = setCount(ParentCount - TrueCount);
688
7
    CountMap[E->getFalseExpr()] = FalseCount;
689
7
    Visit(E->getFalseExpr());
690
7
    OutCount += CurrentCount;
691
7
692
7
    setCount(OutCount);
693
7
    RecordNextStmtCount = true;
694
7
  }
695
696
19
  void VisitBinLAnd(const BinaryOperator *E) {
697
19
    RecordStmtCount(E);
698
19
    uint64_t ParentCount = CurrentCount;
699
19
    Visit(E->getLHS());
700
19
    // Counter tracks the right hand side of a logical and operator.
701
19
    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
702
19
    CountMap[E->getRHS()] = RHSCount;
703
19
    Visit(E->getRHS());
704
19
    setCount(ParentCount + RHSCount - CurrentCount);
705
19
    RecordNextStmtCount = true;
706
19
  }
707
708
18
  void VisitBinLOr(const BinaryOperator *E) {
709
18
    RecordStmtCount(E);
710
18
    uint64_t ParentCount = CurrentCount;
711
18
    Visit(E->getLHS());
712
18
    // Counter tracks the right hand side of a logical or operator.
713
18
    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
714
18
    CountMap[E->getRHS()] = RHSCount;
715
18
    Visit(E->getRHS());
716
18
    setCount(ParentCount + RHSCount - CurrentCount);
717
18
    RecordNextStmtCount = true;
718
18
  }
719
};
720
} // end anonymous namespace
721
722
2.31k
void PGOHash::combine(HashType Type) {
723
2.31k
  // Check that we never combine 0 and only have six bits.
724
2.31k
  assert(Type && "Hash is invalid: unexpected type 0");
725
2.31k
  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
726
2.31k
727
2.31k
  // Pass through MD5 if enough work has built up.
728
2.31k
  if (Count && 
Count % NumTypesPerWord == 01.94k
) {
729
89
    using namespace llvm::support;
730
89
    uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
731
89
    MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
732
89
    Working = 0;
733
89
  }
734
2.31k
735
2.31k
  // Accumulate the current type.
736
2.31k
  ++Count;
737
2.31k
  Working = Working << NumBitsPerType | Type;
738
2.31k
}
739
740
466
uint64_t PGOHash::finalize() {
741
466
  // Use Working as the hash directly if we never used MD5.
742
466
  if (Count <= NumTypesPerWord)
743
414
    // No need to byte swap here, since none of the math was endian-dependent.
744
414
    // This number will be byte-swapped as required on endianness transitions,
745
414
    // so we will see the same value on the other side.
746
414
    return Working;
747
52
748
52
  // Check for remaining work in Working.
749
52
  if (Working)
750
52
    MD5.update(Working);
751
52
752
52
  // Finalize the MD5 and return the hash.
753
52
  llvm::MD5::MD5Result Result;
754
52
  MD5.final(Result);
755
52
  using namespace llvm::support;
756
52
  return Result.low();
757
52
}
758
759
371k
void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
760
371k
  const Decl *D = GD.getDecl();
761
371k
  if (!D->hasBody())
762
125
    return;
763
371k
764
371k
  bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
765
371k
  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
766
371k
  if (!InstrumentRegions && 
!PGOReader370k
)
767
370k
    return;
768
520
  if (D->isImplicit())
769
17
    return;
770
503
  // Constructors and destructors may be represented by several functions in IR.
771
503
  // If so, instrument only base variant, others are implemented by delegation
772
503
  // to the base one, it would be counted twice otherwise.
773
503
  if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
774
488
    if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
775
45
      if (GD.getCtorType() != Ctor_Base &&
776
45
          
CodeGenFunction::IsConstructorDelegationValid(CCD)22
)
777
17
        return;
778
486
  }
779
486
  if (isa<CXXDestructorDecl>(D) && 
GD.getDtorType() != Dtor_Base39
)
780
20
    return;
781
466
782
466
  CGM.ClearUnusedCoverageMapping(D);
783
466
  setFuncName(Fn);
784
466
785
466
  mapRegionCounters(D);
786
466
  if (CGM.getCodeGenOpts().CoverageMapping)
787
162
    emitCounterRegionMapping(D);
788
466
  if (PGOReader) {
789
169
    SourceManager &SM = CGM.getContext().getSourceManager();
790
169
    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
791
169
    computeRegionCounts(D);
792
169
    applyFunctionAttributes(PGOReader, Fn);
793
169
  }
794
466
}
795
796
466
void CodeGenPGO::mapRegionCounters(const Decl *D) {
797
466
  // Use the latest hash version when inserting instrumentation, but use the
798
466
  // version in the indexed profile if we're reading PGO data.
799
466
  PGOHashVersion HashVersion = PGO_HASH_LATEST;
800
466
  if (auto *PGOReader = CGM.getPGOReader())
801
169
    HashVersion = getPGOHashVersion(PGOReader, CGM);
802
466
803
466
  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
804
466
  MapRegionCounters Walker(HashVersion, *RegionCounterMap);
805
466
  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
806
456
    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
807
10
  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
808
3
    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
809
7
  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
810
2
    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
811
5
  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
812
5
    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
813
466
  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
814
466
  NumRegionCounters = Walker.NextCounter;
815
466
  FunctionHash = Walker.Hash.finalize();
816
466
}
817
818
180
bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
819
180
  if (!D->getBody())
820
0
    return true;
821
180
822
180
  // Don't map the functions in system headers.
823
180
  const auto &SM = CGM.getContext().getSourceManager();
824
180
  auto Loc = D->getBody()->getBeginLoc();
825
180
  return SM.isInSystemHeader(Loc);
826
180
}
827
828
162
void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
829
162
  if (skipRegionMappingForDecl(D))
830
0
    return;
831
162
832
162
  std::string CoverageMapping;
833
162
  llvm::raw_string_ostream OS(CoverageMapping);
834
162
  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
835
162
                                CGM.getContext().getSourceManager(),
836
162
                                CGM.getLangOpts(), RegionCounterMap.get());
837
162
  MappingGen.emitCounterMapping(D, OS);
838
162
  OS.flush();
839
162
840
162
  if (CoverageMapping.empty())
841
1
    return;
842
161
843
161
  CGM.getCoverageMapping()->addFunctionMappingRecord(
844
161
      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
845
161
}
846
847
void
848
CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
849
18
                                    llvm::GlobalValue::LinkageTypes Linkage) {
850
18
  if (skipRegionMappingForDecl(D))
851
1
    return;
852
17
853
17
  std::string CoverageMapping;
854
17
  llvm::raw_string_ostream OS(CoverageMapping);
855
17
  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
856
17
                                CGM.getContext().getSourceManager(),
857
17
                                CGM.getLangOpts());
858
17
  MappingGen.emitEmptyMapping(D, OS);
859
17
  OS.flush();
860
17
861
17
  if (CoverageMapping.empty())
862
0
    return;
863
17
864
17
  setFuncName(Name, Linkage);
865
17
  CGM.getCoverageMapping()->addFunctionMappingRecord(
866
17
      FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
867
17
}
868
869
169
void CodeGenPGO::computeRegionCounts(const Decl *D) {
870
169
  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
871
169
  ComputeRegionCounts Walker(*StmtCountMap, *this);
872
169
  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
873
165
    Walker.VisitFunctionDecl(FD);
874
4
  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
875
1
    Walker.VisitObjCMethodDecl(MD);
876
3
  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
877
1
    Walker.VisitBlockDecl(BD);
878
2
  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
879
2
    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
880
169
}
881
882
void
883
CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
884
169
                                    llvm::Function *Fn) {
885
169
  if (!haveRegionCounts())
886
14
    return;
887
155
888
155
  uint64_t FunctionCount = getRegionCount(nullptr);
889
155
  Fn->setEntryCount(FunctionCount);
890
155
}
891
892
void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
893
730
                                      llvm::Value *StepV) {
894
730
  if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
895
28
    return;
896
702
  if (!Builder.GetInsertBlock())
897
4
    return;
898
698
899
698
  unsigned Counter = (*RegionCounterMap)[S];
900
698
  auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
901
698
902
698
  llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
903
698
                         Builder.getInt64(FunctionHash),
904
698
                         Builder.getInt32(NumRegionCounters),
905
698
                         Builder.getInt32(Counter), StepV};
906
698
  if (!StepV)
907
697
    Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
908
697
                       makeArrayRef(Args, 4));
909
1
  else
910
1
    Builder.CreateCall(
911
1
        CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
912
1
        makeArrayRef(Args));
913
698
}
914
915
// This method either inserts a call to the profile run-time during
916
// instrumentation or puts profile data into metadata for PGO use.
917
void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
918
21.3k
    llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
919
21.3k
920
21.3k
  if (!EnableValueProfiling)
921
21.2k
    return;
922
4
923
4
  if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
924
0
    return;
925
4
926
4
  if (isa<llvm::Constant>(ValuePtr))
927
1
    return;
928
3
929
3
  bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
930
3
  if (InstrumentValueSites && RegionCounterMap) {
931
3
    auto BuilderInsertPoint = Builder.saveIP();
932
3
    Builder.SetInsertPoint(ValueSite);
933
3
    llvm::Value *Args[5] = {
934
3
        llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
935
3
        Builder.getInt64(FunctionHash),
936
3
        Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
937
3
        Builder.getInt32(ValueKind),
938
3
        Builder.getInt32(NumValueSites[ValueKind]++)
939
3
    };
940
3
    Builder.CreateCall(
941
3
        CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
942
3
    Builder.restoreIP(BuilderInsertPoint);
943
3
    return;
944
3
  }
945
0
946
0
  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
947
0
  if (PGOReader && haveRegionCounts()) {
948
0
    // We record the top most called three functions at each call site.
949
0
    // Profile metadata contains "VP" string identifying this metadata
950
0
    // as value profiling data, then a uint32_t value for the value profiling
951
0
    // kind, a uint64_t value for the total number of times the call is
952
0
    // executed, followed by the function hash and execution count (uint64_t)
953
0
    // pairs for each function.
954
0
    if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
955
0
      return;
956
0
957
0
    llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
958
0
                            (llvm::InstrProfValueKind)ValueKind,
959
0
                            NumValueSites[ValueKind]);
960
0
961
0
    NumValueSites[ValueKind]++;
962
0
  }
963
0
}
964
965
void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
966
169
                                  bool IsInMainFile) {
967
169
  CGM.getPGOStats().addVisited(IsInMainFile);
968
169
  RegionCounts.clear();
969
169
  llvm::Expected<llvm::InstrProfRecord> RecordExpected =
970
169
      PGOReader->getInstrProfRecord(FuncName, FunctionHash);
971
169
  if (auto E = RecordExpected.takeError()) {
972
14
    auto IPE = llvm::InstrProfError::take(std::move(E));
973
14
    if (IPE == llvm::instrprof_error::unknown_function)
974
5
      CGM.getPGOStats().addMissing(IsInMainFile);
975
9
    else if (IPE == llvm::instrprof_error::hash_mismatch)
976
9
      CGM.getPGOStats().addMismatched(IsInMainFile);
977
0
    else if (IPE == llvm::instrprof_error::malformed)
978
0
      // TODO: Consider a more specific warning for this case.
979
0
      CGM.getPGOStats().addMismatched(IsInMainFile);
980
14
    return;
981
14
  }
982
155
  ProfRecord =
983
155
      llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
984
155
  RegionCounts = ProfRecord->Counts;
985
155
}
986
987
/// Calculate what to divide by to scale weights.
988
///
989
/// Given the maximum weight, calculate a divisor that will scale all the
990
/// weights to strictly less than UINT32_MAX.
991
248
static uint64_t calculateWeightScale(uint64_t MaxWeight) {
992
248
  return MaxWeight < UINT32_MAX ? 
1245
:
MaxWeight / UINT32_MAX + 13
;
993
248
}
994
995
/// Scale an individual branch weight (and add 1).
996
///
997
/// Scale a 64-bit weight down to 32-bits using \c Scale.
998
///
999
/// According to Laplace's Rule of Succession, it is better to compute the
1000
/// weight based on the count plus 1, so universally add 1 to the value.
1001
///
1002
/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1003
/// greater than \c Weight.
1004
534
static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1005
534
  assert(Scale && "scale by 0?");
1006
534
  uint64_t Scaled = Weight / Scale + 1;
1007
534
  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1008
534
  return Scaled;
1009
534
}
1010
1011
llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1012
319k
                                                    uint64_t FalseCount) {
1013
319k
  // Check for empty weights.
1014
319k
  if (!TrueCount && 
!FalseCount318k
)
1015
318k
    return nullptr;
1016
235
1017
235
  // Calculate how to scale down to 32-bits.
1018
235
  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1019
235
1020
235
  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1021
235
  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1022
235
                                      scaleBranchWeight(FalseCount, Scale));
1023
235
}
1024
1025
llvm::MDNode *
1026
15
CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1027
15
  // We need at least two elements to create meaningful weights.
1028
15
  if (Weights.size() < 2)
1029
0
    return nullptr;
1030
15
1031
15
  // Check for empty weights.
1032
15
  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1033
15
  if (MaxWeight == 0)
1034
2
    return nullptr;
1035
13
1036
13
  // Calculate how to scale down to 32-bits.
1037
13
  uint64_t Scale = calculateWeightScale(MaxWeight);
1038
13
1039
13
  SmallVector<uint32_t, 16> ScaledWeights;
1040
13
  ScaledWeights.reserve(Weights.size());
1041
13
  for (uint64_t W : Weights)
1042
64
    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1043
13
1044
13
  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1045
13
  return MDHelper.createBranchWeights(ScaledWeights);
1046
13
}
1047
1048
llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1049
45.6k
                                                           uint64_t LoopCount) {
1050
45.6k
  if (!PGO.haveRegionCounts())
1051
45.5k
    return nullptr;
1052
116
  Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1053
116
  assert(CondCount.hasValue() && "missing expected loop condition count");
1054
116
  if (*CondCount == 0)
1055
34
    return nullptr;
1056
82
  return createProfileWeights(LoopCount,
1057
82
                              std::max(*CondCount, LoopCount) - LoopCount);
1058
82
}