Coverage Report

Created: 2021-09-21 08:58

/Users/buildslave/jenkins/workspace/coverage/llvm-project/clang/utils/TableGen/RISCVVEmitter.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This tablegen backend is responsible for emitting riscv_vector.h which
10
// includes a declaration and definition of each intrinsic functions specified
11
// in https://github.com/riscv/rvv-intrinsic-doc.
12
//
13
// See also the documentation in include/clang/Basic/riscv_vector.td.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#include "llvm/ADT/ArrayRef.h"
18
#include "llvm/ADT/SmallSet.h"
19
#include "llvm/ADT/StringExtras.h"
20
#include "llvm/ADT/StringMap.h"
21
#include "llvm/ADT/StringSet.h"
22
#include "llvm/ADT/Twine.h"
23
#include "llvm/TableGen/Error.h"
24
#include "llvm/TableGen/Record.h"
25
#include <numeric>
26
27
using namespace llvm;
28
using BasicType = char;
29
using VScaleVal = Optional<unsigned>;
30
31
namespace {
32
33
// Exponential LMUL
34
struct LMULType {
35
  int Log2LMUL;
36
  LMULType(int Log2LMUL);
37
  // Return the C/C++ string representation of LMUL
38
  std::string str() const;
39
  Optional<unsigned> getScale(unsigned ElementBitwidth) const;
40
  void MulLog2LMUL(int Log2LMUL);
41
  LMULType &operator*=(uint32_t RHS);
42
};
43
44
// This class is compact representation of a valid and invalid RVVType.
45
class RVVType {
46
  enum ScalarTypeKind : uint32_t {
47
    Void,
48
    Size_t,
49
    Ptrdiff_t,
50
    UnsignedLong,
51
    SignedLong,
52
    Boolean,
53
    SignedInteger,
54
    UnsignedInteger,
55
    Float,
56
    Invalid,
57
  };
58
  BasicType BT;
59
  ScalarTypeKind ScalarType = Invalid;
60
  LMULType LMUL;
61
  bool IsPointer = false;
62
  // IsConstant indices are "int", but have the constant expression.
63
  bool IsImmediate = false;
64
  // Const qualifier for pointer to const object or object of const type.
65
  bool IsConstant = false;
66
  unsigned ElementBitwidth = 0;
67
  VScaleVal Scale = 0;
68
  bool Valid;
69
70
  std::string BuiltinStr;
71
  std::string ClangBuiltinStr;
72
  std::string Str;
73
  std::string ShortStr;
74
75
public:
76
0
  RVVType() : RVVType(BasicType(), 0, StringRef()) {}
77
  RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
78
79
  // Return the string representation of a type, which is an encoded string for
80
  // passing to the BUILTIN() macro in Builtins.def.
81
0
  const std::string &getBuiltinStr() const { return BuiltinStr; }
82
83
  // Return the clang buitlin type for RVV vector type which are used in the
84
  // riscv_vector.h header file.
85
0
  const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
86
87
  // Return the C/C++ string representation of a type for use in the
88
  // riscv_vector.h header file.
89
0
  const std::string &getTypeStr() const { return Str; }
90
91
  // Return the short name of a type for C/C++ name suffix.
92
0
  const std::string &getShortStr() {
93
    // Not all types are used in short name, so compute the short name by
94
    // demanded.
95
0
    if (ShortStr.empty())
96
0
      initShortStr();
97
0
    return ShortStr;
98
0
  }
99
100
0
  bool isValid() const { return Valid; }
101
0
  bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
102
0
  bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
103
0
  bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
104
0
  bool isSignedInteger() const {
105
0
    return ScalarType == ScalarTypeKind::SignedInteger;
106
0
  }
107
0
  bool isFloatVector(unsigned Width) const {
108
0
    return isVector() && isFloat() && ElementBitwidth == Width;
109
0
  }
110
0
  bool isFloat(unsigned Width) const {
111
0
    return isFloat() && ElementBitwidth == Width;
112
0
  }
113
114
private:
115
  // Verify RVV vector type and set Valid.
116
  bool verifyType() const;
117
118
  // Creates a type based on basic types of TypeRange
119
  void applyBasicType();
120
121
  // Applies a prototype modifier to the current type. The result maybe an
122
  // invalid type.
123
  void applyModifier(StringRef prototype);
124
125
  // Compute and record a string for legal type.
126
  void initBuiltinStr();
127
  // Compute and record a builtin RVV vector type string.
128
  void initClangBuiltinStr();
129
  // Compute and record a type string for used in the header.
130
  void initTypeStr();
131
  // Compute and record a short name of a type for C/C++ name suffix.
132
  void initShortStr();
133
};
134
135
using RVVTypePtr = RVVType *;
136
using RVVTypes = std::vector<RVVTypePtr>;
137
138
enum RISCVExtension : uint8_t {
139
  Basic = 0,
140
  F = 1 << 1,
141
  D = 1 << 2,
142
  Zfh = 1 << 3,
143
  Zvamo = 1 << 4,
144
  Zvlsseg = 1 << 5,
145
};
146
147
// TODO refactor RVVIntrinsic class design after support all intrinsic
148
// combination. This represents an instantiation of an intrinsic with a
149
// particular type and prototype
150
class RVVIntrinsic {
151
152
private:
153
  std::string Name; // Builtin name
154
  std::string MangledName;
155
  std::string IRName;
156
  bool HasSideEffects;
157
  bool IsMask;
158
  bool HasMaskedOffOperand;
159
  bool HasVL;
160
  bool HasNoMaskedOverloaded;
161
  bool HasAutoDef; // There is automiatic definition in header
162
  std::string ManualCodegen;
163
  RVVTypePtr OutputType; // Builtin output type
164
  RVVTypes InputTypes;   // Builtin input types
165
  // The types we use to obtain the specific LLVM intrinsic. They are index of
166
  // InputTypes. -1 means the return type.
167
  std::vector<int64_t> IntrinsicTypes;
168
  uint8_t RISCVExtensions = 0;
169
  unsigned NF = 1;
170
171
public:
172
  RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
173
               StringRef MangledSuffix, StringRef IRName, bool HasSideEffects,
174
               bool IsMask, bool HasMaskedOffOperand, bool HasVL,
175
               bool HasNoMaskedOverloaded, bool HasAutoDef,
176
               StringRef ManualCodegen, const RVVTypes &Types,
177
               const std::vector<int64_t> &IntrinsicTypes,
178
               StringRef RequiredExtension, unsigned NF);
179
0
  ~RVVIntrinsic() = default;
180
181
0
  StringRef getName() const { return Name; }
182
0
  StringRef getMangledName() const { return MangledName; }
183
0
  bool hasSideEffects() const { return HasSideEffects; }
184
0
  bool hasMaskedOffOperand() const { return HasMaskedOffOperand; }
185
0
  bool hasVL() const { return HasVL; }
186
0
  bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; }
187
0
  bool hasManualCodegen() const { return !ManualCodegen.empty(); }
188
0
  bool hasAutoDef() const { return HasAutoDef; }
189
0
  bool isMask() const { return IsMask; }
190
0
  StringRef getIRName() const { return IRName; }
191
0
  StringRef getManualCodegen() const { return ManualCodegen; }
192
0
  uint8_t getRISCVExtensions() const { return RISCVExtensions; }
193
0
  unsigned getNF() const { return NF; }
194
195
  // Return the type string for a BUILTIN() macro in Builtins.def.
196
  std::string getBuiltinTypeStr() const;
197
198
  // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
199
  // init the RVVIntrinsic ID and IntrinsicTypes.
200
  void emitCodeGenSwitchBody(raw_ostream &o) const;
201
202
  // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
203
  void emitIntrinsicMacro(raw_ostream &o) const;
204
205
  // Emit the mangled function definition.
206
  void emitMangledFuncDef(raw_ostream &o) const;
207
};
208
209
class RVVEmitter {
210
private:
211
  RecordKeeper &Records;
212
  std::string HeaderCode;
213
  // Concat BasicType, LMUL and Proto as key
214
  StringMap<RVVType> LegalTypes;
215
  StringSet<> IllegalTypes;
216
217
public:
218
0
  RVVEmitter(RecordKeeper &R) : Records(R) {}
219
220
  /// Emit riscv_vector.h
221
  void createHeader(raw_ostream &o);
222
223
  /// Emit all the __builtin prototypes and code needed by Sema.
224
  void createBuiltins(raw_ostream &o);
225
226
  /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
227
  void createCodeGen(raw_ostream &o);
228
229
  std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
230
231
private:
232
  /// Create all intrinsics and add them to \p Out
233
  void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
234
  /// Compute output and input types by applying different config (basic type
235
  /// and LMUL with type transformers). It also record result of type in legal
236
  /// or illegal set to avoid compute the  same config again. The result maybe
237
  /// have illegal RVVType.
238
  Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
239
                                  ArrayRef<std::string> PrototypeSeq);
240
  Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
241
242
  /// Emit Acrh predecessor definitions and body, assume the element of Defs are
243
  /// sorted by extension.
244
  void emitArchMacroAndBody(
245
      std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o,
246
      std::function<void(raw_ostream &, const RVVIntrinsic &)>);
247
248
  // Emit the architecture preprocessor definitions. Return true when emits
249
  // non-empty string.
250
  bool emitExtDefStr(uint8_t Extensions, raw_ostream &o);
251
  // Slice Prototypes string into sub prototype string and process each sub
252
  // prototype string individually in the Handler.
253
  void parsePrototypes(StringRef Prototypes,
254
                       std::function<void(StringRef)> Handler);
255
};
256
257
} // namespace
258
259
//===----------------------------------------------------------------------===//
260
// Type implementation
261
//===----------------------------------------------------------------------===//
262
263
0
LMULType::LMULType(int NewLog2LMUL) {
264
  // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
265
0
  assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
266
0
  Log2LMUL = NewLog2LMUL;
267
0
}
268
269
0
std::string LMULType::str() const {
270
0
  if (Log2LMUL < 0)
271
0
    return "mf" + utostr(1ULL << (-Log2LMUL));
272
0
  return "m" + utostr(1ULL << Log2LMUL);
273
0
}
274
275
0
VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
276
0
  int Log2ScaleResult = 0;
277
0
  switch (ElementBitwidth) {
278
0
  default:
279
0
    break;
280
0
  case 8:
281
0
    Log2ScaleResult = Log2LMUL + 3;
282
0
    break;
283
0
  case 16:
284
0
    Log2ScaleResult = Log2LMUL + 2;
285
0
    break;
286
0
  case 32:
287
0
    Log2ScaleResult = Log2LMUL + 1;
288
0
    break;
289
0
  case 64:
290
0
    Log2ScaleResult = Log2LMUL;
291
0
    break;
292
0
  }
293
  // Illegal vscale result would be less than 1
294
0
  if (Log2ScaleResult < 0)
295
0
    return None;
296
0
  return 1 << Log2ScaleResult;
297
0
}
298
299
0
void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
300
301
0
LMULType &LMULType::operator*=(uint32_t RHS) {
302
0
  assert(isPowerOf2_32(RHS));
303
0
  this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
304
0
  return *this;
305
0
}
306
307
RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
308
0
    : BT(BT), LMUL(LMULType(Log2LMUL)) {
309
0
  applyBasicType();
310
0
  applyModifier(prototype);
311
0
  Valid = verifyType();
312
0
  if (Valid) {
313
0
    initBuiltinStr();
314
0
    initTypeStr();
315
0
    if (isVector()) {
316
0
      initClangBuiltinStr();
317
0
    }
318
0
  }
319
0
}
320
321
// clang-format off
322
// boolean type are encoded the ratio of n (SEW/LMUL)
323
// SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
324
// c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
325
// IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
326
327
// type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
328
// --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
329
// i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
330
// i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
331
// i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
332
// i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
333
// double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
334
// float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
335
// half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
336
// clang-format on
337
338
0
bool RVVType::verifyType() const {
339
0
  if (ScalarType == Invalid)
340
0
    return false;
341
0
  if (isScalar())
342
0
    return true;
343
0
  if (!Scale.hasValue())
344
0
    return false;
345
0
  if (isFloat() && ElementBitwidth == 8)
346
0
    return false;
347
0
  unsigned V = Scale.getValue();
348
0
  switch (ElementBitwidth) {
349
0
  case 1:
350
0
  case 8:
351
    // Check Scale is 1,2,4,8,16,32,64
352
0
    return (V <= 64 && isPowerOf2_32(V));
353
0
  case 16:
354
    // Check Scale is 1,2,4,8,16,32
355
0
    return (V <= 32 && isPowerOf2_32(V));
356
0
  case 32:
357
    // Check Scale is 1,2,4,8,16
358
0
    return (V <= 16 && isPowerOf2_32(V));
359
0
  case 64:
360
    // Check Scale is 1,2,4,8
361
0
    return (V <= 8 && isPowerOf2_32(V));
362
0
  }
363
0
  return false;
364
0
}
365
366
0
void RVVType::initBuiltinStr() {
367
0
  assert(isValid() && "RVVType is invalid");
368
0
  switch (ScalarType) {
369
0
  case ScalarTypeKind::Void:
370
0
    BuiltinStr = "v";
371
0
    return;
372
0
  case ScalarTypeKind::Size_t:
373
0
    BuiltinStr = "z";
374
0
    if (IsImmediate)
375
0
      BuiltinStr = "I" + BuiltinStr;
376
0
    if (IsPointer)
377
0
      BuiltinStr += "*";
378
0
    return;
379
0
  case ScalarTypeKind::Ptrdiff_t:
380
0
    BuiltinStr = "Y";
381
0
    return;
382
0
  case ScalarTypeKind::UnsignedLong:
383
0
    BuiltinStr = "ULi";
384
0
    return;
385
0
  case ScalarTypeKind::SignedLong:
386
0
    BuiltinStr = "Li";
387
0
    return;
388
0
  case ScalarTypeKind::Boolean:
389
0
    assert(ElementBitwidth == 1);
390
0
    BuiltinStr += "b";
391
0
    break;
392
0
  case ScalarTypeKind::SignedInteger:
393
0
  case ScalarTypeKind::UnsignedInteger:
394
0
    switch (ElementBitwidth) {
395
0
    case 8:
396
0
      BuiltinStr += "c";
397
0
      break;
398
0
    case 16:
399
0
      BuiltinStr += "s";
400
0
      break;
401
0
    case 32:
402
0
      BuiltinStr += "i";
403
0
      break;
404
0
    case 64:
405
0
      BuiltinStr += "Wi";
406
0
      break;
407
0
    default:
408
0
      llvm_unreachable("Unhandled ElementBitwidth!");
409
0
    }
410
0
    if (isSignedInteger())
411
0
      BuiltinStr = "S" + BuiltinStr;
412
0
    else
413
0
      BuiltinStr = "U" + BuiltinStr;
414
0
    break;
415
0
  case ScalarTypeKind::Float:
416
0
    switch (ElementBitwidth) {
417
0
    case 16:
418
0
      BuiltinStr += "x";
419
0
      break;
420
0
    case 32:
421
0
      BuiltinStr += "f";
422
0
      break;
423
0
    case 64:
424
0
      BuiltinStr += "d";
425
0
      break;
426
0
    default:
427
0
      llvm_unreachable("Unhandled ElementBitwidth!");
428
0
    }
429
0
    break;
430
0
  default:
431
0
    llvm_unreachable("ScalarType is invalid!");
432
0
  }
433
0
  if (IsImmediate)
434
0
    BuiltinStr = "I" + BuiltinStr;
435
0
  if (isScalar()) {
436
0
    if (IsConstant)
437
0
      BuiltinStr += "C";
438
0
    if (IsPointer)
439
0
      BuiltinStr += "*";
440
0
    return;
441
0
  }
442
0
  BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
443
  // Pointer to vector types. Defined for Zvlsseg load intrinsics.
444
  // Zvlsseg load intrinsics have pointer type arguments to store the loaded
445
  // vector values.
446
0
  if (IsPointer)
447
0
    BuiltinStr += "*";
448
0
}
449
450
0
void RVVType::initClangBuiltinStr() {
451
0
  assert(isValid() && "RVVType is invalid");
452
0
  assert(isVector() && "Handle Vector type only");
453
454
0
  ClangBuiltinStr = "__rvv_";
455
0
  switch (ScalarType) {
456
0
  case ScalarTypeKind::Boolean:
457
0
    ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
458
0
    return;
459
0
  case ScalarTypeKind::Float:
460
0
    ClangBuiltinStr += "float";
461
0
    break;
462
0
  case ScalarTypeKind::SignedInteger:
463
0
    ClangBuiltinStr += "int";
464
0
    break;
465
0
  case ScalarTypeKind::UnsignedInteger:
466
0
    ClangBuiltinStr += "uint";
467
0
    break;
468
0
  default:
469
0
    llvm_unreachable("ScalarTypeKind is invalid");
470
0
  }
471
0
  ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
472
0
}
473
474
0
void RVVType::initTypeStr() {
475
0
  assert(isValid() && "RVVType is invalid");
476
477
0
  if (IsConstant)
478
0
    Str += "const ";
479
480
0
  auto getTypeString = [&](StringRef TypeStr) {
481
0
    if (isScalar())
482
0
      return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
483
0
    return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
484
0
        .str();
485
0
  };
486
487
0
  switch (ScalarType) {
488
0
  case ScalarTypeKind::Void:
489
0
    Str = "void";
490
0
    return;
491
0
  case ScalarTypeKind::Size_t:
492
0
    Str = "size_t";
493
0
    if (IsPointer)
494
0
      Str += " *";
495
0
    return;
496
0
  case ScalarTypeKind::Ptrdiff_t:
497
0
    Str = "ptrdiff_t";
498
0
    return;
499
0
  case ScalarTypeKind::UnsignedLong:
500
0
    Str = "unsigned long";
501
0
    return;
502
0
  case ScalarTypeKind::SignedLong:
503
0
    Str = "long";
504
0
    return;
505
0
  case ScalarTypeKind::Boolean:
506
0
    if (isScalar())
507
0
      Str += "bool";
508
0
    else
509
      // Vector bool is special case, the formulate is
510
      // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
511
0
      Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
512
0
    break;
513
0
  case ScalarTypeKind::Float:
514
0
    if (isScalar()) {
515
0
      if (ElementBitwidth == 64)
516
0
        Str += "double";
517
0
      else if (ElementBitwidth == 32)
518
0
        Str += "float";
519
0
      else if (ElementBitwidth == 16)
520
0
        Str += "_Float16";
521
0
      else
522
0
        llvm_unreachable("Unhandled floating type.");
523
0
    } else
524
0
      Str += getTypeString("float");
525
0
    break;
526
0
  case ScalarTypeKind::SignedInteger:
527
0
    Str += getTypeString("int");
528
0
    break;
529
0
  case ScalarTypeKind::UnsignedInteger:
530
0
    Str += getTypeString("uint");
531
0
    break;
532
0
  default:
533
0
    llvm_unreachable("ScalarType is invalid!");
534
0
  }
535
0
  if (IsPointer)
536
0
    Str += " *";
537
0
}
538
539
0
void RVVType::initShortStr() {
540
0
  switch (ScalarType) {
541
0
  case ScalarTypeKind::Boolean:
542
0
    assert(isVector());
543
0
    ShortStr = "b" + utostr(64 / Scale.getValue());
544
0
    return;
545
0
  case ScalarTypeKind::Float:
546
0
    ShortStr = "f" + utostr(ElementBitwidth);
547
0
    break;
548
0
  case ScalarTypeKind::SignedInteger:
549
0
    ShortStr = "i" + utostr(ElementBitwidth);
550
0
    break;
551
0
  case ScalarTypeKind::UnsignedInteger:
552
0
    ShortStr = "u" + utostr(ElementBitwidth);
553
0
    break;
554
0
  default:
555
0
    PrintFatalError("Unhandled case!");
556
0
  }
557
0
  if (isVector())
558
0
    ShortStr += LMUL.str();
559
0
}
560
561
0
void RVVType::applyBasicType() {
562
0
  switch (BT) {
563
0
  case 'c':
564
0
    ElementBitwidth = 8;
565
0
    ScalarType = ScalarTypeKind::SignedInteger;
566
0
    break;
567
0
  case 's':
568
0
    ElementBitwidth = 16;
569
0
    ScalarType = ScalarTypeKind::SignedInteger;
570
0
    break;
571
0
  case 'i':
572
0
    ElementBitwidth = 32;
573
0
    ScalarType = ScalarTypeKind::SignedInteger;
574
0
    break;
575
0
  case 'l':
576
0
    ElementBitwidth = 64;
577
0
    ScalarType = ScalarTypeKind::SignedInteger;
578
0
    break;
579
0
  case 'x':
580
0
    ElementBitwidth = 16;
581
0
    ScalarType = ScalarTypeKind::Float;
582
0
    break;
583
0
  case 'f':
584
0
    ElementBitwidth = 32;
585
0
    ScalarType = ScalarTypeKind::Float;
586
0
    break;
587
0
  case 'd':
588
0
    ElementBitwidth = 64;
589
0
    ScalarType = ScalarTypeKind::Float;
590
0
    break;
591
0
  default:
592
0
    PrintFatalError("Unhandled type code!");
593
0
  }
594
0
  assert(ElementBitwidth != 0 && "Bad element bitwidth!");
595
0
}
596
597
0
void RVVType::applyModifier(StringRef Transformer) {
598
0
  if (Transformer.empty())
599
0
    return;
600
  // Handle primitive type transformer
601
0
  auto PType = Transformer.back();
602
0
  switch (PType) {
603
0
  case 'e':
604
0
    Scale = 0;
605
0
    break;
606
0
  case 'v':
607
0
    Scale = LMUL.getScale(ElementBitwidth);
608
0
    break;
609
0
  case 'w':
610
0
    ElementBitwidth *= 2;
611
0
    LMUL *= 2;
612
0
    Scale = LMUL.getScale(ElementBitwidth);
613
0
    break;
614
0
  case 'q':
615
0
    ElementBitwidth *= 4;
616
0
    LMUL *= 4;
617
0
    Scale = LMUL.getScale(ElementBitwidth);
618
0
    break;
619
0
  case 'o':
620
0
    ElementBitwidth *= 8;
621
0
    LMUL *= 8;
622
0
    Scale = LMUL.getScale(ElementBitwidth);
623
0
    break;
624
0
  case 'm':
625
0
    ScalarType = ScalarTypeKind::Boolean;
626
0
    Scale = LMUL.getScale(ElementBitwidth);
627
0
    ElementBitwidth = 1;
628
0
    break;
629
0
  case '0':
630
0
    ScalarType = ScalarTypeKind::Void;
631
0
    break;
632
0
  case 'z':
633
0
    ScalarType = ScalarTypeKind::Size_t;
634
0
    break;
635
0
  case 't':
636
0
    ScalarType = ScalarTypeKind::Ptrdiff_t;
637
0
    break;
638
0
  case 'u':
639
0
    ScalarType = ScalarTypeKind::UnsignedLong;
640
0
    break;
641
0
  case 'l':
642
0
    ScalarType = ScalarTypeKind::SignedLong;
643
0
    break;
644
0
  default:
645
0
    PrintFatalError("Illegal primitive type transformers!");
646
0
  }
647
0
  Transformer = Transformer.drop_back();
648
649
  // Extract and compute complex type transformer. It can only appear one time.
650
0
  if (Transformer.startswith("(")) {
651
0
    size_t Idx = Transformer.find(')');
652
0
    assert(Idx != StringRef::npos);
653
0
    StringRef ComplexType = Transformer.slice(1, Idx);
654
0
    Transformer = Transformer.drop_front(Idx + 1);
655
0
    assert(Transformer.find('(') == StringRef::npos &&
656
0
           "Only allow one complex type transformer");
657
658
0
    auto UpdateAndCheckComplexProto = [&]() {
659
0
      Scale = LMUL.getScale(ElementBitwidth);
660
0
      const StringRef VectorPrototypes("vwqom");
661
0
      if (!VectorPrototypes.contains(PType))
662
0
        PrintFatalError("Complex type transformer only supports vector type!");
663
0
      if (Transformer.find_first_of("PCKWS") != StringRef::npos)
664
0
        PrintFatalError(
665
0
            "Illegal type transformer for Complex type transformer");
666
0
    };
667
0
    auto ComputeFixedLog2LMUL =
668
0
        [&](StringRef Value,
669
0
            std::function<bool(const int32_t &, const int32_t &)> Compare) {
670
0
          int32_t Log2LMUL;
671
0
          Value.getAsInteger(10, Log2LMUL);
672
0
          if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
673
0
            ScalarType = Invalid;
674
0
            return false;
675
0
          }
676
          // Update new LMUL
677
0
          LMUL = LMULType(Log2LMUL);
678
0
          UpdateAndCheckComplexProto();
679
0
          return true;
680
0
        };
681
0
    auto ComplexTT = ComplexType.split(":");
682
0
    if (ComplexTT.first == "Log2EEW") {
683
0
      uint32_t Log2EEW;
684
0
      ComplexTT.second.getAsInteger(10, Log2EEW);
685
      // update new elmul = (eew/sew) * lmul
686
0
      LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
687
      // update new eew
688
0
      ElementBitwidth = 1 << Log2EEW;
689
0
      ScalarType = ScalarTypeKind::SignedInteger;
690
0
      UpdateAndCheckComplexProto();
691
0
    } else if (ComplexTT.first == "FixedSEW") {
692
0
      uint32_t NewSEW;
693
0
      ComplexTT.second.getAsInteger(10, NewSEW);
694
      // Set invalid type if src and dst SEW are same.
695
0
      if (ElementBitwidth == NewSEW) {
696
0
        ScalarType = Invalid;
697
0
        return;
698
0
      }
699
      // Update new SEW
700
0
      ElementBitwidth = NewSEW;
701
0
      UpdateAndCheckComplexProto();
702
0
    } else if (ComplexTT.first == "LFixedLog2LMUL") {
703
      // New LMUL should be larger than old
704
0
      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
705
0
        return;
706
0
    } else if (ComplexTT.first == "SFixedLog2LMUL") {
707
      // New LMUL should be smaller than old
708
0
      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
709
0
        return;
710
0
    } else {
711
0
      PrintFatalError("Illegal complex type transformers!");
712
0
    }
713
0
  }
714
715
  // Compute the remain type transformers
716
0
  for (char I : Transformer) {
717
0
    switch (I) {
718
0
    case 'P':
719
0
      if (IsConstant)
720
0
        PrintFatalError("'P' transformer cannot be used after 'C'");
721
0
      if (IsPointer)
722
0
        PrintFatalError("'P' transformer cannot be used twice");
723
0
      IsPointer = true;
724
0
      break;
725
0
    case 'C':
726
0
      if (IsConstant)
727
0
        PrintFatalError("'C' transformer cannot be used twice");
728
0
      IsConstant = true;
729
0
      break;
730
0
    case 'K':
731
0
      IsImmediate = true;
732
0
      break;
733
0
    case 'U':
734
0
      ScalarType = ScalarTypeKind::UnsignedInteger;
735
0
      break;
736
0
    case 'I':
737
0
      ScalarType = ScalarTypeKind::SignedInteger;
738
0
      break;
739
0
    case 'F':
740
0
      ScalarType = ScalarTypeKind::Float;
741
0
      break;
742
0
    case 'S':
743
0
      LMUL = LMULType(0);
744
      // Update ElementBitwidth need to update Scale too.
745
0
      Scale = LMUL.getScale(ElementBitwidth);
746
0
      break;
747
0
    default:
748
0
      PrintFatalError("Illegal non-primitive type transformer!");
749
0
    }
750
0
  }
751
0
}
752
753
//===----------------------------------------------------------------------===//
754
// RVVIntrinsic implementation
755
//===----------------------------------------------------------------------===//
756
RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
757
                           StringRef NewMangledName, StringRef MangledSuffix,
758
                           StringRef IRName, bool HasSideEffects, bool IsMask,
759
                           bool HasMaskedOffOperand, bool HasVL,
760
                           bool HasNoMaskedOverloaded, bool HasAutoDef,
761
                           StringRef ManualCodegen, const RVVTypes &OutInTypes,
762
                           const std::vector<int64_t> &NewIntrinsicTypes,
763
                           StringRef RequiredExtension, unsigned NF)
764
    : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask),
765
      HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL),
766
      HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef),
767
0
      ManualCodegen(ManualCodegen.str()), NF(NF) {
768
769
  // Init Name and MangledName
770
0
  Name = NewName.str();
771
0
  if (NewMangledName.empty())
772
0
    MangledName = NewName.split("_").first.str();
773
0
  else
774
0
    MangledName = NewMangledName.str();
775
0
  if (!Suffix.empty())
776
0
    Name += "_" + Suffix.str();
777
0
  if (!MangledSuffix.empty())
778
0
    MangledName += "_" + MangledSuffix.str();
779
0
  if (IsMask) {
780
0
    Name += "_m";
781
0
  }
782
  // Init RISC-V extensions
783
0
  for (const auto &T : OutInTypes) {
784
0
    if (T->isFloatVector(16) || T->isFloat(16))
785
0
      RISCVExtensions |= RISCVExtension::Zfh;
786
0
    else if (T->isFloatVector(32) || T->isFloat(32))
787
0
      RISCVExtensions |= RISCVExtension::F;
788
0
    else if (T->isFloatVector(64) || T->isFloat(64))
789
0
      RISCVExtensions |= RISCVExtension::D;
790
0
  }
791
0
  if (RequiredExtension == "Zvamo")
792
0
    RISCVExtensions |= RISCVExtension::Zvamo;
793
0
  if (RequiredExtension == "Zvlsseg")
794
0
    RISCVExtensions |= RISCVExtension::Zvlsseg;
795
796
  // Init OutputType and InputTypes
797
0
  OutputType = OutInTypes[0];
798
0
  InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
799
800
  // IntrinsicTypes is nonmasked version index. Need to update it
801
  // if there is maskedoff operand (It is always in first operand).
802
0
  IntrinsicTypes = NewIntrinsicTypes;
803
0
  if (IsMask && HasMaskedOffOperand) {
804
0
    for (auto &I : IntrinsicTypes) {
805
0
      if (I >= 0)
806
0
        I += NF;
807
0
    }
808
0
  }
809
0
}
810
811
0
std::string RVVIntrinsic::getBuiltinTypeStr() const {
812
0
  std::string S;
813
0
  S += OutputType->getBuiltinStr();
814
0
  for (const auto &T : InputTypes) {
815
0
    S += T->getBuiltinStr();
816
0
  }
817
0
  return S;
818
0
}
819
820
0
void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
821
0
  if (!getIRName().empty())
822
0
    OS << "  ID = Intrinsic::riscv_" + getIRName() + ";\n";
823
0
  if (NF >= 2)
824
0
    OS << "  NF = " + utostr(getNF()) + ";\n";
825
0
  if (hasManualCodegen()) {
826
0
    OS << ManualCodegen;
827
0
    OS << "break;\n";
828
0
    return;
829
0
  }
830
831
0
  if (isMask()) {
832
0
    if (hasVL()) {
833
0
      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
834
0
    } else {
835
0
      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
836
0
    }
837
0
  }
838
839
0
  OS << "  IntrinsicTypes = {";
840
0
  ListSeparator LS;
841
0
  for (const auto &Idx : IntrinsicTypes) {
842
0
    if (Idx == -1)
843
0
      OS << LS << "ResultType";
844
0
    else
845
0
      OS << LS << "Ops[" << Idx << "]->getType()";
846
0
  }
847
848
  // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
849
  // always last operand.
850
0
  if (hasVL())
851
0
    OS << ", Ops.back()->getType()";
852
0
  OS << "};\n";
853
0
  OS << "  break;\n";
854
0
}
855
856
0
void RVVIntrinsic::emitIntrinsicMacro(raw_ostream &OS) const {
857
0
  OS << "#define " << getName() << "(";
858
0
  if (!InputTypes.empty()) {
859
0
    ListSeparator LS;
860
0
    for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
861
0
      OS << LS << "op" << i;
862
0
  }
863
0
  OS << ") \\\n";
864
0
  OS << "__builtin_rvv_" << getName() << "(";
865
0
  if (!InputTypes.empty()) {
866
0
    ListSeparator LS;
867
0
    for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
868
0
      OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")";
869
0
  }
870
0
  OS << ")\n";
871
0
}
872
873
0
void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
874
0
  OS << "__attribute__((clang_builtin_alias(";
875
0
  OS << "__builtin_rvv_" << getName() << ")))\n";
876
0
  OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
877
  // Emit function arguments
878
0
  if (!InputTypes.empty()) {
879
0
    ListSeparator LS;
880
0
    for (unsigned i = 0; i < InputTypes.size(); ++i)
881
0
      OS << LS << InputTypes[i]->getTypeStr() << " op" << i;
882
0
  }
883
0
  OS << ");\n\n";
884
0
}
885
886
//===----------------------------------------------------------------------===//
887
// RVVEmitter implementation
888
//===----------------------------------------------------------------------===//
889
0
void RVVEmitter::createHeader(raw_ostream &OS) {
890
891
0
  OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics "
892
0
        "-------------------===\n"
893
0
        " *\n"
894
0
        " *\n"
895
0
        " * Part of the LLVM Project, under the Apache License v2.0 with LLVM "
896
0
        "Exceptions.\n"
897
0
        " * See https://llvm.org/LICENSE.txt for license information.\n"
898
0
        " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n"
899
0
        " *\n"
900
0
        " *===-----------------------------------------------------------------"
901
0
        "------===\n"
902
0
        " */\n\n";
903
904
0
  OS << "#ifndef __RISCV_VECTOR_H\n";
905
0
  OS << "#define __RISCV_VECTOR_H\n\n";
906
907
0
  OS << "#include <stdint.h>\n";
908
0
  OS << "#include <stddef.h>\n\n";
909
910
0
  OS << "#ifndef __riscv_vector\n";
911
0
  OS << "#error \"Vector intrinsics require the vector extension.\"\n";
912
0
  OS << "#endif\n\n";
913
914
0
  OS << "#ifdef __cplusplus\n";
915
0
  OS << "extern \"C\" {\n";
916
0
  OS << "#endif\n\n";
917
918
0
  std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
919
0
  createRVVIntrinsics(Defs);
920
921
  // Print header code
922
0
  if (!HeaderCode.empty()) {
923
0
    OS << HeaderCode;
924
0
  }
925
926
0
  auto printType = [&](auto T) {
927
0
    OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr()
928
0
       << ";\n";
929
0
  };
930
931
0
  constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
932
  // Print RVV boolean types.
933
0
  for (int Log2LMUL : Log2LMULs) {
934
0
    auto T = computeType('c', Log2LMUL, "m");
935
0
    if (T.hasValue())
936
0
      printType(T.getValue());
937
0
  }
938
  // Print RVV int/float types.
939
0
  for (char I : StringRef("csil")) {
940
0
    for (int Log2LMUL : Log2LMULs) {
941
0
      auto T = computeType(I, Log2LMUL, "v");
942
0
      if (T.hasValue()) {
943
0
        printType(T.getValue());
944
0
        auto UT = computeType(I, Log2LMUL, "Uv");
945
0
        printType(UT.getValue());
946
0
      }
947
0
    }
948
0
  }
949
0
  OS << "#if defined(__riscv_zfh)\n";
950
0
  for (int Log2LMUL : Log2LMULs) {
951
0
    auto T = computeType('x', Log2LMUL, "v");
952
0
    if (T.hasValue())
953
0
      printType(T.getValue());
954
0
  }
955
0
  OS << "#endif\n";
956
957
0
  OS << "#if defined(__riscv_f)\n";
958
0
  for (int Log2LMUL : Log2LMULs) {
959
0
    auto T = computeType('f', Log2LMUL, "v");
960
0
    if (T.hasValue())
961
0
      printType(T.getValue());
962
0
  }
963
0
  OS << "#endif\n";
964
965
0
  OS << "#if defined(__riscv_d)\n";
966
0
  for (int Log2LMUL : Log2LMULs) {
967
0
    auto T = computeType('d', Log2LMUL, "v");
968
0
    if (T.hasValue())
969
0
      printType(T.getValue());
970
0
  }
971
0
  OS << "#endif\n\n";
972
973
  // The same extension include in the same arch guard marco.
974
0
  std::stable_sort(Defs.begin(), Defs.end(),
975
0
                   [](const std::unique_ptr<RVVIntrinsic> &A,
976
0
                      const std::unique_ptr<RVVIntrinsic> &B) {
977
0
                     return A->getRISCVExtensions() < B->getRISCVExtensions();
978
0
                   });
979
980
  // Print intrinsic functions with macro
981
0
  emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
982
0
    Inst.emitIntrinsicMacro(OS);
983
0
  });
984
985
0
  OS << "#define __riscv_v_intrinsic_overloading 1\n";
986
987
  // Print Overloaded APIs
988
0
  OS << "#define __rvv_overloaded static inline "
989
0
        "__attribute__((__always_inline__, __nodebug__, __overloadable__))\n";
990
991
0
  emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
992
0
    if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded())
993
0
      return;
994
0
    OS << "__rvv_overloaded ";
995
0
    Inst.emitMangledFuncDef(OS);
996
0
  });
997
998
0
  OS << "\n#ifdef __cplusplus\n";
999
0
  OS << "}\n";
1000
0
  OS << "#endif // __riscv_vector\n";
1001
0
  OS << "#endif // __RISCV_VECTOR_H\n";
1002
0
}
1003
1004
0
void RVVEmitter::createBuiltins(raw_ostream &OS) {
1005
0
  std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1006
0
  createRVVIntrinsics(Defs);
1007
1008
0
  OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n";
1009
0
  OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, "
1010
0
        "ATTRS, \"experimental-v\")\n";
1011
0
  OS << "#endif\n";
1012
0
  for (auto &Def : Defs) {
1013
0
    OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getName() << ",\""
1014
0
       << Def->getBuiltinTypeStr() << "\", ";
1015
0
    if (!Def->hasSideEffects())
1016
0
      OS << "\"n\")\n";
1017
0
    else
1018
0
      OS << "\"\")\n";
1019
0
  }
1020
0
  OS << "#undef RISCVV_BUILTIN\n";
1021
0
}
1022
1023
0
void RVVEmitter::createCodeGen(raw_ostream &OS) {
1024
0
  std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1025
0
  createRVVIntrinsics(Defs);
1026
  // IR name could be empty, use the stable sort preserves the relative order.
1027
0
  std::stable_sort(Defs.begin(), Defs.end(),
1028
0
                   [](const std::unique_ptr<RVVIntrinsic> &A,
1029
0
                      const std::unique_ptr<RVVIntrinsic> &B) {
1030
0
                     return A->getIRName() < B->getIRName();
1031
0
                   });
1032
  // Print switch body when the ir name or ManualCodegen changes from previous
1033
  // iteration.
1034
0
  RVVIntrinsic *PrevDef = Defs.begin()->get();
1035
0
  for (auto &Def : Defs) {
1036
0
    StringRef CurIRName = Def->getIRName();
1037
0
    if (CurIRName != PrevDef->getIRName() ||
1038
0
        (Def->getManualCodegen() != PrevDef->getManualCodegen())) {
1039
0
      PrevDef->emitCodeGenSwitchBody(OS);
1040
0
    }
1041
0
    PrevDef = Def.get();
1042
0
    OS << "case RISCV::BI__builtin_rvv_" << Def->getName() << ":\n";
1043
0
  }
1044
0
  Defs.back()->emitCodeGenSwitchBody(OS);
1045
0
  OS << "\n";
1046
0
}
1047
1048
void RVVEmitter::parsePrototypes(StringRef Prototypes,
1049
0
                                 std::function<void(StringRef)> Handler) {
1050
0
  const StringRef Primaries("evwqom0ztul");
1051
0
  while (!Prototypes.empty()) {
1052
0
    size_t Idx = 0;
1053
    // Skip over complex prototype because it could contain primitive type
1054
    // character.
1055
0
    if (Prototypes[0] == '(')
1056
0
      Idx = Prototypes.find_first_of(')');
1057
0
    Idx = Prototypes.find_first_of(Primaries, Idx);
1058
0
    assert(Idx != StringRef::npos);
1059
0
    Handler(Prototypes.slice(0, Idx + 1));
1060
0
    Prototypes = Prototypes.drop_front(Idx + 1);
1061
0
  }
1062
0
}
1063
1064
std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
1065
0
                                     StringRef Prototypes) {
1066
0
  SmallVector<std::string> SuffixStrs;
1067
0
  parsePrototypes(Prototypes, [&](StringRef Proto) {
1068
0
    auto T = computeType(Type, Log2LMUL, Proto);
1069
0
    SuffixStrs.push_back(T.getValue()->getShortStr());
1070
0
  });
1071
0
  return join(SuffixStrs, "_");
1072
0
}
1073
1074
void RVVEmitter::createRVVIntrinsics(
1075
0
    std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
1076
0
  std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
1077
0
  for (auto *R : RV) {
1078
0
    StringRef Name = R->getValueAsString("Name");
1079
0
    StringRef SuffixProto = R->getValueAsString("Suffix");
1080
0
    StringRef MangledName = R->getValueAsString("MangledName");
1081
0
    StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix");
1082
0
    StringRef Prototypes = R->getValueAsString("Prototype");
1083
0
    StringRef TypeRange = R->getValueAsString("TypeRange");
1084
0
    bool HasMask = R->getValueAsBit("HasMask");
1085
0
    bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand");
1086
0
    bool HasVL = R->getValueAsBit("HasVL");
1087
0
    bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded");
1088
0
    bool HasSideEffects = R->getValueAsBit("HasSideEffects");
1089
0
    std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL");
1090
0
    StringRef ManualCodegen = R->getValueAsString("ManualCodegen");
1091
0
    StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask");
1092
0
    std::vector<int64_t> IntrinsicTypes =
1093
0
        R->getValueAsListOfInts("IntrinsicTypes");
1094
0
    StringRef RequiredExtension = R->getValueAsString("RequiredExtension");
1095
0
    StringRef IRName = R->getValueAsString("IRName");
1096
0
    StringRef IRNameMask = R->getValueAsString("IRNameMask");
1097
0
    unsigned NF = R->getValueAsInt("NF");
1098
1099
0
    StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1100
0
    bool HasAutoDef = HeaderCodeStr.empty();
1101
0
    if (!HeaderCodeStr.empty()) {
1102
0
      HeaderCode += HeaderCodeStr.str();
1103
0
    }
1104
    // Parse prototype and create a list of primitive type with transformers
1105
    // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
1106
0
    SmallVector<std::string> ProtoSeq;
1107
0
    parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
1108
0
      ProtoSeq.push_back(Proto.str());
1109
0
    });
1110
1111
    // Compute Builtin types
1112
0
    SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
1113
0
    if (HasMask) {
1114
      // If HasMaskedOffOperand, insert result type as first input operand.
1115
0
      if (HasMaskedOffOperand) {
1116
0
        if (NF == 1) {
1117
0
          ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]);
1118
0
        } else {
1119
          // Convert
1120
          // (void, op0 address, op1 address, ...)
1121
          // to
1122
          // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1123
0
          for (unsigned I = 0; I < NF; ++I)
1124
0
            ProtoMaskSeq.insert(
1125
0
                ProtoMaskSeq.begin() + NF + 1,
1126
0
                ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
1127
0
        }
1128
0
      }
1129
0
      if (HasMaskedOffOperand && NF > 1) {
1130
        // Convert
1131
        // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1132
        // to
1133
        // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1134
        // ...)
1135
0
        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
1136
0
      } else {
1137
        // If HasMask, insert 'm' as first input operand.
1138
0
        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
1139
0
      }
1140
0
    }
1141
    // If HasVL, append 'z' to last operand
1142
0
    if (HasVL) {
1143
0
      ProtoSeq.push_back("z");
1144
0
      ProtoMaskSeq.push_back("z");
1145
0
    }
1146
1147
    // Create Intrinsics for each type and LMUL.
1148
0
    for (char I : TypeRange) {
1149
0
      for (int Log2LMUL : Log2LMULList) {
1150
0
        Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
1151
        // Ignored to create new intrinsic if there are any illegal types.
1152
0
        if (!Types.hasValue())
1153
0
          continue;
1154
1155
0
        auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
1156
0
        auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
1157
        // Create a non-mask intrinsic
1158
0
        Out.push_back(std::make_unique<RVVIntrinsic>(
1159
0
            Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
1160
0
            HasSideEffects, /*IsMask=*/false, /*HasMaskedOffOperand=*/false,
1161
0
            HasVL, HasNoMaskedOverloaded, HasAutoDef, ManualCodegen,
1162
0
            Types.getValue(), IntrinsicTypes, RequiredExtension, NF));
1163
0
        if (HasMask) {
1164
          // Create a mask intrinsic
1165
0
          Optional<RVVTypes> MaskTypes =
1166
0
              computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
1167
0
          Out.push_back(std::make_unique<RVVIntrinsic>(
1168
0
              Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1169
0
              HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL,
1170
0
              HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask,
1171
0
              MaskTypes.getValue(), IntrinsicTypes, RequiredExtension, NF));
1172
0
        }
1173
0
      } // end for Log2LMULList
1174
0
    }   // end for TypeRange
1175
0
  }
1176
0
}
1177
1178
Optional<RVVTypes>
1179
RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
1180
0
                         ArrayRef<std::string> PrototypeSeq) {
1181
  // LMUL x NF must be less than or equal to 8.
1182
0
  if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
1183
0
    return llvm::None;
1184
1185
0
  RVVTypes Types;
1186
0
  for (const std::string &Proto : PrototypeSeq) {
1187
0
    auto T = computeType(BT, Log2LMUL, Proto);
1188
0
    if (!T.hasValue())
1189
0
      return llvm::None;
1190
    // Record legal type index
1191
0
    Types.push_back(T.getValue());
1192
0
  }
1193
0
  return Types;
1194
0
}
1195
1196
Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
1197
0
                                             StringRef Proto) {
1198
0
  std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
1199
  // Search first
1200
0
  auto It = LegalTypes.find(Idx);
1201
0
  if (It != LegalTypes.end())
1202
0
    return &(It->second);
1203
0
  if (IllegalTypes.count(Idx))
1204
0
    return llvm::None;
1205
  // Compute type and record the result.
1206
0
  RVVType T(BT, Log2LMUL, Proto);
1207
0
  if (T.isValid()) {
1208
    // Record legal type index and value.
1209
0
    LegalTypes.insert({Idx, T});
1210
0
    return &(LegalTypes[Idx]);
1211
0
  }
1212
  // Record illegal type index.
1213
0
  IllegalTypes.insert(Idx);
1214
0
  return llvm::None;
1215
0
}
1216
1217
void RVVEmitter::emitArchMacroAndBody(
1218
    std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
1219
0
    std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
1220
0
  uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions();
1221
0
  bool NeedEndif = emitExtDefStr(PrevExt, OS);
1222
0
  for (auto &Def : Defs) {
1223
0
    uint8_t CurExt = Def->getRISCVExtensions();
1224
0
    if (CurExt != PrevExt) {
1225
0
      if (NeedEndif)
1226
0
        OS << "#endif\n\n";
1227
0
      NeedEndif = emitExtDefStr(CurExt, OS);
1228
0
      PrevExt = CurExt;
1229
0
    }
1230
0
    if (Def->hasAutoDef())
1231
0
      PrintBody(OS, *Def);
1232
0
  }
1233
0
  if (NeedEndif)
1234
0
    OS << "#endif\n\n";
1235
0
}
1236
1237
0
bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) {
1238
0
  if (Extents == RISCVExtension::Basic)
1239
0
    return false;
1240
0
  OS << "#if ";
1241
0
  ListSeparator LS(" && ");
1242
0
  if (Extents & RISCVExtension::F)
1243
0
    OS << LS << "defined(__riscv_f)";
1244
0
  if (Extents & RISCVExtension::D)
1245
0
    OS << LS << "defined(__riscv_d)";
1246
0
  if (Extents & RISCVExtension::Zfh)
1247
0
    OS << LS << "defined(__riscv_zfh)";
1248
0
  if (Extents & RISCVExtension::Zvamo)
1249
0
    OS << LS << "defined(__riscv_zvamo)";
1250
0
  if (Extents & RISCVExtension::Zvlsseg)
1251
0
    OS << LS << "defined(__riscv_zvlsseg)";
1252
0
  OS << "\n";
1253
0
  return true;
1254
0
}
1255
1256
namespace clang {
1257
0
void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) {
1258
0
  RVVEmitter(Records).createHeader(OS);
1259
0
}
1260
1261
0
void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) {
1262
0
  RVVEmitter(Records).createBuiltins(OS);
1263
0
}
1264
1265
0
void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
1266
0
  RVVEmitter(Records).createCodeGen(OS);
1267
0
}
1268
1269
} // End namespace clang