Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp
Line
Count
Source (jump to first uncovered line)
1
//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
// This file contains the AArch64 / Cortex-A57 specific register allocation
9
// constraints for use by the PBQP register allocator.
10
//
11
// It is essentially a transcription of what is contained in
12
// AArch64A57FPLoadBalancing, which tries to use a balanced
13
// mix of odd and even D-registers when performing a critical sequence of
14
// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15
//===----------------------------------------------------------------------===//
16
17
#define DEBUG_TYPE "aarch64-pbqp"
18
19
#include "AArch64PBQPRegAlloc.h"
20
#include "AArch64.h"
21
#include "AArch64RegisterInfo.h"
22
#include "llvm/CodeGen/LiveIntervals.h"
23
#include "llvm/CodeGen/MachineBasicBlock.h"
24
#include "llvm/CodeGen/MachineFunction.h"
25
#include "llvm/CodeGen/MachineRegisterInfo.h"
26
#include "llvm/CodeGen/RegAllocPBQP.h"
27
#include "llvm/Support/Debug.h"
28
#include "llvm/Support/ErrorHandling.h"
29
#include "llvm/Support/raw_ostream.h"
30
31
using namespace llvm;
32
33
namespace {
34
35
#ifndef NDEBUG
36
bool isFPReg(unsigned reg) {
37
  return AArch64::FPR32RegClass.contains(reg) ||
38
         AArch64::FPR64RegClass.contains(reg) ||
39
         AArch64::FPR128RegClass.contains(reg);
40
}
41
#endif
42
43
155k
bool isOdd(unsigned reg) {
44
155k
  switch (reg) {
45
155k
  default:
46
0
    llvm_unreachable("Register is not from the expected class !");
47
155k
  case AArch64::S1:
48
77.8k
  case AArch64::S3:
49
77.8k
  case AArch64::S5:
50
77.8k
  case AArch64::S7:
51
77.8k
  case AArch64::S9:
52
77.8k
  case AArch64::S11:
53
77.8k
  case AArch64::S13:
54
77.8k
  case AArch64::S15:
55
77.8k
  case AArch64::S17:
56
77.8k
  case AArch64::S19:
57
77.8k
  case AArch64::S21:
58
77.8k
  case AArch64::S23:
59
77.8k
  case AArch64::S25:
60
77.8k
  case AArch64::S27:
61
77.8k
  case AArch64::S29:
62
77.8k
  case AArch64::S31:
63
77.8k
  case AArch64::D1:
64
77.8k
  case AArch64::D3:
65
77.8k
  case AArch64::D5:
66
77.8k
  case AArch64::D7:
67
77.8k
  case AArch64::D9:
68
77.8k
  case AArch64::D11:
69
77.8k
  case AArch64::D13:
70
77.8k
  case AArch64::D15:
71
77.8k
  case AArch64::D17:
72
77.8k
  case AArch64::D19:
73
77.8k
  case AArch64::D21:
74
77.8k
  case AArch64::D23:
75
77.8k
  case AArch64::D25:
76
77.8k
  case AArch64::D27:
77
77.8k
  case AArch64::D29:
78
77.8k
  case AArch64::D31:
79
77.8k
  case AArch64::Q1:
80
77.8k
  case AArch64::Q3:
81
77.8k
  case AArch64::Q5:
82
77.8k
  case AArch64::Q7:
83
77.8k
  case AArch64::Q9:
84
77.8k
  case AArch64::Q11:
85
77.8k
  case AArch64::Q13:
86
77.8k
  case AArch64::Q15:
87
77.8k
  case AArch64::Q17:
88
77.8k
  case AArch64::Q19:
89
77.8k
  case AArch64::Q21:
90
77.8k
  case AArch64::Q23:
91
77.8k
  case AArch64::Q25:
92
77.8k
  case AArch64::Q27:
93
77.8k
  case AArch64::Q29:
94
77.8k
  case AArch64::Q31:
95
77.8k
    return true;
96
77.8k
  case AArch64::S0:
97
77.8k
  case AArch64::S2:
98
77.8k
  case AArch64::S4:
99
77.8k
  case AArch64::S6:
100
77.8k
  case AArch64::S8:
101
77.8k
  case AArch64::S10:
102
77.8k
  case AArch64::S12:
103
77.8k
  case AArch64::S14:
104
77.8k
  case AArch64::S16:
105
77.8k
  case AArch64::S18:
106
77.8k
  case AArch64::S20:
107
77.8k
  case AArch64::S22:
108
77.8k
  case AArch64::S24:
109
77.8k
  case AArch64::S26:
110
77.8k
  case AArch64::S28:
111
77.8k
  case AArch64::S30:
112
77.8k
  case AArch64::D0:
113
77.8k
  case AArch64::D2:
114
77.8k
  case AArch64::D4:
115
77.8k
  case AArch64::D6:
116
77.8k
  case AArch64::D8:
117
77.8k
  case AArch64::D10:
118
77.8k
  case AArch64::D12:
119
77.8k
  case AArch64::D14:
120
77.8k
  case AArch64::D16:
121
77.8k
  case AArch64::D18:
122
77.8k
  case AArch64::D20:
123
77.8k
  case AArch64::D22:
124
77.8k
  case AArch64::D24:
125
77.8k
  case AArch64::D26:
126
77.8k
  case AArch64::D28:
127
77.8k
  case AArch64::D30:
128
77.8k
  case AArch64::Q0:
129
77.8k
  case AArch64::Q2:
130
77.8k
  case AArch64::Q4:
131
77.8k
  case AArch64::Q6:
132
77.8k
  case AArch64::Q8:
133
77.8k
  case AArch64::Q10:
134
77.8k
  case AArch64::Q12:
135
77.8k
  case AArch64::Q14:
136
77.8k
  case AArch64::Q16:
137
77.8k
  case AArch64::Q18:
138
77.8k
  case AArch64::Q20:
139
77.8k
  case AArch64::Q22:
140
77.8k
  case AArch64::Q24:
141
77.8k
  case AArch64::Q26:
142
77.8k
  case AArch64::Q28:
143
77.8k
  case AArch64::Q30:
144
77.8k
    return false;
145
155k
146
155k
  }
147
155k
}
148
149
77.8k
bool haveSameParity(unsigned reg1, unsigned reg2) {
150
77.8k
  assert(isFPReg(reg1) && "Expecting an FP register for reg1");
151
77.8k
  assert(isFPReg(reg2) && "Expecting an FP register for reg2");
152
77.8k
153
77.8k
  return isOdd(reg1) == isOdd(reg2);
154
77.8k
}
155
156
}
157
158
bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
159
28
                                                 unsigned Ra) {
160
28
  if (Rd == Ra)
161
0
    return false;
162
28
163
28
  LiveIntervals &LIs = G.getMetadata().LIS;
164
28
165
28
  if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
166
0
    LLVM_DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
167
0
                      << '\n');
168
0
    LLVM_DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
169
0
                      << '\n');
170
0
    return false;
171
0
  }
172
28
173
28
  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
174
28
  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
175
28
176
28
  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
177
28
    &G.getNodeMetadata(node1).getAllowedRegs();
178
28
  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
179
28
    &G.getNodeMetadata(node2).getAllowedRegs();
180
28
181
28
  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
182
28
183
28
  // The edge does not exist. Create one with the appropriate interference
184
28
  // costs.
185
28
  if (edge == G.invalidEdgeId()) {
186
28
    const LiveInterval &ld = LIs.getInterval(Rd);
187
28
    const LiveInterval &la = LIs.getInterval(Ra);
188
28
    bool livesOverlap = ld.overlaps(la);
189
28
190
28
    PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
191
28
                                 vRaAllowed->size() + 1, 0);
192
924
    for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; 
++i896
) {
193
896
      unsigned pRd = (*vRdAllowed)[i];
194
29.5k
      for (unsigned j = 0, je = vRaAllowed->size(); j != je; 
++j28.6k
) {
195
28.6k
        unsigned pRa = (*vRaAllowed)[j];
196
28.6k
        if (livesOverlap && 
TRI->regsOverlap(pRd, pRa)0
)
197
0
          costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
198
28.6k
        else
199
28.6k
          costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 
0.014.3k
:
1.014.3k
;
200
28.6k
      }
201
896
    }
202
28
    G.addEdge(node1, node2, std::move(costs));
203
28
    return true;
204
28
  }
205
0
206
0
  if (G.getEdgeNode1Id(edge) == node2) {
207
0
    std::swap(node1, node2);
208
0
    std::swap(vRdAllowed, vRaAllowed);
209
0
  }
210
0
211
0
  // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
212
0
  PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
213
0
  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
214
0
    unsigned pRd = (*vRdAllowed)[i];
215
0
216
0
    // Get the maximum cost (excluding unallocatable reg) for same parity
217
0
    // registers
218
0
    PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
219
0
    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
220
0
      unsigned pRa = (*vRaAllowed)[j];
221
0
      if (haveSameParity(pRd, pRa))
222
0
        if (costs[i + 1][j + 1] !=
223
0
                std::numeric_limits<PBQP::PBQPNum>::infinity() &&
224
0
            costs[i + 1][j + 1] > sameParityMax)
225
0
          sameParityMax = costs[i + 1][j + 1];
226
0
    }
227
0
228
0
    // Ensure all registers with a different parity have a higher cost
229
0
    // than sameParityMax
230
0
    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
231
0
      unsigned pRa = (*vRaAllowed)[j];
232
0
      if (!haveSameParity(pRd, pRa))
233
0
        if (sameParityMax > costs[i + 1][j + 1])
234
0
          costs[i + 1][j + 1] = sameParityMax + 1.0;
235
0
    }
236
0
  }
237
0
  G.updateEdgeCosts(edge, std::move(costs));
238
0
239
0
  return true;
240
0
}
241
242
void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
243
28
                                                 unsigned Ra) {
244
28
  LiveIntervals &LIs = G.getMetadata().LIS;
245
28
246
28
  // Do some Chain management
247
28
  if (Chains.count(Ra)) {
248
24
    if (Rd != Ra) {
249
24
      LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
250
24
                        << " to " << printReg(Rd, TRI) << '\n';);
251
24
      Chains.remove(Ra);
252
24
      Chains.insert(Rd);
253
24
    }
254
24
  } else {
255
4
    LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
256
4
                      << '\n';);
257
4
    Chains.insert(Rd);
258
4
  }
259
28
260
28
  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
261
28
262
28
  const LiveInterval &ld = LIs.getInterval(Rd);
263
52
  for (auto r : Chains) {
264
52
    // Skip self
265
52
    if (r == Rd)
266
28
      continue;
267
24
268
24
    const LiveInterval &lr = LIs.getInterval(r);
269
24
    if (ld.overlaps(lr)) {
270
24
      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
271
24
        &G.getNodeMetadata(node1).getAllowedRegs();
272
24
273
24
      PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
274
24
      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
275
24
        &G.getNodeMetadata(node2).getAllowedRegs();
276
24
277
24
      PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
278
24
      assert(edge != G.invalidEdgeId() &&
279
24
             "PBQP error ! The edge should exist !");
280
24
281
24
      LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
282
24
283
24
      if (G.getEdgeNode1Id(edge) == node2) {
284
0
        std::swap(node1, node2);
285
0
        std::swap(vRdAllowed, vRrAllowed);
286
0
      }
287
24
288
24
      // Enforce that cost is higher with all other Chains of the same parity
289
24
      PBQP::Matrix costs(G.getEdgeCosts(edge));
290
792
      for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; 
++i768
) {
291
768
        unsigned pRd = (*vRdAllowed)[i];
292
768
293
768
        // Get the maximum cost (excluding unallocatable reg) for all other
294
768
        // parity registers
295
768
        PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
296
25.3k
        for (unsigned j = 0, je = vRrAllowed->size(); j != je; 
++j24.5k
) {
297
24.5k
          unsigned pRa = (*vRrAllowed)[j];
298
24.5k
          if (!haveSameParity(pRd, pRa))
299
12.2k
            if (costs[i + 1][j + 1] !=
300
12.2k
                    std::numeric_limits<PBQP::PBQPNum>::infinity() &&
301
12.2k
                costs[i + 1][j + 1] > sameParityMax)
302
0
              sameParityMax = costs[i + 1][j + 1];
303
24.5k
        }
304
768
305
768
        // Ensure all registers with same parity have a higher cost
306
768
        // than sameParityMax
307
25.3k
        for (unsigned j = 0, je = vRrAllowed->size(); j != je; 
++j24.5k
) {
308
24.5k
          unsigned pRa = (*vRrAllowed)[j];
309
24.5k
          if (haveSameParity(pRd, pRa))
310
12.2k
            if (sameParityMax > costs[i + 1][j + 1])
311
11.5k
              costs[i + 1][j + 1] = sameParityMax + 1.0;
312
24.5k
        }
313
768
      }
314
24
      G.updateEdgeCosts(edge, std::move(costs));
315
24
    }
316
24
  }
317
28
}
318
319
static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
320
130
                                const MachineInstr &MI) {
321
130
  const LiveInterval &LI = LIs.getInterval(reg);
322
130
  SlotIndex SI = LIs.getInstructionIndex(MI);
323
130
  return LI.expiredAt(SI);
324
130
}
325
326
5
void A57ChainingConstraint::apply(PBQPRAGraph &G) {
327
5
  const MachineFunction &MF = G.getMetadata().MF;
328
5
  LiveIntervals &LIs = G.getMetadata().LIS;
329
5
330
5
  TRI = MF.getSubtarget().getRegisterInfo();
331
5
  LLVM_DEBUG(MF.dump());
332
5
333
8
  for (const auto &MBB: MF) {
334
8
    Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
335
8
336
163
    for (const auto &MI: MBB) {
337
163
338
163
      // Forget Chains which have expired
339
163
      for (auto r : Chains) {
340
130
        SmallVector<unsigned, 8> toDel;
341
130
        if(regJustKilledBefore(LIs, r, MI)) {
342
4
          LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
343
4
                     MI.print(dbgs()););
344
4
          toDel.push_back(r);
345
4
        }
346
130
347
134
        while (!toDel.empty()) {
348
4
          Chains.remove(toDel.back());
349
4
          toDel.pop_back();
350
4
        }
351
130
      }
352
163
353
163
      switch (MI.getOpcode()) {
354
163
      case AArch64::FMSUBSrrr:
355
28
      case AArch64::FMADDSrrr:
356
28
      case AArch64::FNMSUBSrrr:
357
28
      case AArch64::FNMADDSrrr:
358
28
      case AArch64::FMSUBDrrr:
359
28
      case AArch64::FMADDDrrr:
360
28
      case AArch64::FNMSUBDrrr:
361
28
      case AArch64::FNMADDDrrr: {
362
28
        unsigned Rd = MI.getOperand(0).getReg();
363
28
        unsigned Ra = MI.getOperand(3).getReg();
364
28
365
28
        if (addIntraChainConstraint(G, Rd, Ra))
366
28
          addInterChainConstraint(G, Rd, Ra);
367
28
        break;
368
28
      }
369
28
370
28
      case AArch64::FMLAv2f32:
371
0
      case AArch64::FMLSv2f32: {
372
0
        unsigned Rd = MI.getOperand(0).getReg();
373
0
        addInterChainConstraint(G, Rd, Rd);
374
0
        break;
375
0
      }
376
0
377
135
      default:
378
135
        break;
379
163
      }
380
163
    }
381
8
  }
382
5
}