Coverage Report

Created: 2019-04-21 11:35

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/tools/polly/lib/Transform/ScheduleOptimizer.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- Schedule.cpp - Calculate an optimized schedule ---------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This pass generates an entirely new schedule tree from the data dependences
10
// and iteration domains. The new schedule tree is computed in two steps:
11
//
12
// 1) The isl scheduling optimizer is run
13
//
14
// The isl scheduling optimizer creates a new schedule tree that maximizes
15
// parallelism and tileability and minimizes data-dependence distances. The
16
// algorithm used is a modified version of the ``Pluto'' algorithm:
17
//
18
//   U. Bondhugula, A. Hartono, J. Ramanujam, and P. Sadayappan.
19
//   A Practical Automatic Polyhedral Parallelizer and Locality Optimizer.
20
//   In Proceedings of the 2008 ACM SIGPLAN Conference On Programming Language
21
//   Design and Implementation, PLDI ’08, pages 101–113. ACM, 2008.
22
//
23
// 2) A set of post-scheduling transformations is applied on the schedule tree.
24
//
25
// These optimizations include:
26
//
27
//  - Tiling of the innermost tilable bands
28
//  - Prevectorization - The choice of a possible outer loop that is strip-mined
29
//                       to the innermost level to enable inner-loop
30
//                       vectorization.
31
//  - Some optimizations for spatial locality are also planned.
32
//
33
// For a detailed description of the schedule tree itself please see section 6
34
// of:
35
//
36
// Polyhedral AST generation is more than scanning polyhedra
37
// Tobias Grosser, Sven Verdoolaege, Albert Cohen
38
// ACM Transactions on Programming Languages and Systems (TOPLAS),
39
// 37(4), July 2015
40
// http://www.grosser.es/#pub-polyhedral-AST-generation
41
//
42
// This publication also contains a detailed discussion of the different options
43
// for polyhedral loop unrolling, full/partial tile separation and other uses
44
// of the schedule tree.
45
//
46
//===----------------------------------------------------------------------===//
47
48
#include "polly/ScheduleOptimizer.h"
49
#include "polly/CodeGen/CodeGeneration.h"
50
#include "polly/DependenceInfo.h"
51
#include "polly/LinkAllPasses.h"
52
#include "polly/Options.h"
53
#include "polly/ScopInfo.h"
54
#include "polly/ScopPass.h"
55
#include "polly/Simplify.h"
56
#include "polly/Support/ISLOStream.h"
57
#include "llvm/ADT/Statistic.h"
58
#include "llvm/Analysis/TargetTransformInfo.h"
59
#include "llvm/IR/Function.h"
60
#include "llvm/Support/CommandLine.h"
61
#include "llvm/Support/Debug.h"
62
#include "llvm/Support/raw_ostream.h"
63
#include "isl/ctx.h"
64
#include "isl/options.h"
65
#include "isl/printer.h"
66
#include "isl/schedule.h"
67
#include "isl/schedule_node.h"
68
#include "isl/union_map.h"
69
#include "isl/union_set.h"
70
#include <algorithm>
71
#include <cassert>
72
#include <cmath>
73
#include <cstdint>
74
#include <cstdlib>
75
#include <string>
76
#include <vector>
77
78
using namespace llvm;
79
using namespace polly;
80
81
#define DEBUG_TYPE "polly-opt-isl"
82
83
static cl::opt<std::string>
84
    OptimizeDeps("polly-opt-optimize-only",
85
                 cl::desc("Only a certain kind of dependences (all/raw)"),
86
                 cl::Hidden, cl::init("all"), cl::ZeroOrMore,
87
                 cl::cat(PollyCategory));
88
89
static cl::opt<std::string>
90
    SimplifyDeps("polly-opt-simplify-deps",
91
                 cl::desc("Dependences should be simplified (yes/no)"),
92
                 cl::Hidden, cl::init("yes"), cl::ZeroOrMore,
93
                 cl::cat(PollyCategory));
94
95
static cl::opt<int> MaxConstantTerm(
96
    "polly-opt-max-constant-term",
97
    cl::desc("The maximal constant term allowed (-1 is unlimited)"), cl::Hidden,
98
    cl::init(20), cl::ZeroOrMore, cl::cat(PollyCategory));
99
100
static cl::opt<int> MaxCoefficient(
101
    "polly-opt-max-coefficient",
102
    cl::desc("The maximal coefficient allowed (-1 is unlimited)"), cl::Hidden,
103
    cl::init(20), cl::ZeroOrMore, cl::cat(PollyCategory));
104
105
static cl::opt<std::string> FusionStrategy(
106
    "polly-opt-fusion", cl::desc("The fusion strategy to choose (min/max)"),
107
    cl::Hidden, cl::init("min"), cl::ZeroOrMore, cl::cat(PollyCategory));
108
109
static cl::opt<std::string>
110
    MaximizeBandDepth("polly-opt-maximize-bands",
111
                      cl::desc("Maximize the band depth (yes/no)"), cl::Hidden,
112
                      cl::init("yes"), cl::ZeroOrMore, cl::cat(PollyCategory));
113
114
static cl::opt<std::string> OuterCoincidence(
115
    "polly-opt-outer-coincidence",
116
    cl::desc("Try to construct schedules where the outer member of each band "
117
             "satisfies the coincidence constraints (yes/no)"),
118
    cl::Hidden, cl::init("no"), cl::ZeroOrMore, cl::cat(PollyCategory));
119
120
static cl::opt<int> PrevectorWidth(
121
    "polly-prevect-width",
122
    cl::desc(
123
        "The number of loop iterations to strip-mine for pre-vectorization"),
124
    cl::Hidden, cl::init(4), cl::ZeroOrMore, cl::cat(PollyCategory));
125
126
static cl::opt<bool> FirstLevelTiling("polly-tiling",
127
                                      cl::desc("Enable loop tiling"),
128
                                      cl::init(true), cl::ZeroOrMore,
129
                                      cl::cat(PollyCategory));
130
131
static cl::opt<int> LatencyVectorFma(
132
    "polly-target-latency-vector-fma",
133
    cl::desc("The minimal number of cycles between issuing two "
134
             "dependent consecutive vector fused multiply-add "
135
             "instructions."),
136
    cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
137
138
static cl::opt<int> ThroughputVectorFma(
139
    "polly-target-throughput-vector-fma",
140
    cl::desc("A throughput of the processor floating-point arithmetic units "
141
             "expressed in the number of vector fused multiply-add "
142
             "instructions per clock cycle."),
143
    cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
144
145
// This option, along with --polly-target-2nd-cache-level-associativity,
146
// --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size
147
// represent the parameters of the target cache, which do not have typical
148
// values that can be used by default. However, to apply the pattern matching
149
// optimizations, we use the values of the parameters of Intel Core i7-3820
150
// SandyBridge in case the parameters are not specified or not provided by the
151
// TargetTransformInfo.
152
static cl::opt<int> FirstCacheLevelAssociativity(
153
    "polly-target-1st-cache-level-associativity",
154
    cl::desc("The associativity of the first cache level."), cl::Hidden,
155
    cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
156
157
static cl::opt<int> FirstCacheLevelDefaultAssociativity(
158
    "polly-target-1st-cache-level-default-associativity",
159
    cl::desc("The default associativity of the first cache level"
160
             " (if not enough were provided by the TargetTransformInfo)."),
161
    cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
162
163
static cl::opt<int> SecondCacheLevelAssociativity(
164
    "polly-target-2nd-cache-level-associativity",
165
    cl::desc("The associativity of the second cache level."), cl::Hidden,
166
    cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
167
168
static cl::opt<int> SecondCacheLevelDefaultAssociativity(
169
    "polly-target-2nd-cache-level-default-associativity",
170
    cl::desc("The default associativity of the second cache level"
171
             " (if not enough were provided by the TargetTransformInfo)."),
172
    cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
173
174
static cl::opt<int> FirstCacheLevelSize(
175
    "polly-target-1st-cache-level-size",
176
    cl::desc("The size of the first cache level specified in bytes."),
177
    cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
178
179
static cl::opt<int> FirstCacheLevelDefaultSize(
180
    "polly-target-1st-cache-level-default-size",
181
    cl::desc("The default size of the first cache level specified in bytes"
182
             " (if not enough were provided by the TargetTransformInfo)."),
183
    cl::Hidden, cl::init(32768), cl::ZeroOrMore, cl::cat(PollyCategory));
184
185
static cl::opt<int> SecondCacheLevelSize(
186
    "polly-target-2nd-cache-level-size",
187
    cl::desc("The size of the second level specified in bytes."), cl::Hidden,
188
    cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
189
190
static cl::opt<int> SecondCacheLevelDefaultSize(
191
    "polly-target-2nd-cache-level-default-size",
192
    cl::desc("The default size of the second cache level specified in bytes"
193
             " (if not enough were provided by the TargetTransformInfo)."),
194
    cl::Hidden, cl::init(262144), cl::ZeroOrMore, cl::cat(PollyCategory));
195
196
static cl::opt<int> VectorRegisterBitwidth(
197
    "polly-target-vector-register-bitwidth",
198
    cl::desc("The size in bits of a vector register (if not set, this "
199
             "information is taken from LLVM's target information."),
200
    cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
201
202
static cl::opt<int> FirstLevelDefaultTileSize(
203
    "polly-default-tile-size",
204
    cl::desc("The default tile size (if not enough were provided by"
205
             " --polly-tile-sizes)"),
206
    cl::Hidden, cl::init(32), cl::ZeroOrMore, cl::cat(PollyCategory));
207
208
static cl::list<int>
209
    FirstLevelTileSizes("polly-tile-sizes",
210
                        cl::desc("A tile size for each loop dimension, filled "
211
                                 "with --polly-default-tile-size"),
212
                        cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
213
                        cl::cat(PollyCategory));
214
215
static cl::opt<bool>
216
    SecondLevelTiling("polly-2nd-level-tiling",
217
                      cl::desc("Enable a 2nd level loop of loop tiling"),
218
                      cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
219
220
static cl::opt<int> SecondLevelDefaultTileSize(
221
    "polly-2nd-level-default-tile-size",
222
    cl::desc("The default 2nd-level tile size (if not enough were provided by"
223
             " --polly-2nd-level-tile-sizes)"),
224
    cl::Hidden, cl::init(16), cl::ZeroOrMore, cl::cat(PollyCategory));
225
226
static cl::list<int>
227
    SecondLevelTileSizes("polly-2nd-level-tile-sizes",
228
                         cl::desc("A tile size for each loop dimension, filled "
229
                                  "with --polly-default-tile-size"),
230
                         cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
231
                         cl::cat(PollyCategory));
232
233
static cl::opt<bool> RegisterTiling("polly-register-tiling",
234
                                    cl::desc("Enable register tiling"),
235
                                    cl::init(false), cl::ZeroOrMore,
236
                                    cl::cat(PollyCategory));
237
238
static cl::opt<int> RegisterDefaultTileSize(
239
    "polly-register-tiling-default-tile-size",
240
    cl::desc("The default register tile size (if not enough were provided by"
241
             " --polly-register-tile-sizes)"),
242
    cl::Hidden, cl::init(2), cl::ZeroOrMore, cl::cat(PollyCategory));
243
244
static cl::opt<int> PollyPatternMatchingNcQuotient(
245
    "polly-pattern-matching-nc-quotient",
246
    cl::desc("Quotient that is obtained by dividing Nc, the parameter of the"
247
             "macro-kernel, by Nr, the parameter of the micro-kernel"),
248
    cl::Hidden, cl::init(256), cl::ZeroOrMore, cl::cat(PollyCategory));
249
250
static cl::list<int>
251
    RegisterTileSizes("polly-register-tile-sizes",
252
                      cl::desc("A tile size for each loop dimension, filled "
253
                               "with --polly-register-tile-size"),
254
                      cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
255
                      cl::cat(PollyCategory));
256
257
static cl::opt<bool>
258
    PMBasedOpts("polly-pattern-matching-based-opts",
259
                cl::desc("Perform optimizations based on pattern matching"),
260
                cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory));
261
262
static cl::opt<bool> OptimizedScops(
263
    "polly-optimized-scops",
264
    cl::desc("Polly - Dump polyhedral description of Scops optimized with "
265
             "the isl scheduling optimizer and the set of post-scheduling "
266
             "transformations is applied on the schedule tree"),
267
    cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
268
269
STATISTIC(ScopsProcessed, "Number of scops processed");
270
STATISTIC(ScopsRescheduled, "Number of scops rescheduled");
271
STATISTIC(ScopsOptimized, "Number of scops optimized");
272
273
STATISTIC(NumAffineLoopsOptimized, "Number of affine loops optimized");
274
STATISTIC(NumBoxedLoopsOptimized, "Number of boxed loops optimized");
275
276
#define THREE_STATISTICS(VARNAME, DESC)                                        \
277
  static Statistic VARNAME[3] = {                                              \
278
      {DEBUG_TYPE, #VARNAME "0", DESC " (original)", {0}, {false}},            \
279
      {DEBUG_TYPE, #VARNAME "1", DESC " (after scheduler)", {0}, {false}},     \
280
      {DEBUG_TYPE, #VARNAME "2", DESC " (after optimizer)", {0}, {false}}}
281
282
THREE_STATISTICS(NumBands, "Number of bands");
283
THREE_STATISTICS(NumBandMembers, "Number of band members");
284
THREE_STATISTICS(NumCoincident, "Number of coincident band members");
285
THREE_STATISTICS(NumPermutable, "Number of permutable bands");
286
THREE_STATISTICS(NumFilters, "Number of filter nodes");
287
THREE_STATISTICS(NumExtension, "Number of extension nodes");
288
289
STATISTIC(FirstLevelTileOpts, "Number of first level tiling applied");
290
STATISTIC(SecondLevelTileOpts, "Number of second level tiling applied");
291
STATISTIC(RegisterTileOpts, "Number of register tiling applied");
292
STATISTIC(PrevectOpts, "Number of strip-mining for prevectorization applied");
293
STATISTIC(MatMulOpts,
294
          "Number of matrix multiplication patterns detected and optimized");
295
296
/// Create an isl::union_set, which describes the isolate option based on
297
/// IsolateDomain.
298
///
299
/// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should
300
///                      belong to the current band node.
301
/// @param OutDimsNum    A number of dimensions that should belong to
302
///                      the current band node.
303
static isl::union_set getIsolateOptions(isl::set IsolateDomain,
304
10
                                        unsigned OutDimsNum) {
305
10
  unsigned Dims = IsolateDomain.dim(isl::dim::set);
306
10
  assert(OutDimsNum <= Dims &&
307
10
         "The isl::set IsolateDomain is used to describe the range of schedule "
308
10
         "dimensions values, which should be isolated. Consequently, the "
309
10
         "number of its dimensions should be greater than or equal to the "
310
10
         "number of the schedule dimensions.");
311
10
  isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
312
10
  IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
313
10
                                              Dims - OutDimsNum, OutDimsNum);
314
10
  isl::set IsolateOption = IsolateRelation.wrap();
315
10
  isl::id Id = isl::id::alloc(IsolateOption.get_ctx(), "isolate", nullptr);
316
10
  IsolateOption = IsolateOption.set_tuple_id(Id);
317
10
  return isl::union_set(IsolateOption);
318
10
}
319
320
namespace {
321
/// Create an isl::union_set, which describes the specified option for the
322
/// dimension of the current node.
323
///
324
/// @param Ctx    An isl::ctx, which is used to create the isl::union_set.
325
/// @param Option The name of the option.
326
10
isl::union_set getDimOptions(isl::ctx Ctx, const char *Option) {
327
10
  isl::space Space(Ctx, 0, 1);
328
10
  auto DimOption = isl::set::universe(Space);
329
10
  auto Id = isl::id::alloc(Ctx, Option, nullptr);
330
10
  DimOption = DimOption.set_tuple_id(Id);
331
10
  return isl::union_set(DimOption);
332
10
}
333
} // namespace
334
335
/// Create an isl::union_set, which describes the option of the form
336
/// [isolate[] -> unroll[x]].
337
///
338
/// @param Ctx An isl::ctx, which is used to create the isl::union_set.
339
3
static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) {
340
3
  isl::space Space = isl::space(Ctx, 0, 0, 1);
341
3
  isl::map UnrollIsolatedSetOption = isl::map::universe(Space);
342
3
  isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr);
343
3
  isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr);
344
3
  UnrollIsolatedSetOption =
345
3
      UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId);
346
3
  UnrollIsolatedSetOption =
347
3
      UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId);
348
3
  return UnrollIsolatedSetOption.wrap();
349
3
}
350
351
/// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
352
///
353
/// @param Set         A set, which should be modified.
354
/// @param VectorWidth A parameter, which determines the constraint.
355
13
static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
356
13
  unsigned Dims = Set.dim(isl::dim::set);
357
13
  isl::space Space = Set.get_space();
358
13
  isl::local_space LocalSpace = isl::local_space(Space);
359
13
  isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
360
13
  ExtConstr = ExtConstr.set_constant_si(0);
361
13
  ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
362
13
  Set = Set.add_constraint(ExtConstr);
363
13
  ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
364
13
  ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
365
13
  ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
366
13
  return Set.add_constraint(ExtConstr);
367
13
}
368
369
13
isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth) {
370
13
  unsigned Dims = ScheduleRange.dim(isl::dim::set);
371
13
  isl::set LoopPrefixes =
372
13
      ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
373
13
  auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
374
13
  isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
375
13
  BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
376
13
  LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
377
13
  return LoopPrefixes.subtract(BadPrefixes);
378
13
}
379
380
isl::schedule_node
381
ScheduleTreeOptimizer::isolateFullPartialTiles(isl::schedule_node Node,
382
4
                                               int VectorWidth) {
383
4
  assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
384
4
  Node = Node.child(0).child(0);
385
4
  isl::union_map SchedRelUMap = Node.get_prefix_schedule_relation();
386
4
  isl::map ScheduleRelation = isl::map::from_union_map(SchedRelUMap);
387
4
  isl::set ScheduleRange = ScheduleRelation.range();
388
4
  isl::set IsolateDomain = getPartialTilePrefixes(ScheduleRange, VectorWidth);
389
4
  auto AtomicOption = getDimOptions(IsolateDomain.get_ctx(), "atomic");
390
4
  isl::union_set IsolateOption = getIsolateOptions(IsolateDomain, 1);
391
4
  Node = Node.parent().parent();
392
4
  isl::union_set Options = IsolateOption.unite(AtomicOption);
393
4
  Node = Node.band_set_ast_build_options(Options);
394
4
  return Node;
395
4
}
396
397
isl::schedule_node ScheduleTreeOptimizer::prevectSchedBand(
398
4
    isl::schedule_node Node, unsigned DimToVectorize, int VectorWidth) {
399
4
  assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
400
4
401
4
  auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
402
4
  auto ScheduleDimensions = Space.dim(isl::dim::set);
403
4
  assert(DimToVectorize < ScheduleDimensions);
404
4
405
4
  if (DimToVectorize > 0) {
406
4
    Node = isl::manage(
407
4
        isl_schedule_node_band_split(Node.release(), DimToVectorize));
408
4
    Node = Node.child(0);
409
4
  }
410
4
  if (DimToVectorize < ScheduleDimensions - 1)
411
1
    Node = isl::manage(isl_schedule_node_band_split(Node.release(), 1));
412
4
  Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
413
4
  auto Sizes = isl::multi_val::zero(Space);
414
4
  Sizes = Sizes.set_val(0, isl::val(Node.get_ctx(), VectorWidth));
415
4
  Node =
416
4
      isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
417
4
  Node = isolateFullPartialTiles(Node, VectorWidth);
418
4
  Node = Node.child(0);
419
4
  // Make sure the "trivially vectorizable loop" is not unrolled. Otherwise,
420
4
  // we will have troubles to match it in the backend.
421
4
  Node = Node.band_set_ast_build_options(
422
4
      isl::union_set(Node.get_ctx(), "{ unroll[x]: 1 = 0 }"));
423
4
  Node = isl::manage(isl_schedule_node_band_sink(Node.release()));
424
4
  Node = Node.child(0);
425
4
  if (isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf)
426
3
    Node = Node.parent();
427
4
  auto LoopMarker = isl::id::alloc(Node.get_ctx(), "SIMD", nullptr);
428
4
  PrevectOpts++;
429
4
  return Node.insert_mark(LoopMarker);
430
4
}
431
432
isl::schedule_node ScheduleTreeOptimizer::tileNode(isl::schedule_node Node,
433
                                                   const char *Identifier,
434
                                                   ArrayRef<int> TileSizes,
435
17
                                                   int DefaultTileSize) {
436
17
  auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
437
17
  auto Dims = Space.dim(isl::dim::set);
438
17
  auto Sizes = isl::multi_val::zero(Space);
439
17
  std::string IdentifierString(Identifier);
440
58
  for (unsigned i = 0; i < Dims; 
i++41
) {
441
41
    auto tileSize = i < TileSizes.size() ? 
TileSizes[i]29
:
DefaultTileSize12
;
442
41
    Sizes = Sizes.set_val(i, isl::val(Node.get_ctx(), tileSize));
443
41
  }
444
17
  auto TileLoopMarkerStr = IdentifierString + " - Tiles";
445
17
  auto TileLoopMarker =
446
17
      isl::id::alloc(Node.get_ctx(), TileLoopMarkerStr, nullptr);
447
17
  Node = Node.insert_mark(TileLoopMarker);
448
17
  Node = Node.child(0);
449
17
  Node =
450
17
      isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
451
17
  Node = Node.child(0);
452
17
  auto PointLoopMarkerStr = IdentifierString + " - Points";
453
17
  auto PointLoopMarker =
454
17
      isl::id::alloc(Node.get_ctx(), PointLoopMarkerStr, nullptr);
455
17
  Node = Node.insert_mark(PointLoopMarker);
456
17
  return Node.child(0);
457
17
}
458
459
isl::schedule_node ScheduleTreeOptimizer::applyRegisterTiling(
460
6
    isl::schedule_node Node, ArrayRef<int> TileSizes, int DefaultTileSize) {
461
6
  Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
462
6
  auto Ctx = Node.get_ctx();
463
6
  return Node.band_set_ast_build_options(isl::union_set(Ctx, "{unroll[x]}"));
464
6
}
465
466
23
static bool isSimpleInnermostBand(const isl::schedule_node &Node) {
467
23
  assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
468
23
  assert(isl_schedule_node_n_children(Node.get()) == 1);
469
23
470
23
  auto ChildType = isl_schedule_node_get_type(Node.child(0).get());
471
23
472
23
  if (ChildType == isl_schedule_node_leaf)
473
13
    return true;
474
10
475
10
  if (ChildType != isl_schedule_node_sequence)
476
10
    return false;
477
0
478
0
  auto Sequence = Node.child(0);
479
0
480
0
  for (int c = 0, nc = isl_schedule_node_n_children(Sequence.get()); c < nc;
481
0
       ++c) {
482
0
    auto Child = Sequence.child(c);
483
0
    if (isl_schedule_node_get_type(Child.get()) != isl_schedule_node_filter)
484
0
      return false;
485
0
    if (isl_schedule_node_get_type(Child.child(0).get()) !=
486
0
        isl_schedule_node_leaf)
487
0
      return false;
488
0
  }
489
0
  return true;
490
0
}
491
492
158
bool ScheduleTreeOptimizer::isTileableBandNode(isl::schedule_node Node) {
493
158
  if (isl_schedule_node_get_type(Node.get()) != isl_schedule_node_band)
494
111
    return false;
495
47
496
47
  if (isl_schedule_node_n_children(Node.get()) != 1)
497
0
    return false;
498
47
499
47
  if (!isl_schedule_node_band_get_permutable(Node.get()))
500
10
    return false;
501
37
502
37
  auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
503
37
  auto Dims = Space.dim(isl::dim::set);
504
37
505
37
  if (Dims <= 1)
506
14
    return false;
507
23
508
23
  return isSimpleInnermostBand(Node);
509
23
}
510
511
__isl_give isl::schedule_node
512
9
ScheduleTreeOptimizer::standardBandOpts(isl::schedule_node Node, void *User) {
513
9
  if (FirstLevelTiling) {
514
6
    Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
515
6
                    FirstLevelDefaultTileSize);
516
6
    FirstLevelTileOpts++;
517
6
  }
518
9
519
9
  if (SecondLevelTiling) {
520
2
    Node = tileNode(Node, "2nd level tiling", SecondLevelTileSizes,
521
2
                    SecondLevelDefaultTileSize);
522
2
    SecondLevelTileOpts++;
523
2
  }
524
9
525
9
  if (RegisterTiling) {
526
2
    Node =
527
2
        applyRegisterTiling(Node, RegisterTileSizes, RegisterDefaultTileSize);
528
2
    RegisterTileOpts++;
529
2
  }
530
9
531
9
  if (PollyVectorizerChoice == VECTORIZER_NONE)
532
5
    return Node;
533
4
534
4
  auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
535
4
  auto Dims = Space.dim(isl::dim::set);
536
4
537
5
  for (int i = Dims - 1; i >= 0; 
i--1
)
538
5
    if (Node.band_member_get_coincident(i)) {
539
4
      Node = prevectSchedBand(Node, i, PrevectorWidth);
540
4
      break;
541
4
    }
542
4
543
4
  return Node;
544
4
}
545
546
/// Permute the two dimensions of the isl map.
547
///
548
/// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
549
/// have type @p DimType.
550
///
551
/// @param Map     The isl map to be modified.
552
/// @param DimType The type of the dimensions.
553
/// @param DstPos  The first dimension.
554
/// @param SrcPos  The second dimension.
555
/// @return        The modified map.
556
isl::map permuteDimensions(isl::map Map, isl::dim DimType, unsigned DstPos,
557
12
                           unsigned SrcPos) {
558
12
  assert(DstPos < Map.dim(DimType) && SrcPos < Map.dim(DimType));
559
12
  if (DstPos == SrcPos)
560
4
    return Map;
561
8
  isl::id DimId;
562
8
  if (Map.has_tuple_id(DimType))
563
0
    DimId = Map.get_tuple_id(DimType);
564
8
  auto FreeDim = DimType == isl::dim::in ? 
isl::dim::out0
: isl::dim::in;
565
8
  isl::id FreeDimId;
566
8
  if (Map.has_tuple_id(FreeDim))
567
8
    FreeDimId = Map.get_tuple_id(FreeDim);
568
8
  auto MaxDim = std::max(DstPos, SrcPos);
569
8
  auto MinDim = std::min(DstPos, SrcPos);
570
8
  Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1);
571
8
  Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1);
572
8
  Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1);
573
8
  Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1);
574
8
  if (DimId)
575
0
    Map = Map.set_tuple_id(DimType, DimId);
576
8
  if (FreeDimId)
577
8
    Map = Map.set_tuple_id(FreeDim, FreeDimId);
578
8
  return Map;
579
8
}
580
581
/// Check the form of the access relation.
582
///
583
/// Check that the access relation @p AccMap has the form M[i][j], where i
584
/// is a @p FirstPos and j is a @p SecondPos.
585
///
586
/// @param AccMap    The access relation to be checked.
587
/// @param FirstPos  The index of the input dimension that is mapped to
588
///                  the first output dimension.
589
/// @param SecondPos The index of the input dimension that is mapped to the
590
///                  second output dimension.
591
/// @return          True in case @p AccMap has the expected form and false,
592
///                  otherwise.
593
static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
594
28
                               int &SecondPos) {
595
28
  isl::space Space = AccMap.get_space();
596
28
  isl::map Universe = isl::map::universe(Space);
597
28
598
28
  if (Space.dim(isl::dim::out) != 2)
599
0
    return false;
600
28
601
28
  // MatMul has the form:
602
28
  // for (i = 0; i < N; i++)
603
28
  //   for (j = 0; j < M; j++)
604
28
  //     for (k = 0; k < P; k++)
605
28
  //       C[i, j] += A[i, k] * B[k, j]
606
28
  //
607
28
  // Permutation of three outer loops: 3! = 6 possibilities.
608
28
  int FirstDims[] = {0, 0, 1, 1, 2, 2};
609
28
  int SecondDims[] = {1, 2, 2, 0, 0, 1};
610
124
  for (int i = 0; i < 6; 
i += 196
) {
611
112
    auto PossibleMatMul =
612
112
        Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
613
112
            .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
614
112
615
112
    AccMap = AccMap.intersect_domain(Domain);
616
112
    PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
617
112
618
112
    // If AccMap spans entire domain (Non-partial write),
619
112
    // compute FirstPos and SecondPos.
620
112
    // If AccMap != PossibleMatMul here (the two maps have been gisted at
621
112
    // this point), it means that the writes are not complete, or in other
622
112
    // words, it is a Partial write and Partial writes must be rejected.
623
112
    if (AccMap.is_equal(PossibleMatMul)) {
624
28
      if (FirstPos != -1 && 
FirstPos != FirstDims[i]24
)
625
8
        continue;
626
20
      FirstPos = FirstDims[i];
627
20
      if (SecondPos != -1 && 
SecondPos != SecondDims[i]16
)
628
4
        continue;
629
16
      SecondPos = SecondDims[i];
630
16
      return true;
631
16
    }
632
112
  }
633
28
634
28
  
return false12
;
635
28
}
636
637
/// Does the memory access represent a non-scalar operand of the matrix
638
/// multiplication.
639
///
640
/// Check that the memory access @p MemAccess is the read access to a non-scalar
641
/// operand of the matrix multiplication or its result.
642
///
643
/// @param MemAccess The memory access to be checked.
644
/// @param MMI       Parameters of the matrix multiplication operands.
645
/// @return          True in case the memory access represents the read access
646
///                  to a non-scalar operand of the matrix multiplication and
647
///                  false, otherwise.
648
static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
649
12
                                        MatMulInfoTy &MMI) {
650
12
  if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
651
0
    return false;
652
12
  auto AccMap = MemAccess->getLatestAccessRelation();
653
12
  isl::set StmtDomain = MemAccess->getStatement()->getDomain();
654
12
  if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && 
!MMI.ReadFromC4
) {
655
4
    MMI.ReadFromC = MemAccess;
656
4
    return true;
657
4
  }
658
8
  if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && 
!MMI.A4
) {
659
4
    MMI.A = MemAccess;
660
4
    return true;
661
4
  }
662
4
  if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
663
4
    MMI.B = MemAccess;
664
4
    return true;
665
4
  }
666
0
  return false;
667
0
}
668
669
/// Check accesses to operands of the matrix multiplication.
670
///
671
/// Check that accesses of the SCoP statement, which corresponds to
672
/// the partial schedule @p PartialSchedule, are scalar in terms of loops
673
/// containing the matrix multiplication, in case they do not represent
674
/// accesses to the non-scalar operands of the matrix multiplication or
675
/// its result.
676
///
677
/// @param  PartialSchedule The partial schedule of the SCoP statement.
678
/// @param  MMI             Parameters of the matrix multiplication operands.
679
/// @return                 True in case the corresponding SCoP statement
680
///                         represents matrix multiplication and false,
681
///                         otherwise.
682
static bool containsOnlyMatrMultAcc(isl::map PartialSchedule,
683
4
                                    MatMulInfoTy &MMI) {
684
4
  auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in);
685
4
  auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user());
686
4
  unsigned OutDimNum = PartialSchedule.dim(isl::dim::out);
687
4
  assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest "
688
4
                          "and, consequently, the corresponding scheduling "
689
4
                          "functions have at least three dimensions.");
690
4
  auto MapI =
691
4
      permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1);
692
4
  auto MapJ =
693
4
      permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1);
694
4
  auto MapK =
695
4
      permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1);
696
4
697
4
  auto Accesses = getAccessesInOrder(*Stmt);
698
18
  for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; 
MemA++14
) {
699
14
    auto *MemAccessPtr = *MemA;
700
14
    if (MemAccessPtr->isLatestArrayKind() && 
MemAccessPtr != MMI.WriteToC12
&&
701
14
        
!isMatMulNonScalarReadAccess(MemAccessPtr, MMI)12
&&
702
14
        
!(MemAccessPtr->isStrideZero(MapI))0
&&
703
14
        
MemAccessPtr->isStrideZero(MapJ)0
&&
MemAccessPtr->isStrideZero(MapK)0
)
704
0
      return false;
705
14
  }
706
4
  return true;
707
4
}
708
709
/// Check for dependencies corresponding to the matrix multiplication.
710
///
711
/// Check that there is only true dependence of the form
712
/// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement
713
/// represented by @p Schedule and k is @p Pos. Such a dependence corresponds
714
/// to the dependency produced by the matrix multiplication.
715
///
716
/// @param  Schedule The schedule of the SCoP statement.
717
/// @param  D The SCoP dependencies.
718
/// @param  Pos The parameter to describe an acceptable true dependence.
719
///             In case it has a negative value, try to determine its
720
///             acceptable value.
721
/// @return True in case dependencies correspond to the matrix multiplication
722
///         and false, otherwise.
723
static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D,
724
4
                                  int &Pos) {
725
4
  isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
726
4
  isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
727
4
  if (Red)
728
4
    Dep = Dep.unite(Red);
729
4
  auto DomainSpace = Schedule.get_space().domain();
730
4
  auto Space = DomainSpace.map_from_domain_and_range(DomainSpace);
731
4
  auto Deltas = Dep.extract_map(Space).deltas();
732
4
  int DeltasDimNum = Deltas.dim(isl::dim::set);
733
16
  for (int i = 0; i < DeltasDimNum; 
i++12
) {
734
12
    auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i);
735
12
    Pos = Pos < 0 && Val.is_one() ? 
i4
:
Pos8
;
736
12
    if (Val.is_nan() || !(Val.is_zero() || 
(4
i == Pos4
&&
Val.is_one()4
)))
737
0
      return false;
738
12
  }
739
4
  if (DeltasDimNum == 0 || Pos < 0)
740
0
    return false;
741
4
  return true;
742
4
}
743
744
/// Check if the SCoP statement could probably be optimized with analytical
745
/// modeling.
746
///
747
/// containsMatrMult tries to determine whether the following conditions
748
/// are true:
749
/// 1. The last memory access modeling an array, MA1, represents writing to
750
///    memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or
751
///    S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement
752
///    under consideration.
753
/// 2. There is only one loop-carried true dependency, and it has the
754
///    form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no
755
///    loop-carried or anti dependencies.
756
/// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent
757
///    reading from memory and have the form S(..., i3, ...) -> M(i1, i3),
758
///    S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively,
759
///    and all memory accesses of the SCoP that are different from MA1, MA2,
760
///    MA3, and MA4 have stride 0, if the innermost loop is exchanged with any
761
///    of loops i1, i2 and i3.
762
///
763
/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
764
///        to check.
765
/// @D     The SCoP dependencies.
766
/// @MMI   Parameters of the matrix multiplication operands.
767
static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
768
4
                             MatMulInfoTy &MMI) {
769
4
  auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
770
4
  auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
771
4
  if (Stmt->size() <= 1)
772
0
    return false;
773
4
774
4
  auto Accesses = getAccessesInOrder(*Stmt);
775
4
  for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); 
MemA--0
) {
776
4
    auto *MemAccessPtr = *MemA;
777
4
    if (!MemAccessPtr->isLatestArrayKind())
778
0
      continue;
779
4
    if (!MemAccessPtr->isWrite())
780
0
      return false;
781
4
    auto AccMap = MemAccessPtr->getLatestAccessRelation();
782
4
    if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
783
0
      return false;
784
4
    MMI.WriteToC = MemAccessPtr;
785
4
    break;
786
4
  }
787
4
788
4
  if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k))
789
0
    return false;
790
4
791
4
  if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI))
792
0
    return false;
793
4
794
4
  if (!MMI.A || !MMI.B || !MMI.ReadFromC)
795
0
    return false;
796
4
  return true;
797
4
}
798
799
/// Permute two dimensions of the band node.
800
///
801
/// Permute FirstDim and SecondDim dimensions of the Node.
802
///
803
/// @param Node The band node to be modified.
804
/// @param FirstDim The first dimension to be permuted.
805
/// @param SecondDim The second dimension to be permuted.
806
static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node,
807
                                                    unsigned FirstDim,
808
22
                                                    unsigned SecondDim) {
809
22
  assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band &&
810
22
         isl_schedule_node_band_n_member(Node.get()) >
811
22
             std::max(FirstDim, SecondDim));
812
22
  auto PartialSchedule =
813
22
      isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get()));
814
22
  auto PartialScheduleFirstDim = PartialSchedule.get_union_pw_aff(FirstDim);
815
22
  auto PartialScheduleSecondDim = PartialSchedule.get_union_pw_aff(SecondDim);
816
22
  PartialSchedule =
817
22
      PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim);
818
22
  PartialSchedule =
819
22
      PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim);
820
22
  Node = isl::manage(isl_schedule_node_delete(Node.release()));
821
22
  return Node.insert_partial_schedule(PartialSchedule);
822
22
}
823
824
isl::schedule_node ScheduleTreeOptimizer::createMicroKernel(
825
4
    isl::schedule_node Node, MicroKernelParamsTy MicroKernelParams) {
826
4
  Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr},
827
4
                             1);
828
4
  Node = Node.parent().parent();
829
4
  return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0);
830
4
}
831
832
isl::schedule_node ScheduleTreeOptimizer::createMacroKernel(
833
4
    isl::schedule_node Node, MacroKernelParamsTy MacroKernelParams) {
834
4
  assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
835
4
  if (MacroKernelParams.Mc == 1 && 
MacroKernelParams.Nc == 11
&&
836
4
      
MacroKernelParams.Kc == 11
)
837
1
    return Node;
838
3
  int DimOutNum = isl_schedule_node_band_n_member(Node.get());
839
3
  std::vector<int> TileSizes(DimOutNum, 1);
840
3
  TileSizes[DimOutNum - 3] = MacroKernelParams.Mc;
841
3
  TileSizes[DimOutNum - 2] = MacroKernelParams.Nc;
842
3
  TileSizes[DimOutNum - 1] = MacroKernelParams.Kc;
843
3
  Node = tileNode(Node, "1st level tiling", TileSizes, 1);
844
3
  Node = Node.parent().parent();
845
3
  Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1);
846
3
  Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1);
847
3
  return Node.child(0).child(0);
848
3
}
849
850
/// Get the size of the widest type of the matrix multiplication operands
851
/// in bytes, including alignment padding.
852
///
853
/// @param MMI Parameters of the matrix multiplication operands.
854
/// @return The size of the widest type of the matrix multiplication operands
855
///         in bytes, including alignment padding.
856
3
static uint64_t getMatMulAlignTypeSize(MatMulInfoTy MMI) {
857
3
  auto *S = MMI.A->getStatement()->getParent();
858
3
  auto &DL = S->getFunction().getParent()->getDataLayout();
859
3
  auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType());
860
3
  auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType());
861
3
  auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType());
862
3
  return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
863
3
}
864
865
/// Get the size of the widest type of the matrix multiplication operands
866
/// in bits.
867
///
868
/// @param MMI Parameters of the matrix multiplication operands.
869
/// @return The size of the widest type of the matrix multiplication operands
870
///         in bits.
871
4
static uint64_t getMatMulTypeSize(MatMulInfoTy MMI) {
872
4
  auto *S = MMI.A->getStatement()->getParent();
873
4
  auto &DL = S->getFunction().getParent()->getDataLayout();
874
4
  auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType());
875
4
  auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType());
876
4
  auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType());
877
4
  return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
878
4
}
879
880
/// Get parameters of the BLIS micro kernel.
881
///
882
/// We choose the Mr and Nr parameters of the micro kernel to be large enough
883
/// such that no stalls caused by the combination of latencies and dependencies
884
/// are introduced during the updates of the resulting matrix of the matrix
885
/// multiplication. However, they should also be as small as possible to
886
/// release more registers for entries of multiplied matrices.
887
///
888
/// @param TTI Target Transform Info.
889
/// @param MMI Parameters of the matrix multiplication operands.
890
/// @return The structure of type MicroKernelParamsTy.
891
/// @see MicroKernelParamsTy
892
static struct MicroKernelParamsTy
893
4
getMicroKernelParams(const TargetTransformInfo *TTI, MatMulInfoTy MMI) {
894
4
  assert(TTI && "The target transform info should be provided.");
895
4
896
4
  // Nvec - Number of double-precision floating-point numbers that can be hold
897
4
  // by a vector register. Use 2 by default.
898
4
  long RegisterBitwidth = VectorRegisterBitwidth;
899
4
900
4
  if (RegisterBitwidth == -1)
901
0
    RegisterBitwidth = TTI->getRegisterBitWidth(true);
902
4
  auto ElementSize = getMatMulTypeSize(MMI);
903
4
  assert(ElementSize > 0 && "The element size of the matrix multiplication "
904
4
                            "operands should be greater than zero.");
905
4
  auto Nvec = RegisterBitwidth / ElementSize;
906
4
  if (Nvec == 0)
907
0
    Nvec = 2;
908
4
  int Nr =
909
4
      ceil(sqrt(Nvec * LatencyVectorFma * ThroughputVectorFma) / Nvec) * Nvec;
910
4
  int Mr = ceil(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr);
911
4
  return {Mr, Nr};
912
4
}
913
914
namespace {
915
/// Determine parameters of the target cache.
916
///
917
/// @param TTI Target Transform Info.
918
4
void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) {
919
4
  auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D;
920
4
  auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D;
921
4
  if (FirstCacheLevelSize == -1) {
922
0
    if (TTI->getCacheSize(L1DCache).hasValue())
923
0
      FirstCacheLevelSize = TTI->getCacheSize(L1DCache).getValue();
924
0
    else
925
0
      FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize);
926
0
  }
927
4
  if (SecondCacheLevelSize == -1) {
928
1
    if (TTI->getCacheSize(L2DCache).hasValue())
929
1
      SecondCacheLevelSize = TTI->getCacheSize(L2DCache).getValue();
930
0
    else
931
0
      SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize);
932
1
  }
933
4
  if (FirstCacheLevelAssociativity == -1) {
934
1
    if (TTI->getCacheAssociativity(L1DCache).hasValue())
935
1
      FirstCacheLevelAssociativity =
936
1
          TTI->getCacheAssociativity(L1DCache).getValue();
937
0
    else
938
0
      FirstCacheLevelAssociativity =
939
0
          static_cast<int>(FirstCacheLevelDefaultAssociativity);
940
1
  }
941
4
  if (SecondCacheLevelAssociativity == -1) {
942
1
    if (TTI->getCacheAssociativity(L2DCache).hasValue())
943
1
      SecondCacheLevelAssociativity =
944
1
          TTI->getCacheAssociativity(L2DCache).getValue();
945
0
    else
946
0
      SecondCacheLevelAssociativity =
947
0
          static_cast<int>(SecondCacheLevelDefaultAssociativity);
948
1
  }
949
4
}
950
} // namespace
951
952
/// Get parameters of the BLIS macro kernel.
953
///
954
/// During the computation of matrix multiplication, blocks of partitioned
955
/// matrices are mapped to different layers of the memory hierarchy.
956
/// To optimize data reuse, blocks should be ideally kept in cache between
957
/// iterations. Since parameters of the macro kernel determine sizes of these
958
/// blocks, there are upper and lower bounds on these parameters.
959
///
960
/// @param TTI Target Transform Info.
961
/// @param MicroKernelParams Parameters of the micro-kernel
962
///                          to be taken into account.
963
/// @param MMI Parameters of the matrix multiplication operands.
964
/// @return The structure of type MacroKernelParamsTy.
965
/// @see MacroKernelParamsTy
966
/// @see MicroKernelParamsTy
967
static struct MacroKernelParamsTy
968
getMacroKernelParams(const llvm::TargetTransformInfo *TTI,
969
                     const MicroKernelParamsTy &MicroKernelParams,
970
4
                     MatMulInfoTy MMI) {
971
4
  getTargetCacheParameters(TTI);
972
4
  // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf,
973
4
  // it requires information about the first two levels of a cache to determine
974
4
  // all the parameters of a macro-kernel. It also checks that an associativity
975
4
  // degree of a cache level is greater than two. Otherwise, another algorithm
976
4
  // for determination of the parameters should be used.
977
4
  if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 &&
978
4
        FirstCacheLevelSize > 0 && 
SecondCacheLevelSize > 03
&&
979
4
        
FirstCacheLevelAssociativity > 23
&&
SecondCacheLevelAssociativity > 23
))
980
1
    return {1, 1, 1};
981
3
  // The quotient should be greater than zero.
982
3
  if (PollyPatternMatchingNcQuotient <= 0)
983
0
    return {1, 1, 1};
984
3
  int Car = floor(
985
3
      (FirstCacheLevelAssociativity - 1) /
986
3
      (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr));
987
3
988
3
  // Car can be computed to be zero since it is floor to int.
989
3
  // On Mac OS, division by 0 does not raise a signal. This causes negative
990
3
  // tile sizes to be computed. Prevent division by Cac==0 by early returning
991
3
  // if this happens.
992
3
  if (Car == 0)
993
0
    return {1, 1, 1};
994
3
995
3
  auto ElementSize = getMatMulAlignTypeSize(MMI);
996
3
  assert(ElementSize > 0 && "The element size of the matrix multiplication "
997
3
                            "operands should be greater than zero.");
998
3
  int Kc = (Car * FirstCacheLevelSize) /
999
3
           (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize);
1000
3
  double Cac =
1001
3
      static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) /
1002
3
      SecondCacheLevelSize;
1003
3
  int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac);
1004
3
  int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr;
1005
3
1006
3
  assert(Mc > 0 && Nc > 0 && Kc > 0 &&
1007
3
         "Matrix block sizes should be  greater than zero");
1008
3
  return {Mc, Nc, Kc};
1009
3
}
1010
1011
/// Create an access relation that is specific to
1012
///        the matrix multiplication pattern.
1013
///
1014
/// Create an access relation of the following form:
1015
/// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ]
1016
/// where I is @p FirstDim, J is @p SecondDim.
1017
///
1018
/// It can be used, for example, to create relations that helps to consequently
1019
/// access elements of operands of a matrix multiplication after creation of
1020
/// the BLIS micro and macro kernels.
1021
///
1022
/// @see ScheduleTreeOptimizer::createMicroKernel
1023
/// @see ScheduleTreeOptimizer::createMacroKernel
1024
///
1025
/// Subsequently, the described access relation is applied to the range of
1026
/// @p MapOldIndVar, that is used to map original induction variables to
1027
/// the ones, which are produced by schedule transformations. It helps to
1028
/// define relations using a new space and, at the same time, keep them
1029
/// in the original one.
1030
///
1031
/// @param MapOldIndVar The relation, which maps original induction variables
1032
///                     to the ones, which are produced by schedule
1033
///                     transformations.
1034
/// @param FirstDim, SecondDim The input dimensions that are used to define
1035
///        the specified access relation.
1036
/// @return The specified access relation.
1037
isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim,
1038
6
                         unsigned SecondDim) {
1039
6
  auto AccessRelSpace = isl::space(MapOldIndVar.get_ctx(), 0, 9, 3);
1040
6
  auto AccessRel = isl::map::universe(AccessRelSpace);
1041
6
  AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0);
1042
6
  AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1);
1043
6
  AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2);
1044
6
  return MapOldIndVar.apply_range(AccessRel);
1045
6
}
1046
1047
isl::schedule_node createExtensionNode(isl::schedule_node Node,
1048
6
                                       isl::map ExtensionMap) {
1049
6
  auto Extension = isl::union_map(ExtensionMap);
1050
6
  auto NewNode = isl::schedule_node::from_extension(Extension);
1051
6
  return Node.graft_before(NewNode);
1052
6
}
1053
1054
/// Apply the packing transformation.
1055
///
1056
/// The packing transformation can be described as a data-layout
1057
/// transformation that requires to introduce a new array, copy data
1058
/// to the array, and change memory access locations to reference the array.
1059
/// It can be used to ensure that elements of the new array are read in-stride
1060
/// access, aligned to cache lines boundaries, and preloaded into certain cache
1061
/// levels.
1062
///
1063
/// As an example let us consider the packing of the array A that would help
1064
/// to read its elements with in-stride access. An access to the array A
1065
/// is represented by an access relation that has the form
1066
/// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has
1067
/// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr),
1068
/// k mod Kc, j mod Nr, i mod Mr].
1069
///
1070
/// To ensure that elements of the array A are read in-stride access, we add
1071
/// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using
1072
/// Scop::createScopArrayInfo, change the access relation
1073
/// S[i, j, k] -> A[i, k] to
1074
/// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using
1075
/// MemoryAccess::setNewAccessRelation, and copy the data to the array, using
1076
/// the copy statement created by Scop::addScopStmt.
1077
///
1078
/// @param Node The schedule node to be optimized.
1079
/// @param MapOldIndVar The relation, which maps original induction variables
1080
///                     to the ones, which are produced by schedule
1081
///                     transformations.
1082
/// @param MicroParams, MacroParams Parameters of the BLIS kernel
1083
///                                 to be taken into account.
1084
/// @param MMI Parameters of the matrix multiplication operands.
1085
/// @return The optimized schedule node.
1086
static isl::schedule_node
1087
optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar,
1088
                                 MicroKernelParamsTy MicroParams,
1089
                                 MacroKernelParamsTy MacroParams,
1090
3
                                 MatMulInfoTy &MMI) {
1091
3
  auto InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
1092
3
  auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
1093
3
1094
3
  // Create a copy statement that corresponds to the memory access to the
1095
3
  // matrix B, the second operand of the matrix multiplication.
1096
3
  Node = Node.parent().parent().parent().parent().parent().parent();
1097
3
  Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2)).child(0);
1098
3
  auto AccRel = getMatMulAccRel(MapOldIndVar, 3, 7);
1099
3
  unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
1100
3
  unsigned SecondDimSize = MacroParams.Kc;
1101
3
  unsigned ThirdDimSize = MicroParams.Nr;
1102
3
  auto *SAI = Stmt->getParent()->createScopArrayInfo(
1103
3
      MMI.B->getElementType(), "Packed_B",
1104
3
      {FirstDimSize, SecondDimSize, ThirdDimSize});
1105
3
  AccRel = AccRel.set_tuple_id(isl::dim::out, SAI->getBasePtrId());
1106
3
  auto OldAcc = MMI.B->getLatestAccessRelation();
1107
3
  MMI.B->setNewAccessRelation(AccRel);
1108
3
  auto ExtMap = MapOldIndVar.project_out(isl::dim::out, 2,
1109
3
                                         MapOldIndVar.dim(isl::dim::out) - 2);
1110
3
  ExtMap = ExtMap.reverse();
1111
3
  ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0);
1112
3
  auto Domain = Stmt->getDomain();
1113
3
1114
3
  // Restrict the domains of the copy statements to only execute when also its
1115
3
  // originating statement is executed.
1116
3
  auto DomainId = Domain.get_tuple_id();
1117
3
  auto *NewStmt = Stmt->getParent()->addScopStmt(
1118
3
      OldAcc, MMI.B->getLatestAccessRelation(), Domain);
1119
3
  ExtMap = ExtMap.set_tuple_id(isl::dim::out, DomainId);
1120
3
  ExtMap = ExtMap.intersect_range(Domain);
1121
3
  ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId());
1122
3
  Node = createExtensionNode(Node, ExtMap);
1123
3
1124
3
  // Create a copy statement that corresponds to the memory access
1125
3
  // to the matrix A, the first operand of the matrix multiplication.
1126
3
  Node = Node.child(0);
1127
3
  AccRel = getMatMulAccRel(MapOldIndVar, 4, 6);
1128
3
  FirstDimSize = MacroParams.Mc / MicroParams.Mr;
1129
3
  ThirdDimSize = MicroParams.Mr;
1130
3
  SAI = Stmt->getParent()->createScopArrayInfo(
1131
3
      MMI.A->getElementType(), "Packed_A",
1132
3
      {FirstDimSize, SecondDimSize, ThirdDimSize});
1133
3
  AccRel = AccRel.set_tuple_id(isl::dim::out, SAI->getBasePtrId());
1134
3
  OldAcc = MMI.A->getLatestAccessRelation();
1135
3
  MMI.A->setNewAccessRelation(AccRel);
1136
3
  ExtMap = MapOldIndVar.project_out(isl::dim::out, 3,
1137
3
                                    MapOldIndVar.dim(isl::dim::out) - 3);
1138
3
  ExtMap = ExtMap.reverse();
1139
3
  ExtMap = ExtMap.fix_si(isl::dim::out, MMI.j, 0);
1140
3
  NewStmt = Stmt->getParent()->addScopStmt(
1141
3
      OldAcc, MMI.A->getLatestAccessRelation(), Domain);
1142
3
1143
3
  // Restrict the domains of the copy statements to only execute when also its
1144
3
  // originating statement is executed.
1145
3
  ExtMap = ExtMap.set_tuple_id(isl::dim::out, DomainId);
1146
3
  ExtMap = ExtMap.intersect_range(Domain);
1147
3
  ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId());
1148
3
  Node = createExtensionNode(Node, ExtMap);
1149
3
  return Node.child(0).child(0).child(0).child(0).child(0);
1150
3
}
1151
1152
/// Get a relation mapping induction variables produced by schedule
1153
/// transformations to the original ones.
1154
///
1155
/// @param Node The schedule node produced as the result of creation
1156
///        of the BLIS kernels.
1157
/// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel
1158
///                                             to be taken into account.
1159
/// @return  The relation mapping original induction variables to the ones
1160
///          produced by schedule transformation.
1161
/// @see ScheduleTreeOptimizer::createMicroKernel
1162
/// @see ScheduleTreeOptimizer::createMacroKernel
1163
/// @see getMacroKernelParams
1164
isl::map
1165
getInductionVariablesSubstitution(isl::schedule_node Node,
1166
                                  MicroKernelParamsTy MicroKernelParams,
1167
3
                                  MacroKernelParamsTy MacroKernelParams) {
1168
3
  auto Child = Node.child(0);
1169
3
  auto UnMapOldIndVar = Child.get_prefix_schedule_union_map();
1170
3
  auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar);
1171
3
  if (MapOldIndVar.dim(isl::dim::out) > 9)
1172
0
    return MapOldIndVar.project_out(isl::dim::out, 0,
1173
0
                                    MapOldIndVar.dim(isl::dim::out) - 9);
1174
3
  return MapOldIndVar;
1175
3
}
1176
1177
/// Isolate a set of partial tile prefixes and unroll the isolated part.
1178
///
1179
/// The set should ensure that it contains only partial tile prefixes that have
1180
/// exactly Mr x Nr iterations of the two innermost loops produced by
1181
/// the optimization of the matrix multiplication. Mr and Nr are parameters of
1182
/// the micro-kernel.
1183
///
1184
/// In case of parametric bounds, this helps to auto-vectorize the unrolled
1185
/// innermost loops, using the SLP vectorizer.
1186
///
1187
/// @param Node              The schedule node to be modified.
1188
/// @param MicroKernelParams Parameters of the micro-kernel
1189
///                          to be taken into account.
1190
/// @return The modified isl_schedule_node.
1191
static isl::schedule_node
1192
isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,
1193
3
                                 struct MicroKernelParamsTy MicroKernelParams) {
1194
3
  isl::schedule_node Child = Node.get_child(0);
1195
3
  isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation();
1196
3
  isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range();
1197
3
  unsigned Dims = Prefix.dim(isl::dim::set);
1198
3
  Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1);
1199
3
  Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr);
1200
3
  Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr);
1201
3
1202
3
  isl::union_set IsolateOption =
1203
3
      getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3);
1204
3
  isl::ctx Ctx = Node.get_ctx();
1205
3
  auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll"));
1206
3
  Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
1207
3
  Node = Node.band_set_ast_build_options(Options);
1208
3
  Node = Node.parent().parent().parent();
1209
3
  IsolateOption = getIsolateOptions(Prefix, 3);
1210
3
  Options = IsolateOption.unite(getDimOptions(Ctx, "separate"));
1211
3
  Node = Node.band_set_ast_build_options(Options);
1212
3
  Node = Node.child(0).child(0).child(0);
1213
3
  return Node;
1214
3
}
1215
1216
/// Mark @p BasePtr with "Inter iteration alias-free" mark node.
1217
///
1218
/// @param Node The child of the mark node to be inserted.
1219
/// @param BasePtr The pointer to be marked.
1220
/// @return The modified isl_schedule_node.
1221
static isl::schedule_node markInterIterationAliasFree(isl::schedule_node Node,
1222
4
                                                      Value *BasePtr) {
1223
4
  if (!BasePtr)
1224
0
    return Node;
1225
4
1226
4
  auto Id =
1227
4
      isl::id::alloc(Node.get_ctx(), "Inter iteration alias-free", BasePtr);
1228
4
  return Node.insert_mark(Id).child(0);
1229
4
}
1230
1231
/// Insert "Loop Vectorizer Disabled" mark node.
1232
///
1233
/// @param Node The child of the mark node to be inserted.
1234
/// @return The modified isl_schedule_node.
1235
3
static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
1236
3
  auto Id = isl::id::alloc(Node.get_ctx(), "Loop Vectorizer Disabled", nullptr);
1237
3
  return Node.insert_mark(Id).child(0);
1238
3
}
1239
1240
/// Restore the initial ordering of dimensions of the band node
1241
///
1242
/// In case the band node represents all the dimensions of the iteration
1243
/// domain, recreate the band node to restore the initial ordering of the
1244
/// dimensions.
1245
///
1246
/// @param Node The band node to be modified.
1247
/// @return The modified schedule node.
1248
static isl::schedule_node
1249
4
getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
1250
4
  assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
1251
4
  if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf)
1252
0
    return Node;
1253
4
  auto Domain = Node.get_universe_domain();
1254
4
  assert(isl_union_set_n_set(Domain.get()) == 1);
1255
4
  if (Node.get_schedule_depth() != 0 ||
1256
4
      (isl::set(Domain).dim(isl::dim::set) !=
1257
4
       isl_schedule_node_band_n_member(Node.get())))
1258
0
    return Node;
1259
4
  Node = isl::manage(isl_schedule_node_delete(Node.copy()));
1260
4
  auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff();
1261
4
  auto PartialScheduleMultiPwAff =
1262
4
      isl::multi_union_pw_aff(PartialSchedulePwAff);
1263
4
  PartialScheduleMultiPwAff =
1264
4
      PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set);
1265
4
  return Node.insert_partial_schedule(PartialScheduleMultiPwAff);
1266
4
}
1267
1268
isl::schedule_node
1269
ScheduleTreeOptimizer::optimizeMatMulPattern(isl::schedule_node Node,
1270
                                             const TargetTransformInfo *TTI,
1271
4
                                             MatMulInfoTy &MMI) {
1272
4
  assert(TTI && "The target transform info should be provided.");
1273
4
  Node = markInterIterationAliasFree(
1274
4
      Node, MMI.WriteToC->getLatestScopArrayInfo()->getBasePtr());
1275
4
  int DimOutNum = isl_schedule_node_band_n_member(Node.get());
1276
4
  assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
1277
4
                          "and, consequently, the corresponding scheduling "
1278
4
                          "functions have at least three dimensions.");
1279
4
  Node = getBandNodeWithOriginDimOrder(Node);
1280
4
  Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
1281
4
  int NewJ = MMI.j == DimOutNum - 3 ? 
MMI.i0
: MMI.j;
1282
4
  int NewK = MMI.k == DimOutNum - 3 ? 
MMI.i0
: MMI.k;
1283
4
  Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2);
1284
4
  NewK = NewK == DimOutNum - 2 ? 
NewJ0
: NewK;
1285
4
  Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1);
1286
4
  auto MicroKernelParams = getMicroKernelParams(TTI, MMI);
1287
4
  auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI);
1288
4
  Node = createMacroKernel(Node, MacroKernelParams);
1289
4
  Node = createMicroKernel(Node, MicroKernelParams);
1290
4
  if (MacroKernelParams.Mc == 1 || 
MacroKernelParams.Nc == 13
||
1291
4
      
MacroKernelParams.Kc == 13
)
1292
1
    return Node;
1293
3
  auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams,
1294
3
                                                        MacroKernelParams);
1295
3
  if (!MapOldIndVar)
1296
0
    return Node;
1297
3
  Node = markLoopVectorizerDisabled(Node.parent()).child(0);
1298
3
  Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
1299
3
  return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
1300
3
                                          MacroKernelParams, MMI);
1301
3
}
1302
1303
bool ScheduleTreeOptimizer::isMatrMultPattern(isl::schedule_node Node,
1304
                                              const Dependences *D,
1305
11
                                              MatMulInfoTy &MMI) {
1306
11
  auto PartialSchedule = isl::manage(
1307
11
      isl_schedule_node_band_get_partial_schedule_union_map(Node.get()));
1308
11
  Node = Node.child(0);
1309
11
  auto LeafType = isl_schedule_node_get_type(Node.get());
1310
11
  Node = Node.parent();
1311
11
  if (LeafType != isl_schedule_node_leaf ||
1312
11
      isl_schedule_node_band_n_member(Node.get()) < 3 ||
1313
11
      
Node.get_schedule_depth() != 04
||
1314
11
      
isl_union_map_n_map(PartialSchedule.get()) != 14
)
1315
7
    return false;
1316
4
  auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule);
1317
4
  if (containsMatrMult(NewPartialSchedule, D, MMI))
1318
4
    return true;
1319
0
  return false;
1320
0
}
1321
1322
__isl_give isl_schedule_node *
1323
ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
1324
158
                                    void *User) {
1325
158
  if (!isTileableBandNode(isl::manage_copy(Node)))
1326
145
    return Node;
1327
13
1328
13
  const OptimizerAdditionalInfoTy *OAI =
1329
13
      static_cast<const OptimizerAdditionalInfoTy *>(User);
1330
13
1331
13
  MatMulInfoTy MMI;
1332
13
  if (PMBasedOpts && 
User11
&&
1333
13
      
isMatrMultPattern(isl::manage_copy(Node), OAI->D, MMI)11
) {
1334
4
    LLVM_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
1335
4
    MatMulOpts++;
1336
4
    return optimizeMatMulPattern(isl::manage(Node), OAI->TTI, MMI).release();
1337
4
  }
1338
9
1339
9
  return standardBandOpts(isl::manage(Node), User).release();
1340
9
}
1341
1342
isl::schedule
1343
ScheduleTreeOptimizer::optimizeSchedule(isl::schedule Schedule,
1344
12
                                        const OptimizerAdditionalInfoTy *OAI) {
1345
12
  auto Root = Schedule.get_root();
1346
12
  Root = optimizeScheduleNode(Root, OAI);
1347
12
  return Root.get_schedule();
1348
12
}
1349
1350
isl::schedule_node ScheduleTreeOptimizer::optimizeScheduleNode(
1351
12
    isl::schedule_node Node, const OptimizerAdditionalInfoTy *OAI) {
1352
12
  Node = isl::manage(isl_schedule_node_map_descendant_bottom_up(
1353
12
      Node.release(), optimizeBand,
1354
12
      const_cast<void *>(static_cast<const void *>(OAI))));
1355
12
  return Node;
1356
12
}
1357
1358
bool ScheduleTreeOptimizer::isProfitableSchedule(Scop &S,
1359
12
                                                 isl::schedule NewSchedule) {
1360
12
  // To understand if the schedule has been optimized we check if the schedule
1361
12
  // has changed at all.
1362
12
  // TODO: We can improve this by tracking if any necessarily beneficial
1363
12
  // transformations have been performed. This can e.g. be tiling, loop
1364
12
  // interchange, or ...) We can track this either at the place where the
1365
12
  // transformation has been performed or, in case of automatic ILP based
1366
12
  // optimizations, by comparing (yet to be defined) performance metrics
1367
12
  // before/after the scheduling optimizer
1368
12
  // (e.g., #stride-one accesses)
1369
12
  if (S.containsExtensionNode(NewSchedule))
1370
3
    return true;
1371
9
  auto NewScheduleMap = NewSchedule.get_map();
1372
9
  auto OldSchedule = S.getSchedule();
1373
9
  assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes "
1374
9
                        "that make Scop::getSchedule() return nullptr.");
1375
9
  bool changed = !OldSchedule.is_equal(NewScheduleMap);
1376
9
  return changed;
1377
9
}
1378
1379
namespace {
1380
1381
class IslScheduleOptimizer : public ScopPass {
1382
public:
1383
  static char ID;
1384
1385
13
  explicit IslScheduleOptimizer() : ScopPass(ID) {}
1386
1387
13
  ~IslScheduleOptimizer() override { isl_schedule_free(LastSchedule); }
1388
1389
  /// Optimize the schedule of the SCoP @p S.
1390
  bool runOnScop(Scop &S) override;
1391
1392
  /// Print the new schedule for the SCoP @p S.
1393
  void printScop(raw_ostream &OS, Scop &S) const override;
1394
1395
  /// Register all analyses and transformation required.
1396
  void getAnalysisUsage(AnalysisUsage &AU) const override;
1397
1398
  /// Release the internal memory.
1399
52
  void releaseMemory() override {
1400
52
    isl_schedule_free(LastSchedule);
1401
52
    LastSchedule = nullptr;
1402
52
  }
1403
1404
private:
1405
  isl_schedule *LastSchedule = nullptr;
1406
};
1407
} // namespace
1408
1409
char IslScheduleOptimizer::ID = 0;
1410
1411
/// Collect statistics for the schedule tree.
1412
///
1413
/// @param Schedule The schedule tree to analyze. If not a schedule tree it is
1414
/// ignored.
1415
/// @param Version  The version of the schedule tree that is analyzed.
1416
///                 0 for the original schedule tree before any transformation.
1417
///                 1 for the schedule tree after isl's rescheduling.
1418
///                 2 for the schedule tree after optimizations are applied
1419
///                 (tiling, pattern matching)
1420
36
static void walkScheduleTreeForStatistics(isl::schedule Schedule, int Version) {
1421
36
  auto Root = Schedule.get_root();
1422
36
  if (!Root)
1423
0
    return;
1424
36
1425
36
  isl_schedule_node_foreach_descendant_top_down(
1426
36
      Root.get(),
1427
326
      [](__isl_keep isl_schedule_node *nodeptr, void *user) -> isl_bool {
1428
326
        isl::schedule_node Node = isl::manage_copy(nodeptr);
1429
326
        int Version = *static_cast<int *>(user);
1430
326
1431
326
        switch (isl_schedule_node_get_type(Node.get())) {
1432
326
        case isl_schedule_node_band: {
1433
95
          NumBands[Version]++;
1434
95
          if (isl_schedule_node_band_get_permutable(Node.get()) ==
1435
95
              isl_bool_true)
1436
51
            NumPermutable[Version]++;
1437
95
1438
95
          int CountMembers = isl_schedule_node_band_n_member(Node.get());
1439
95
          NumBandMembers[Version] += CountMembers;
1440
242
          for (int i = 0; i < CountMembers; 
i += 1147
) {
1441
147
            if (Node.band_member_get_coincident(i))
1442
73
              NumCoincident[Version]++;
1443
147
          }
1444
95
          break;
1445
326
        }
1446
326
1447
326
        case isl_schedule_node_filter:
1448
54
          NumFilters[Version]++;
1449
54
          break;
1450
326
1451
326
        case isl_schedule_node_extension:
1452
6
          NumExtension[Version]++;
1453
6
          break;
1454
326
1455
326
        default:
1456
171
          break;
1457
326
        }
1458
326
1459
326
        return isl_bool_true;
1460
326
      },
1461
36
      &Version);
1462
36
}
1463
1464
13
bool IslScheduleOptimizer::runOnScop(Scop &S) {
1465
13
  // Skip SCoPs in case they're already optimised by PPCGCodeGeneration
1466
13
  if (S.isToBeSkipped())
1467
0
    return false;
1468
13
1469
13
  // Skip empty SCoPs but still allow code generation as it will delete the
1470
13
  // loops present but not needed.
1471
13
  if (S.getSize() == 0) {
1472
0
    S.markAsOptimized();
1473
0
    return false;
1474
0
  }
1475
13
1476
13
  const Dependences &D =
1477
13
      getAnalysis<DependenceInfo>().getDependences(Dependences::AL_Statement);
1478
13
1479
13
  if (D.getSharedIslCtx() != S.getSharedIslCtx()) {
1480
0
    LLVM_DEBUG(dbgs() << "DependenceInfo for another SCoP/isl_ctx\n");
1481
0
    return false;
1482
0
  }
1483
13
1484
13
  if (!D.hasValidDependences())
1485
1
    return false;
1486
12
1487
12
  isl_schedule_free(LastSchedule);
1488
12
  LastSchedule = nullptr;
1489
12
1490
12
  // Build input data.
1491
12
  int ValidityKinds =
1492
12
      Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW;
1493
12
  int ProximityKinds;
1494
12
1495
12
  if (OptimizeDeps == "all")
1496
12
    ProximityKinds =
1497
12
        Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW;
1498
0
  else if (OptimizeDeps == "raw")
1499
0
    ProximityKinds = Dependences::TYPE_RAW;
1500
0
  else {
1501
0
    errs() << "Do not know how to optimize for '" << OptimizeDeps << "'"
1502
0
           << " Falling back to optimizing all dependences.\n";
1503
0
    ProximityKinds =
1504
0
        Dependences::TYPE_RAW | Dependences::TYPE_WAR | Dependences::TYPE_WAW;
1505
0
  }
1506
12
1507
12
  isl::union_set Domain = S.getDomains();
1508
12
1509
12
  if (!Domain)
1510
0
    return false;
1511
12
1512
12
  ScopsProcessed++;
1513
12
  walkScheduleTreeForStatistics(S.getScheduleTree(), 0);
1514
12
1515
12
  isl::union_map Validity = D.getDependences(ValidityKinds);
1516
12
  isl::union_map Proximity = D.getDependences(ProximityKinds);
1517
12
1518
12
  // Simplify the dependences by removing the constraints introduced by the
1519
12
  // domains. This can speed up the scheduling time significantly, as large
1520
12
  // constant coefficients will be removed from the dependences. The
1521
12
  // introduction of some additional dependences reduces the possible
1522
12
  // transformations, but in most cases, such transformation do not seem to be
1523
12
  // interesting anyway. In some cases this option may stop the scheduler to
1524
12
  // find any schedule.
1525
12
  if (SimplifyDeps == "yes") {
1526
12
    Validity = Validity.gist_domain(Domain);
1527
12
    Validity = Validity.gist_range(Domain);
1528
12
    Proximity = Proximity.gist_domain(Domain);
1529
12
    Proximity = Proximity.gist_range(Domain);
1530
12
  } else 
if (0
SimplifyDeps != "no"0
) {
1531
0
    errs() << "warning: Option -polly-opt-simplify-deps should either be 'yes' "
1532
0
              "or 'no'. Falling back to default: 'yes'\n";
1533
0
  }
1534
12
1535
12
  LLVM_DEBUG(dbgs() << "\n\nCompute schedule from: ");
1536
12
  LLVM_DEBUG(dbgs() << "Domain := " << Domain << ";\n");
1537
12
  LLVM_DEBUG(dbgs() << "Proximity := " << Proximity << ";\n");
1538
12
  LLVM_DEBUG(dbgs() << "Validity := " << Validity << ";\n");
1539
12
1540
12
  unsigned IslSerializeSCCs;
1541
12
1542
12
  if (FusionStrategy == "max") {
1543
0
    IslSerializeSCCs = 0;
1544
12
  } else if (FusionStrategy == "min") {
1545
12
    IslSerializeSCCs = 1;
1546
12
  } else {
1547
0
    errs() << "warning: Unknown fusion strategy. Falling back to maximal "
1548
0
              "fusion.\n";
1549
0
    IslSerializeSCCs = 0;
1550
0
  }
1551
12
1552
12
  int IslMaximizeBands;
1553
12
1554
12
  if (MaximizeBandDepth == "yes") {
1555
12
    IslMaximizeBands = 1;
1556
12
  } else 
if (0
MaximizeBandDepth == "no"0
) {
1557
0
    IslMaximizeBands = 0;
1558
0
  } else {
1559
0
    errs() << "warning: Option -polly-opt-maximize-bands should either be 'yes'"
1560
0
              " or 'no'. Falling back to default: 'yes'\n";
1561
0
    IslMaximizeBands = 1;
1562
0
  }
1563
12
1564
12
  int IslOuterCoincidence;
1565
12
1566
12
  if (OuterCoincidence == "yes") {
1567
0
    IslOuterCoincidence = 1;
1568
12
  } else if (OuterCoincidence == "no") {
1569
12
    IslOuterCoincidence = 0;
1570
12
  } else {
1571
0
    errs() << "warning: Option -polly-opt-outer-coincidence should either be "
1572
0
              "'yes' or 'no'. Falling back to default: 'no'\n";
1573
0
    IslOuterCoincidence = 0;
1574
0
  }
1575
12
1576
12
  isl_ctx *Ctx = S.getIslCtx().get();
1577
12
1578
12
  isl_options_set_schedule_outer_coincidence(Ctx, IslOuterCoincidence);
1579
12
  isl_options_set_schedule_serialize_sccs(Ctx, IslSerializeSCCs);
1580
12
  isl_options_set_schedule_maximize_band_depth(Ctx, IslMaximizeBands);
1581
12
  isl_options_set_schedule_max_constant_term(Ctx, MaxConstantTerm);
1582
12
  isl_options_set_schedule_max_coefficient(Ctx, MaxCoefficient);
1583
12
  isl_options_set_tile_scale_tile_loops(Ctx, 0);
1584
12
1585
12
  auto OnErrorStatus = isl_options_get_on_error(Ctx);
1586
12
  isl_options_set_on_error(Ctx, ISL_ON_ERROR_CONTINUE);
1587
12
1588
12
  auto SC = isl::schedule_constraints::on_domain(Domain);
1589
12
  SC = SC.set_proximity(Proximity);
1590
12
  SC = SC.set_validity(Validity);
1591
12
  SC = SC.set_coincidence(Validity);
1592
12
  auto Schedule = SC.compute_schedule();
1593
12
  isl_options_set_on_error(Ctx, OnErrorStatus);
1594
12
1595
12
  walkScheduleTreeForStatistics(Schedule, 1);
1596
12
1597
12
  // In cases the scheduler is not able to optimize the code, we just do not
1598
12
  // touch the schedule.
1599
12
  if (!Schedule)
1600
0
    return false;
1601
12
1602
12
  ScopsRescheduled++;
1603
12
1604
12
  LLVM_DEBUG({
1605
12
    auto *P = isl_printer_to_str(Ctx);
1606
12
    P = isl_printer_set_yaml_style(P, ISL_YAML_STYLE_BLOCK);
1607
12
    P = isl_printer_print_schedule(P, Schedule.get());
1608
12
    auto *str = isl_printer_get_str(P);
1609
12
    dbgs() << "NewScheduleTree: \n" << str << "\n";
1610
12
    free(str);
1611
12
    isl_printer_free(P);
1612
12
  });
1613
12
1614
12
  Function &F = S.getFunction();
1615
12
  auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1616
12
  const OptimizerAdditionalInfoTy OAI = {TTI, const_cast<Dependences *>(&D)};
1617
12
  auto NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule, &OAI);
1618
12
  walkScheduleTreeForStatistics(NewSchedule, 2);
1619
12
1620
12
  if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule))
1621
2
    return false;
1622
10
1623
10
  auto ScopStats = S.getStatistics();
1624
10
  ScopsOptimized++;
1625
10
  NumAffineLoopsOptimized += ScopStats.NumAffineLoops;
1626
10
  NumBoxedLoopsOptimized += ScopStats.NumBoxedLoops;
1627
10
1628
10
  S.setScheduleTree(NewSchedule);
1629
10
  S.markAsOptimized();
1630
10
1631
10
  if (OptimizedScops)
1632
0
    errs() << S;
1633
10
1634
10
  return false;
1635
10
}
1636
1637
10
void IslScheduleOptimizer::printScop(raw_ostream &OS, Scop &) const {
1638
10
  isl_printer *p;
1639
10
  char *ScheduleStr;
1640
10
1641
10
  OS << "Calculated schedule:\n";
1642
10
1643
10
  if (!LastSchedule) {
1644
10
    OS << "n/a\n";
1645
10
    return;
1646
10
  }
1647
0
1648
0
  p = isl_printer_to_str(isl_schedule_get_ctx(LastSchedule));
1649
0
  p = isl_printer_print_schedule(p, LastSchedule);
1650
0
  ScheduleStr = isl_printer_get_str(p);
1651
0
  isl_printer_free(p);
1652
0
1653
0
  OS << ScheduleStr << "\n";
1654
0
}
1655
1656
13
void IslScheduleOptimizer::getAnalysisUsage(AnalysisUsage &AU) const {
1657
13
  ScopPass::getAnalysisUsage(AU);
1658
13
  AU.addRequired<DependenceInfo>();
1659
13
  AU.addRequired<TargetTransformInfoWrapperPass>();
1660
13
1661
13
  AU.addPreserved<DependenceInfo>();
1662
13
}
1663
1664
0
Pass *polly::createIslScheduleOptimizerPass() {
1665
0
  return new IslScheduleOptimizer();
1666
0
}
1667
1668
9.54k
INITIALIZE_PASS_BEGIN(IslScheduleOptimizer, "polly-opt-isl",
1669
9.54k
                      "Polly - Optimize schedule of SCoP", false, false);
1670
9.54k
INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
1671
9.54k
INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass);
1672
9.54k
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass);
1673
9.54k
INITIALIZE_PASS_END(IslScheduleOptimizer, "polly-opt-isl",
1674
                    "Polly - Optimize schedule of SCoP", false, false)