Coverage Report

Created: 2022-01-18 06:27

/Users/buildslave/jenkins/workspace/coverage/llvm-project/clang/utils/TableGen/RISCVVEmitter.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This tablegen backend is responsible for emitting riscv_vector.h which
10
// includes a declaration and definition of each intrinsic functions specified
11
// in https://github.com/riscv/rvv-intrinsic-doc.
12
//
13
// See also the documentation in include/clang/Basic/riscv_vector.td.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#include "llvm/ADT/ArrayRef.h"
18
#include "llvm/ADT/SmallSet.h"
19
#include "llvm/ADT/StringExtras.h"
20
#include "llvm/ADT/StringMap.h"
21
#include "llvm/ADT/StringSet.h"
22
#include "llvm/ADT/Twine.h"
23
#include "llvm/TableGen/Error.h"
24
#include "llvm/TableGen/Record.h"
25
#include <numeric>
26
27
using namespace llvm;
28
using BasicType = char;
29
using VScaleVal = Optional<unsigned>;
30
31
namespace {
32
33
// Exponential LMUL
34
struct LMULType {
35
  int Log2LMUL;
36
  LMULType(int Log2LMUL);
37
  // Return the C/C++ string representation of LMUL
38
  std::string str() const;
39
  Optional<unsigned> getScale(unsigned ElementBitwidth) const;
40
  void MulLog2LMUL(int Log2LMUL);
41
  LMULType &operator*=(uint32_t RHS);
42
};
43
44
// This class is compact representation of a valid and invalid RVVType.
45
class RVVType {
46
  enum ScalarTypeKind : uint32_t {
47
    Void,
48
    Size_t,
49
    Ptrdiff_t,
50
    UnsignedLong,
51
    SignedLong,
52
    Boolean,
53
    SignedInteger,
54
    UnsignedInteger,
55
    Float,
56
    Invalid,
57
  };
58
  BasicType BT;
59
  ScalarTypeKind ScalarType = Invalid;
60
  LMULType LMUL;
61
  bool IsPointer = false;
62
  // IsConstant indices are "int", but have the constant expression.
63
  bool IsImmediate = false;
64
  // Const qualifier for pointer to const object or object of const type.
65
  bool IsConstant = false;
66
  unsigned ElementBitwidth = 0;
67
  VScaleVal Scale = 0;
68
  bool Valid;
69
70
  std::string BuiltinStr;
71
  std::string ClangBuiltinStr;
72
  std::string Str;
73
  std::string ShortStr;
74
75
public:
76
0
  RVVType() : RVVType(BasicType(), 0, StringRef()) {}
77
  RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
78
79
  // Return the string representation of a type, which is an encoded string for
80
  // passing to the BUILTIN() macro in Builtins.def.
81
0
  const std::string &getBuiltinStr() const { return BuiltinStr; }
82
83
  // Return the clang builtin type for RVV vector type which are used in the
84
  // riscv_vector.h header file.
85
0
  const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
86
87
  // Return the C/C++ string representation of a type for use in the
88
  // riscv_vector.h header file.
89
0
  const std::string &getTypeStr() const { return Str; }
90
91
  // Return the short name of a type for C/C++ name suffix.
92
0
  const std::string &getShortStr() {
93
    // Not all types are used in short name, so compute the short name by
94
    // demanded.
95
0
    if (ShortStr.empty())
96
0
      initShortStr();
97
0
    return ShortStr;
98
0
  }
99
100
0
  bool isValid() const { return Valid; }
101
0
  bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
102
0
  bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
103
0
  bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
104
0
  bool isSignedInteger() const {
105
0
    return ScalarType == ScalarTypeKind::SignedInteger;
106
0
  }
107
0
  bool isFloatVector(unsigned Width) const {
108
0
    return isVector() && isFloat() && ElementBitwidth == Width;
109
0
  }
110
0
  bool isFloat(unsigned Width) const {
111
0
    return isFloat() && ElementBitwidth == Width;
112
0
  }
113
114
private:
115
  // Verify RVV vector type and set Valid.
116
  bool verifyType() const;
117
118
  // Creates a type based on basic types of TypeRange
119
  void applyBasicType();
120
121
  // Applies a prototype modifier to the current type. The result maybe an
122
  // invalid type.
123
  void applyModifier(StringRef prototype);
124
125
  // Compute and record a string for legal type.
126
  void initBuiltinStr();
127
  // Compute and record a builtin RVV vector type string.
128
  void initClangBuiltinStr();
129
  // Compute and record a type string for used in the header.
130
  void initTypeStr();
131
  // Compute and record a short name of a type for C/C++ name suffix.
132
  void initShortStr();
133
};
134
135
using RVVTypePtr = RVVType *;
136
using RVVTypes = std::vector<RVVTypePtr>;
137
138
enum RISCVExtension : uint8_t {
139
  Basic = 0,
140
  F = 1 << 1,
141
  D = 1 << 2,
142
  Zfh = 1 << 3,
143
  Zvlsseg = 1 << 4,
144
  RV64 = 1 << 5,
145
};
146
147
// TODO refactor RVVIntrinsic class design after support all intrinsic
148
// combination. This represents an instantiation of an intrinsic with a
149
// particular type and prototype
150
class RVVIntrinsic {
151
152
private:
153
  std::string BuiltinName; // Builtin name
154
  std::string Name;        // C intrinsic name.
155
  std::string MangledName;
156
  std::string IRName;
157
  bool IsMask;
158
  bool HasVL;
159
  bool HasPolicy;
160
  bool HasNoMaskedOverloaded;
161
  bool HasAutoDef; // There is automiatic definition in header
162
  std::string ManualCodegen;
163
  RVVTypePtr OutputType; // Builtin output type
164
  RVVTypes InputTypes;   // Builtin input types
165
  // The types we use to obtain the specific LLVM intrinsic. They are index of
166
  // InputTypes. -1 means the return type.
167
  std::vector<int64_t> IntrinsicTypes;
168
  uint8_t RISCVExtensions = 0;
169
  unsigned NF = 1;
170
171
public:
172
  RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
173
               StringRef MangledSuffix, StringRef IRName, bool IsMask,
174
               bool HasMaskedOffOperand, bool HasVL, bool HasPolicy,
175
               bool HasNoMaskedOverloaded, bool HasAutoDef,
176
               StringRef ManualCodegen, const RVVTypes &Types,
177
               const std::vector<int64_t> &IntrinsicTypes,
178
               const std::vector<StringRef> &RequiredExtensions, unsigned NF);
179
0
  ~RVVIntrinsic() = default;
180
181
0
  StringRef getBuiltinName() const { return BuiltinName; }
182
0
  StringRef getName() const { return Name; }
183
0
  StringRef getMangledName() const { return MangledName; }
184
0
  bool hasVL() const { return HasVL; }
185
0
  bool hasPolicy() const { return HasPolicy; }
186
0
  bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; }
187
0
  bool hasManualCodegen() const { return !ManualCodegen.empty(); }
188
0
  bool hasAutoDef() const { return HasAutoDef; }
189
0
  bool isMask() const { return IsMask; }
190
0
  StringRef getIRName() const { return IRName; }
191
0
  StringRef getManualCodegen() const { return ManualCodegen; }
192
0
  uint8_t getRISCVExtensions() const { return RISCVExtensions; }
193
0
  unsigned getNF() const { return NF; }
194
0
  const std::vector<int64_t> &getIntrinsicTypes() const {
195
0
    return IntrinsicTypes;
196
0
  }
197
198
  // Return the type string for a BUILTIN() macro in Builtins.def.
199
  std::string getBuiltinTypeStr() const;
200
201
  // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
202
  // init the RVVIntrinsic ID and IntrinsicTypes.
203
  void emitCodeGenSwitchBody(raw_ostream &o) const;
204
205
  // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
206
  void emitIntrinsicFuncDef(raw_ostream &o) const;
207
208
  // Emit the mangled function definition.
209
  void emitMangledFuncDef(raw_ostream &o) const;
210
};
211
212
class RVVEmitter {
213
private:
214
  RecordKeeper &Records;
215
  std::string HeaderCode;
216
  // Concat BasicType, LMUL and Proto as key
217
  StringMap<RVVType> LegalTypes;
218
  StringSet<> IllegalTypes;
219
220
public:
221
0
  RVVEmitter(RecordKeeper &R) : Records(R) {}
222
223
  /// Emit riscv_vector.h
224
  void createHeader(raw_ostream &o);
225
226
  /// Emit all the __builtin prototypes and code needed by Sema.
227
  void createBuiltins(raw_ostream &o);
228
229
  /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
230
  void createCodeGen(raw_ostream &o);
231
232
  std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
233
234
private:
235
  /// Create all intrinsics and add them to \p Out
236
  void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
237
  /// Create Headers and add them to \p Out
238
  void createRVVHeaders(raw_ostream &OS);
239
  /// Compute output and input types by applying different config (basic type
240
  /// and LMUL with type transformers). It also record result of type in legal
241
  /// or illegal set to avoid compute the  same config again. The result maybe
242
  /// have illegal RVVType.
243
  Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
244
                                  ArrayRef<std::string> PrototypeSeq);
245
  Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
246
247
  /// Emit Acrh predecessor definitions and body, assume the element of Defs are
248
  /// sorted by extension.
249
  void emitArchMacroAndBody(
250
      std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o,
251
      std::function<void(raw_ostream &, const RVVIntrinsic &)>);
252
253
  // Emit the architecture preprocessor definitions. Return true when emits
254
  // non-empty string.
255
  bool emitExtDefStr(uint8_t Extensions, raw_ostream &o);
256
  // Slice Prototypes string into sub prototype string and process each sub
257
  // prototype string individually in the Handler.
258
  void parsePrototypes(StringRef Prototypes,
259
                       std::function<void(StringRef)> Handler);
260
};
261
262
} // namespace
263
264
//===----------------------------------------------------------------------===//
265
// Type implementation
266
//===----------------------------------------------------------------------===//
267
268
0
LMULType::LMULType(int NewLog2LMUL) {
269
  // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
270
0
  assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
271
0
  Log2LMUL = NewLog2LMUL;
272
0
}
273
274
0
std::string LMULType::str() const {
275
0
  if (Log2LMUL < 0)
276
0
    return "mf" + utostr(1ULL << (-Log2LMUL));
277
0
  return "m" + utostr(1ULL << Log2LMUL);
278
0
}
279
280
0
VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
281
0
  int Log2ScaleResult = 0;
282
0
  switch (ElementBitwidth) {
283
0
  default:
284
0
    break;
285
0
  case 8:
286
0
    Log2ScaleResult = Log2LMUL + 3;
287
0
    break;
288
0
  case 16:
289
0
    Log2ScaleResult = Log2LMUL + 2;
290
0
    break;
291
0
  case 32:
292
0
    Log2ScaleResult = Log2LMUL + 1;
293
0
    break;
294
0
  case 64:
295
0
    Log2ScaleResult = Log2LMUL;
296
0
    break;
297
0
  }
298
  // Illegal vscale result would be less than 1
299
0
  if (Log2ScaleResult < 0)
300
0
    return None;
301
0
  return 1 << Log2ScaleResult;
302
0
}
303
304
0
void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
305
306
0
LMULType &LMULType::operator*=(uint32_t RHS) {
307
0
  assert(isPowerOf2_32(RHS));
308
0
  this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
309
0
  return *this;
310
0
}
311
312
RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
313
0
    : BT(BT), LMUL(LMULType(Log2LMUL)) {
314
0
  applyBasicType();
315
0
  applyModifier(prototype);
316
0
  Valid = verifyType();
317
0
  if (Valid) {
318
0
    initBuiltinStr();
319
0
    initTypeStr();
320
0
    if (isVector()) {
321
0
      initClangBuiltinStr();
322
0
    }
323
0
  }
324
0
}
325
326
// clang-format off
327
// boolean type are encoded the ratio of n (SEW/LMUL)
328
// SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
329
// c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
330
// IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
331
332
// type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
333
// --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
334
// i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
335
// i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
336
// i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
337
// i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
338
// double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
339
// float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
340
// half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
341
// clang-format on
342
343
0
bool RVVType::verifyType() const {
344
0
  if (ScalarType == Invalid)
345
0
    return false;
346
0
  if (isScalar())
347
0
    return true;
348
0
  if (!Scale.hasValue())
349
0
    return false;
350
0
  if (isFloat() && ElementBitwidth == 8)
351
0
    return false;
352
0
  unsigned V = Scale.getValue();
353
0
  switch (ElementBitwidth) {
354
0
  case 1:
355
0
  case 8:
356
    // Check Scale is 1,2,4,8,16,32,64
357
0
    return (V <= 64 && isPowerOf2_32(V));
358
0
  case 16:
359
    // Check Scale is 1,2,4,8,16,32
360
0
    return (V <= 32 && isPowerOf2_32(V));
361
0
  case 32:
362
    // Check Scale is 1,2,4,8,16
363
0
    return (V <= 16 && isPowerOf2_32(V));
364
0
  case 64:
365
    // Check Scale is 1,2,4,8
366
0
    return (V <= 8 && isPowerOf2_32(V));
367
0
  }
368
0
  return false;
369
0
}
370
371
0
void RVVType::initBuiltinStr() {
372
0
  assert(isValid() && "RVVType is invalid");
373
0
  switch (ScalarType) {
374
0
  case ScalarTypeKind::Void:
375
0
    BuiltinStr = "v";
376
0
    return;
377
0
  case ScalarTypeKind::Size_t:
378
0
    BuiltinStr = "z";
379
0
    if (IsImmediate)
380
0
      BuiltinStr = "I" + BuiltinStr;
381
0
    if (IsPointer)
382
0
      BuiltinStr += "*";
383
0
    return;
384
0
  case ScalarTypeKind::Ptrdiff_t:
385
0
    BuiltinStr = "Y";
386
0
    return;
387
0
  case ScalarTypeKind::UnsignedLong:
388
0
    BuiltinStr = "ULi";
389
0
    return;
390
0
  case ScalarTypeKind::SignedLong:
391
0
    BuiltinStr = "Li";
392
0
    return;
393
0
  case ScalarTypeKind::Boolean:
394
0
    assert(ElementBitwidth == 1);
395
0
    BuiltinStr += "b";
396
0
    break;
397
0
  case ScalarTypeKind::SignedInteger:
398
0
  case ScalarTypeKind::UnsignedInteger:
399
0
    switch (ElementBitwidth) {
400
0
    case 8:
401
0
      BuiltinStr += "c";
402
0
      break;
403
0
    case 16:
404
0
      BuiltinStr += "s";
405
0
      break;
406
0
    case 32:
407
0
      BuiltinStr += "i";
408
0
      break;
409
0
    case 64:
410
0
      BuiltinStr += "Wi";
411
0
      break;
412
0
    default:
413
0
      llvm_unreachable("Unhandled ElementBitwidth!");
414
0
    }
415
0
    if (isSignedInteger())
416
0
      BuiltinStr = "S" + BuiltinStr;
417
0
    else
418
0
      BuiltinStr = "U" + BuiltinStr;
419
0
    break;
420
0
  case ScalarTypeKind::Float:
421
0
    switch (ElementBitwidth) {
422
0
    case 16:
423
0
      BuiltinStr += "x";
424
0
      break;
425
0
    case 32:
426
0
      BuiltinStr += "f";
427
0
      break;
428
0
    case 64:
429
0
      BuiltinStr += "d";
430
0
      break;
431
0
    default:
432
0
      llvm_unreachable("Unhandled ElementBitwidth!");
433
0
    }
434
0
    break;
435
0
  default:
436
0
    llvm_unreachable("ScalarType is invalid!");
437
0
  }
438
0
  if (IsImmediate)
439
0
    BuiltinStr = "I" + BuiltinStr;
440
0
  if (isScalar()) {
441
0
    if (IsConstant)
442
0
      BuiltinStr += "C";
443
0
    if (IsPointer)
444
0
      BuiltinStr += "*";
445
0
    return;
446
0
  }
447
0
  BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
448
  // Pointer to vector types. Defined for Zvlsseg load intrinsics.
449
  // Zvlsseg load intrinsics have pointer type arguments to store the loaded
450
  // vector values.
451
0
  if (IsPointer)
452
0
    BuiltinStr += "*";
453
0
}
454
455
0
void RVVType::initClangBuiltinStr() {
456
0
  assert(isValid() && "RVVType is invalid");
457
0
  assert(isVector() && "Handle Vector type only");
458
459
0
  ClangBuiltinStr = "__rvv_";
460
0
  switch (ScalarType) {
461
0
  case ScalarTypeKind::Boolean:
462
0
    ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
463
0
    return;
464
0
  case ScalarTypeKind::Float:
465
0
    ClangBuiltinStr += "float";
466
0
    break;
467
0
  case ScalarTypeKind::SignedInteger:
468
0
    ClangBuiltinStr += "int";
469
0
    break;
470
0
  case ScalarTypeKind::UnsignedInteger:
471
0
    ClangBuiltinStr += "uint";
472
0
    break;
473
0
  default:
474
0
    llvm_unreachable("ScalarTypeKind is invalid");
475
0
  }
476
0
  ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
477
0
}
478
479
0
void RVVType::initTypeStr() {
480
0
  assert(isValid() && "RVVType is invalid");
481
482
0
  if (IsConstant)
483
0
    Str += "const ";
484
485
0
  auto getTypeString = [&](StringRef TypeStr) {
486
0
    if (isScalar())
487
0
      return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
488
0
    return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
489
0
        .str();
490
0
  };
491
492
0
  switch (ScalarType) {
493
0
  case ScalarTypeKind::Void:
494
0
    Str = "void";
495
0
    return;
496
0
  case ScalarTypeKind::Size_t:
497
0
    Str = "size_t";
498
0
    if (IsPointer)
499
0
      Str += " *";
500
0
    return;
501
0
  case ScalarTypeKind::Ptrdiff_t:
502
0
    Str = "ptrdiff_t";
503
0
    return;
504
0
  case ScalarTypeKind::UnsignedLong:
505
0
    Str = "unsigned long";
506
0
    return;
507
0
  case ScalarTypeKind::SignedLong:
508
0
    Str = "long";
509
0
    return;
510
0
  case ScalarTypeKind::Boolean:
511
0
    if (isScalar())
512
0
      Str += "bool";
513
0
    else
514
      // Vector bool is special case, the formulate is
515
      // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
516
0
      Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
517
0
    break;
518
0
  case ScalarTypeKind::Float:
519
0
    if (isScalar()) {
520
0
      if (ElementBitwidth == 64)
521
0
        Str += "double";
522
0
      else if (ElementBitwidth == 32)
523
0
        Str += "float";
524
0
      else if (ElementBitwidth == 16)
525
0
        Str += "_Float16";
526
0
      else
527
0
        llvm_unreachable("Unhandled floating type.");
528
0
    } else
529
0
      Str += getTypeString("float");
530
0
    break;
531
0
  case ScalarTypeKind::SignedInteger:
532
0
    Str += getTypeString("int");
533
0
    break;
534
0
  case ScalarTypeKind::UnsignedInteger:
535
0
    Str += getTypeString("uint");
536
0
    break;
537
0
  default:
538
0
    llvm_unreachable("ScalarType is invalid!");
539
0
  }
540
0
  if (IsPointer)
541
0
    Str += " *";
542
0
}
543
544
0
void RVVType::initShortStr() {
545
0
  switch (ScalarType) {
546
0
  case ScalarTypeKind::Boolean:
547
0
    assert(isVector());
548
0
    ShortStr = "b" + utostr(64 / Scale.getValue());
549
0
    return;
550
0
  case ScalarTypeKind::Float:
551
0
    ShortStr = "f" + utostr(ElementBitwidth);
552
0
    break;
553
0
  case ScalarTypeKind::SignedInteger:
554
0
    ShortStr = "i" + utostr(ElementBitwidth);
555
0
    break;
556
0
  case ScalarTypeKind::UnsignedInteger:
557
0
    ShortStr = "u" + utostr(ElementBitwidth);
558
0
    break;
559
0
  default:
560
0
    PrintFatalError("Unhandled case!");
561
0
  }
562
0
  if (isVector())
563
0
    ShortStr += LMUL.str();
564
0
}
565
566
0
void RVVType::applyBasicType() {
567
0
  switch (BT) {
568
0
  case 'c':
569
0
    ElementBitwidth = 8;
570
0
    ScalarType = ScalarTypeKind::SignedInteger;
571
0
    break;
572
0
  case 's':
573
0
    ElementBitwidth = 16;
574
0
    ScalarType = ScalarTypeKind::SignedInteger;
575
0
    break;
576
0
  case 'i':
577
0
    ElementBitwidth = 32;
578
0
    ScalarType = ScalarTypeKind::SignedInteger;
579
0
    break;
580
0
  case 'l':
581
0
    ElementBitwidth = 64;
582
0
    ScalarType = ScalarTypeKind::SignedInteger;
583
0
    break;
584
0
  case 'x':
585
0
    ElementBitwidth = 16;
586
0
    ScalarType = ScalarTypeKind::Float;
587
0
    break;
588
0
  case 'f':
589
0
    ElementBitwidth = 32;
590
0
    ScalarType = ScalarTypeKind::Float;
591
0
    break;
592
0
  case 'd':
593
0
    ElementBitwidth = 64;
594
0
    ScalarType = ScalarTypeKind::Float;
595
0
    break;
596
0
  default:
597
0
    PrintFatalError("Unhandled type code!");
598
0
  }
599
0
  assert(ElementBitwidth != 0 && "Bad element bitwidth!");
600
0
}
601
602
0
void RVVType::applyModifier(StringRef Transformer) {
603
0
  if (Transformer.empty())
604
0
    return;
605
  // Handle primitive type transformer
606
0
  auto PType = Transformer.back();
607
0
  switch (PType) {
608
0
  case 'e':
609
0
    Scale = 0;
610
0
    break;
611
0
  case 'v':
612
0
    Scale = LMUL.getScale(ElementBitwidth);
613
0
    break;
614
0
  case 'w':
615
0
    ElementBitwidth *= 2;
616
0
    LMUL *= 2;
617
0
    Scale = LMUL.getScale(ElementBitwidth);
618
0
    break;
619
0
  case 'q':
620
0
    ElementBitwidth *= 4;
621
0
    LMUL *= 4;
622
0
    Scale = LMUL.getScale(ElementBitwidth);
623
0
    break;
624
0
  case 'o':
625
0
    ElementBitwidth *= 8;
626
0
    LMUL *= 8;
627
0
    Scale = LMUL.getScale(ElementBitwidth);
628
0
    break;
629
0
  case 'm':
630
0
    ScalarType = ScalarTypeKind::Boolean;
631
0
    Scale = LMUL.getScale(ElementBitwidth);
632
0
    ElementBitwidth = 1;
633
0
    break;
634
0
  case '0':
635
0
    ScalarType = ScalarTypeKind::Void;
636
0
    break;
637
0
  case 'z':
638
0
    ScalarType = ScalarTypeKind::Size_t;
639
0
    break;
640
0
  case 't':
641
0
    ScalarType = ScalarTypeKind::Ptrdiff_t;
642
0
    break;
643
0
  case 'u':
644
0
    ScalarType = ScalarTypeKind::UnsignedLong;
645
0
    break;
646
0
  case 'l':
647
0
    ScalarType = ScalarTypeKind::SignedLong;
648
0
    break;
649
0
  default:
650
0
    PrintFatalError("Illegal primitive type transformers!");
651
0
  }
652
0
  Transformer = Transformer.drop_back();
653
654
  // Extract and compute complex type transformer. It can only appear one time.
655
0
  if (Transformer.startswith("(")) {
656
0
    size_t Idx = Transformer.find(')');
657
0
    assert(Idx != StringRef::npos);
658
0
    StringRef ComplexType = Transformer.slice(1, Idx);
659
0
    Transformer = Transformer.drop_front(Idx + 1);
660
0
    assert(!Transformer.contains('(') &&
661
0
           "Only allow one complex type transformer");
662
663
0
    auto UpdateAndCheckComplexProto = [&]() {
664
0
      Scale = LMUL.getScale(ElementBitwidth);
665
0
      const StringRef VectorPrototypes("vwqom");
666
0
      if (!VectorPrototypes.contains(PType))
667
0
        PrintFatalError("Complex type transformer only supports vector type!");
668
0
      if (Transformer.find_first_of("PCKWS") != StringRef::npos)
669
0
        PrintFatalError(
670
0
            "Illegal type transformer for Complex type transformer");
671
0
    };
672
0
    auto ComputeFixedLog2LMUL =
673
0
        [&](StringRef Value,
674
0
            std::function<bool(const int32_t &, const int32_t &)> Compare) {
675
0
          int32_t Log2LMUL;
676
0
          Value.getAsInteger(10, Log2LMUL);
677
0
          if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
678
0
            ScalarType = Invalid;
679
0
            return false;
680
0
          }
681
          // Update new LMUL
682
0
          LMUL = LMULType(Log2LMUL);
683
0
          UpdateAndCheckComplexProto();
684
0
          return true;
685
0
        };
686
0
    auto ComplexTT = ComplexType.split(":");
687
0
    if (ComplexTT.first == "Log2EEW") {
688
0
      uint32_t Log2EEW;
689
0
      ComplexTT.second.getAsInteger(10, Log2EEW);
690
      // update new elmul = (eew/sew) * lmul
691
0
      LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
692
      // update new eew
693
0
      ElementBitwidth = 1 << Log2EEW;
694
0
      ScalarType = ScalarTypeKind::SignedInteger;
695
0
      UpdateAndCheckComplexProto();
696
0
    } else if (ComplexTT.first == "FixedSEW") {
697
0
      uint32_t NewSEW;
698
0
      ComplexTT.second.getAsInteger(10, NewSEW);
699
      // Set invalid type if src and dst SEW are same.
700
0
      if (ElementBitwidth == NewSEW) {
701
0
        ScalarType = Invalid;
702
0
        return;
703
0
      }
704
      // Update new SEW
705
0
      ElementBitwidth = NewSEW;
706
0
      UpdateAndCheckComplexProto();
707
0
    } else if (ComplexTT.first == "LFixedLog2LMUL") {
708
      // New LMUL should be larger than old
709
0
      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
710
0
        return;
711
0
    } else if (ComplexTT.first == "SFixedLog2LMUL") {
712
      // New LMUL should be smaller than old
713
0
      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
714
0
        return;
715
0
    } else {
716
0
      PrintFatalError("Illegal complex type transformers!");
717
0
    }
718
0
  }
719
720
  // Compute the remain type transformers
721
0
  for (char I : Transformer) {
722
0
    switch (I) {
723
0
    case 'P':
724
0
      if (IsConstant)
725
0
        PrintFatalError("'P' transformer cannot be used after 'C'");
726
0
      if (IsPointer)
727
0
        PrintFatalError("'P' transformer cannot be used twice");
728
0
      IsPointer = true;
729
0
      break;
730
0
    case 'C':
731
0
      if (IsConstant)
732
0
        PrintFatalError("'C' transformer cannot be used twice");
733
0
      IsConstant = true;
734
0
      break;
735
0
    case 'K':
736
0
      IsImmediate = true;
737
0
      break;
738
0
    case 'U':
739
0
      ScalarType = ScalarTypeKind::UnsignedInteger;
740
0
      break;
741
0
    case 'I':
742
0
      ScalarType = ScalarTypeKind::SignedInteger;
743
0
      break;
744
0
    case 'F':
745
0
      ScalarType = ScalarTypeKind::Float;
746
0
      break;
747
0
    case 'S':
748
0
      LMUL = LMULType(0);
749
      // Update ElementBitwidth need to update Scale too.
750
0
      Scale = LMUL.getScale(ElementBitwidth);
751
0
      break;
752
0
    default:
753
0
      PrintFatalError("Illegal non-primitive type transformer!");
754
0
    }
755
0
  }
756
0
}
757
758
//===----------------------------------------------------------------------===//
759
// RVVIntrinsic implementation
760
//===----------------------------------------------------------------------===//
761
RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
762
                           StringRef NewMangledName, StringRef MangledSuffix,
763
                           StringRef IRName, bool IsMask,
764
                           bool HasMaskedOffOperand, bool HasVL, bool HasPolicy,
765
                           bool HasNoMaskedOverloaded, bool HasAutoDef,
766
                           StringRef ManualCodegen, const RVVTypes &OutInTypes,
767
                           const std::vector<int64_t> &NewIntrinsicTypes,
768
                           const std::vector<StringRef> &RequiredExtensions,
769
                           unsigned NF)
770
    : IRName(IRName), IsMask(IsMask), HasVL(HasVL), HasPolicy(HasPolicy),
771
      HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef),
772
0
      ManualCodegen(ManualCodegen.str()), NF(NF) {
773
774
  // Init BuiltinName, Name and MangledName
775
0
  BuiltinName = NewName.str();
776
0
  Name = BuiltinName;
777
0
  if (NewMangledName.empty())
778
0
    MangledName = NewName.split("_").first.str();
779
0
  else
780
0
    MangledName = NewMangledName.str();
781
0
  if (!Suffix.empty())
782
0
    Name += "_" + Suffix.str();
783
0
  if (!MangledSuffix.empty())
784
0
    MangledName += "_" + MangledSuffix.str();
785
0
  if (IsMask) {
786
0
    BuiltinName += "_m";
787
0
    Name += "_m";
788
0
  }
789
790
  // Init RISC-V extensions
791
0
  for (const auto &T : OutInTypes) {
792
0
    if (T->isFloatVector(16) || T->isFloat(16))
793
0
      RISCVExtensions |= RISCVExtension::Zfh;
794
0
    else if (T->isFloatVector(32) || T->isFloat(32))
795
0
      RISCVExtensions |= RISCVExtension::F;
796
0
    else if (T->isFloatVector(64) || T->isFloat(64))
797
0
      RISCVExtensions |= RISCVExtension::D;
798
0
  }
799
0
  for (auto Extension : RequiredExtensions) {
800
0
    if (Extension == "Zvlsseg")
801
0
      RISCVExtensions |= RISCVExtension::Zvlsseg;
802
0
    if (Extension == "RV64")
803
0
      RISCVExtensions |= RISCVExtension::RV64;
804
0
  }
805
806
  // Init OutputType and InputTypes
807
0
  OutputType = OutInTypes[0];
808
0
  InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
809
810
  // IntrinsicTypes is nonmasked version index. Need to update it
811
  // if there is maskedoff operand (It is always in first operand).
812
0
  IntrinsicTypes = NewIntrinsicTypes;
813
0
  if (IsMask && HasMaskedOffOperand) {
814
0
    for (auto &I : IntrinsicTypes) {
815
0
      if (I >= 0)
816
0
        I += NF;
817
0
    }
818
0
  }
819
0
}
820
821
0
std::string RVVIntrinsic::getBuiltinTypeStr() const {
822
0
  std::string S;
823
0
  S += OutputType->getBuiltinStr();
824
0
  for (const auto &T : InputTypes) {
825
0
    S += T->getBuiltinStr();
826
0
  }
827
0
  return S;
828
0
}
829
830
0
void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
831
0
  if (!getIRName().empty())
832
0
    OS << "  ID = Intrinsic::riscv_" + getIRName() + ";\n";
833
0
  if (NF >= 2)
834
0
    OS << "  NF = " + utostr(getNF()) + ";\n";
835
0
  if (hasManualCodegen()) {
836
0
    OS << ManualCodegen;
837
0
    OS << "break;\n";
838
0
    return;
839
0
  }
840
841
0
  if (isMask()) {
842
0
    if (hasVL()) {
843
0
      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
844
0
      if (hasPolicy())
845
0
        OS << "  Ops.push_back(ConstantInt::get(Ops.back()->getType(),"
846
0
                               " TAIL_UNDISTURBED));\n";
847
0
    } else {
848
0
      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
849
0
    }
850
0
  }
851
852
0
  OS << "  IntrinsicTypes = {";
853
0
  ListSeparator LS;
854
0
  for (const auto &Idx : IntrinsicTypes) {
855
0
    if (Idx == -1)
856
0
      OS << LS << "ResultType";
857
0
    else
858
0
      OS << LS << "Ops[" << Idx << "]->getType()";
859
0
  }
860
861
  // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
862
  // always last operand.
863
0
  if (hasVL())
864
0
    OS << ", Ops.back()->getType()";
865
0
  OS << "};\n";
866
0
  OS << "  break;\n";
867
0
}
868
869
0
void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const {
870
0
  OS << "__attribute__((__clang_builtin_alias__(";
871
0
  OS << "__builtin_rvv_" << getBuiltinName() << ")))\n";
872
0
  OS << OutputType->getTypeStr() << " " << getName() << "(";
873
  // Emit function arguments
874
0
  if (!InputTypes.empty()) {
875
0
    ListSeparator LS;
876
0
    for (unsigned i = 0; i < InputTypes.size(); ++i)
877
0
      OS << LS << InputTypes[i]->getTypeStr();
878
0
  }
879
0
  OS << ");\n";
880
0
}
881
882
0
void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
883
0
  OS << "__attribute__((__clang_builtin_alias__(";
884
0
  OS << "__builtin_rvv_" << getBuiltinName() << ")))\n";
885
0
  OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
886
  // Emit function arguments
887
0
  if (!InputTypes.empty()) {
888
0
    ListSeparator LS;
889
0
    for (unsigned i = 0; i < InputTypes.size(); ++i)
890
0
      OS << LS << InputTypes[i]->getTypeStr();
891
0
  }
892
0
  OS << ");\n";
893
0
}
894
895
//===----------------------------------------------------------------------===//
896
// RVVEmitter implementation
897
//===----------------------------------------------------------------------===//
898
0
void RVVEmitter::createHeader(raw_ostream &OS) {
899
900
0
  OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics "
901
0
        "-------------------===\n"
902
0
        " *\n"
903
0
        " *\n"
904
0
        " * Part of the LLVM Project, under the Apache License v2.0 with LLVM "
905
0
        "Exceptions.\n"
906
0
        " * See https://llvm.org/LICENSE.txt for license information.\n"
907
0
        " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n"
908
0
        " *\n"
909
0
        " *===-----------------------------------------------------------------"
910
0
        "------===\n"
911
0
        " */\n\n";
912
913
0
  OS << "#ifndef __RISCV_VECTOR_H\n";
914
0
  OS << "#define __RISCV_VECTOR_H\n\n";
915
916
0
  OS << "#include <stdint.h>\n";
917
0
  OS << "#include <stddef.h>\n\n";
918
919
0
  OS << "#ifndef __riscv_vector\n";
920
0
  OS << "#error \"Vector intrinsics require the vector extension.\"\n";
921
0
  OS << "#endif\n\n";
922
923
0
  OS << "#ifdef __cplusplus\n";
924
0
  OS << "extern \"C\" {\n";
925
0
  OS << "#endif\n\n";
926
927
0
  createRVVHeaders(OS);
928
929
0
  std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
930
0
  createRVVIntrinsics(Defs);
931
932
  // Print header code
933
0
  if (!HeaderCode.empty()) {
934
0
    OS << HeaderCode;
935
0
  }
936
937
0
  auto printType = [&](auto T) {
938
0
    OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr()
939
0
       << ";\n";
940
0
  };
941
942
0
  constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
943
  // Print RVV boolean types.
944
0
  for (int Log2LMUL : Log2LMULs) {
945
0
    auto T = computeType('c', Log2LMUL, "m");
946
0
    if (T.hasValue())
947
0
      printType(T.getValue());
948
0
  }
949
  // Print RVV int/float types.
950
0
  for (char I : StringRef("csil")) {
951
0
    for (int Log2LMUL : Log2LMULs) {
952
0
      auto T = computeType(I, Log2LMUL, "v");
953
0
      if (T.hasValue()) {
954
0
        printType(T.getValue());
955
0
        auto UT = computeType(I, Log2LMUL, "Uv");
956
0
        printType(UT.getValue());
957
0
      }
958
0
    }
959
0
  }
960
0
  OS << "#if defined(__riscv_zfh)\n";
961
0
  for (int Log2LMUL : Log2LMULs) {
962
0
    auto T = computeType('x', Log2LMUL, "v");
963
0
    if (T.hasValue())
964
0
      printType(T.getValue());
965
0
  }
966
0
  OS << "#endif\n";
967
968
0
  OS << "#if defined(__riscv_f)\n";
969
0
  for (int Log2LMUL : Log2LMULs) {
970
0
    auto T = computeType('f', Log2LMUL, "v");
971
0
    if (T.hasValue())
972
0
      printType(T.getValue());
973
0
  }
974
0
  OS << "#endif\n";
975
976
0
  OS << "#if defined(__riscv_d)\n";
977
0
  for (int Log2LMUL : Log2LMULs) {
978
0
    auto T = computeType('d', Log2LMUL, "v");
979
0
    if (T.hasValue())
980
0
      printType(T.getValue());
981
0
  }
982
0
  OS << "#endif\n\n";
983
984
  // The same extension include in the same arch guard marco.
985
0
  llvm::stable_sort(Defs, [](const std::unique_ptr<RVVIntrinsic> &A,
986
0
                             const std::unique_ptr<RVVIntrinsic> &B) {
987
0
    return A->getRISCVExtensions() < B->getRISCVExtensions();
988
0
  });
989
990
0
  OS << "#define __rvv_ai static __inline__\n";
991
992
  // Print intrinsic functions with macro
993
0
  emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
994
0
    OS << "__rvv_ai ";
995
0
    Inst.emitIntrinsicFuncDef(OS);
996
0
  });
997
998
0
  OS << "#undef __rvv_ai\n\n";
999
1000
0
  OS << "#define __riscv_v_intrinsic_overloading 1\n";
1001
1002
  // Print Overloaded APIs
1003
0
  OS << "#define __rvv_aio static __inline__ "
1004
0
        "__attribute__((__overloadable__))\n";
1005
1006
0
  emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
1007
0
    if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded())
1008
0
      return;
1009
0
    OS << "__rvv_aio ";
1010
0
    Inst.emitMangledFuncDef(OS);
1011
0
  });
1012
1013
0
  OS << "#undef __rvv_aio\n";
1014
1015
0
  OS << "\n#ifdef __cplusplus\n";
1016
0
  OS << "}\n";
1017
0
  OS << "#endif // __cplusplus\n";
1018
0
  OS << "#endif // __RISCV_VECTOR_H\n";
1019
0
}
1020
1021
0
void RVVEmitter::createBuiltins(raw_ostream &OS) {
1022
0
  std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1023
0
  createRVVIntrinsics(Defs);
1024
1025
  // Map to keep track of which builtin names have already been emitted.
1026
0
  StringMap<RVVIntrinsic *> BuiltinMap;
1027
1028
0
  OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n";
1029
0
  OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, "
1030
0
        "ATTRS, \"experimental-v\")\n";
1031
0
  OS << "#endif\n";
1032
0
  for (auto &Def : Defs) {
1033
0
    auto P =
1034
0
        BuiltinMap.insert(std::make_pair(Def->getBuiltinName(), Def.get()));
1035
0
    if (!P.second) {
1036
      // Verify that this would have produced the same builtin definition.
1037
0
      if (P.first->second->hasAutoDef() != Def->hasAutoDef()) {
1038
0
        PrintFatalError("Builtin with same name has different hasAutoDef");
1039
0
      } else if (!Def->hasAutoDef() && P.first->second->getBuiltinTypeStr() !=
1040
0
                                           Def->getBuiltinTypeStr()) {
1041
0
        PrintFatalError("Builtin with same name has different type string");
1042
0
      }
1043
0
      continue;
1044
0
    }
1045
1046
0
    OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getBuiltinName() << ",\"";
1047
0
    if (!Def->hasAutoDef())
1048
0
      OS << Def->getBuiltinTypeStr();
1049
0
    OS << "\", \"n\")\n";
1050
0
  }
1051
0
  OS << "#undef RISCVV_BUILTIN\n";
1052
0
}
1053
1054
0
void RVVEmitter::createCodeGen(raw_ostream &OS) {
1055
0
  std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1056
0
  createRVVIntrinsics(Defs);
1057
  // IR name could be empty, use the stable sort preserves the relative order.
1058
0
  llvm::stable_sort(Defs, [](const std::unique_ptr<RVVIntrinsic> &A,
1059
0
                             const std::unique_ptr<RVVIntrinsic> &B) {
1060
0
    return A->getIRName() < B->getIRName();
1061
0
  });
1062
1063
  // Map to keep track of which builtin names have already been emitted.
1064
0
  StringMap<RVVIntrinsic *> BuiltinMap;
1065
1066
  // Print switch body when the ir name or ManualCodegen changes from previous
1067
  // iteration.
1068
0
  RVVIntrinsic *PrevDef = Defs.begin()->get();
1069
0
  for (auto &Def : Defs) {
1070
0
    StringRef CurIRName = Def->getIRName();
1071
0
    if (CurIRName != PrevDef->getIRName() ||
1072
0
        (Def->getManualCodegen() != PrevDef->getManualCodegen())) {
1073
0
      PrevDef->emitCodeGenSwitchBody(OS);
1074
0
    }
1075
0
    PrevDef = Def.get();
1076
1077
0
    auto P =
1078
0
        BuiltinMap.insert(std::make_pair(Def->getBuiltinName(), Def.get()));
1079
0
    if (P.second) {
1080
0
      OS << "case RISCVVector::BI__builtin_rvv_" << Def->getBuiltinName()
1081
0
         << ":\n";
1082
0
      continue;
1083
0
    }
1084
1085
0
    if (P.first->second->getIRName() != Def->getIRName())
1086
0
      PrintFatalError("Builtin with same name has different IRName");
1087
0
    else if (P.first->second->getManualCodegen() != Def->getManualCodegen())
1088
0
      PrintFatalError("Builtin with same name has different ManualCodegen");
1089
0
    else if (P.first->second->getNF() != Def->getNF())
1090
0
      PrintFatalError("Builtin with same name has different NF");
1091
0
    else if (P.first->second->isMask() != Def->isMask())
1092
0
      PrintFatalError("Builtin with same name has different isMask");
1093
0
    else if (P.first->second->hasVL() != Def->hasVL())
1094
0
      PrintFatalError("Builtin with same name has different HasPolicy");
1095
0
    else if (P.first->second->hasPolicy() != Def->hasPolicy())
1096
0
      PrintFatalError("Builtin with same name has different HasPolicy");
1097
0
    else if (P.first->second->getIntrinsicTypes() != Def->getIntrinsicTypes())
1098
0
      PrintFatalError("Builtin with same name has different IntrinsicTypes");
1099
0
  }
1100
0
  Defs.back()->emitCodeGenSwitchBody(OS);
1101
0
  OS << "\n";
1102
0
}
1103
1104
void RVVEmitter::parsePrototypes(StringRef Prototypes,
1105
0
                                 std::function<void(StringRef)> Handler) {
1106
0
  const StringRef Primaries("evwqom0ztul");
1107
0
  while (!Prototypes.empty()) {
1108
0
    size_t Idx = 0;
1109
    // Skip over complex prototype because it could contain primitive type
1110
    // character.
1111
0
    if (Prototypes[0] == '(')
1112
0
      Idx = Prototypes.find_first_of(')');
1113
0
    Idx = Prototypes.find_first_of(Primaries, Idx);
1114
0
    assert(Idx != StringRef::npos);
1115
0
    Handler(Prototypes.slice(0, Idx + 1));
1116
0
    Prototypes = Prototypes.drop_front(Idx + 1);
1117
0
  }
1118
0
}
1119
1120
std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
1121
0
                                     StringRef Prototypes) {
1122
0
  SmallVector<std::string> SuffixStrs;
1123
0
  parsePrototypes(Prototypes, [&](StringRef Proto) {
1124
0
    auto T = computeType(Type, Log2LMUL, Proto);
1125
0
    SuffixStrs.push_back(T.getValue()->getShortStr());
1126
0
  });
1127
0
  return join(SuffixStrs, "_");
1128
0
}
1129
1130
void RVVEmitter::createRVVIntrinsics(
1131
0
    std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
1132
0
  std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
1133
0
  for (auto *R : RV) {
1134
0
    StringRef Name = R->getValueAsString("Name");
1135
0
    StringRef SuffixProto = R->getValueAsString("Suffix");
1136
0
    StringRef MangledName = R->getValueAsString("MangledName");
1137
0
    StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix");
1138
0
    StringRef Prototypes = R->getValueAsString("Prototype");
1139
0
    StringRef TypeRange = R->getValueAsString("TypeRange");
1140
0
    bool HasMask = R->getValueAsBit("HasMask");
1141
0
    bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand");
1142
0
    bool HasVL = R->getValueAsBit("HasVL");
1143
0
    bool HasPolicy = R->getValueAsBit("HasPolicy");
1144
0
    bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded");
1145
0
    std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL");
1146
0
    StringRef ManualCodegen = R->getValueAsString("ManualCodegen");
1147
0
    StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask");
1148
0
    std::vector<int64_t> IntrinsicTypes =
1149
0
        R->getValueAsListOfInts("IntrinsicTypes");
1150
0
    std::vector<StringRef> RequiredExtensions =
1151
0
        R->getValueAsListOfStrings("RequiredExtensions");
1152
0
    StringRef IRName = R->getValueAsString("IRName");
1153
0
    StringRef IRNameMask = R->getValueAsString("IRNameMask");
1154
0
    unsigned NF = R->getValueAsInt("NF");
1155
1156
0
    StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1157
0
    bool HasAutoDef = HeaderCodeStr.empty();
1158
0
    if (!HeaderCodeStr.empty()) {
1159
0
      HeaderCode += HeaderCodeStr.str();
1160
0
    }
1161
    // Parse prototype and create a list of primitive type with transformers
1162
    // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
1163
0
    SmallVector<std::string> ProtoSeq;
1164
0
    parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
1165
0
      ProtoSeq.push_back(Proto.str());
1166
0
    });
1167
1168
    // Compute Builtin types
1169
0
    SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
1170
0
    if (HasMask) {
1171
      // If HasMaskedOffOperand, insert result type as first input operand.
1172
0
      if (HasMaskedOffOperand) {
1173
0
        if (NF == 1) {
1174
0
          ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]);
1175
0
        } else {
1176
          // Convert
1177
          // (void, op0 address, op1 address, ...)
1178
          // to
1179
          // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1180
0
          for (unsigned I = 0; I < NF; ++I)
1181
0
            ProtoMaskSeq.insert(
1182
0
                ProtoMaskSeq.begin() + NF + 1,
1183
0
                ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
1184
0
        }
1185
0
      }
1186
0
      if (HasMaskedOffOperand && NF > 1) {
1187
        // Convert
1188
        // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1189
        // to
1190
        // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1191
        // ...)
1192
0
        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
1193
0
      } else {
1194
        // If HasMask, insert 'm' as first input operand.
1195
0
        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
1196
0
      }
1197
0
    }
1198
    // If HasVL, append 'z' to last operand
1199
0
    if (HasVL) {
1200
0
      ProtoSeq.push_back("z");
1201
0
      ProtoMaskSeq.push_back("z");
1202
0
    }
1203
1204
    // Create Intrinsics for each type and LMUL.
1205
0
    for (char I : TypeRange) {
1206
0
      for (int Log2LMUL : Log2LMULList) {
1207
0
        Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
1208
        // Ignored to create new intrinsic if there are any illegal types.
1209
0
        if (!Types.hasValue())
1210
0
          continue;
1211
1212
0
        auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
1213
0
        auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
1214
        // Create a non-mask intrinsic
1215
0
        Out.push_back(std::make_unique<RVVIntrinsic>(
1216
0
            Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
1217
0
            /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, HasPolicy,
1218
0
            HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(),
1219
0
            IntrinsicTypes, RequiredExtensions, NF));
1220
0
        if (HasMask) {
1221
          // Create a mask intrinsic
1222
0
          Optional<RVVTypes> MaskTypes =
1223
0
              computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
1224
0
          Out.push_back(std::make_unique<RVVIntrinsic>(
1225
0
              Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1226
0
              /*IsMask=*/true, HasMaskedOffOperand, HasVL, HasPolicy,
1227
0
              HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask,
1228
0
              MaskTypes.getValue(), IntrinsicTypes, RequiredExtensions, NF));
1229
0
        }
1230
0
      } // end for Log2LMULList
1231
0
    }   // end for TypeRange
1232
0
  }
1233
0
}
1234
1235
0
void RVVEmitter::createRVVHeaders(raw_ostream &OS) {
1236
0
  std::vector<Record *> RVVHeaders =
1237
0
      Records.getAllDerivedDefinitions("RVVHeader");
1238
0
  for (auto *R : RVVHeaders) {
1239
0
    StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1240
0
    OS << HeaderCodeStr.str();
1241
0
  }
1242
0
}
1243
1244
Optional<RVVTypes>
1245
RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
1246
0
                         ArrayRef<std::string> PrototypeSeq) {
1247
  // LMUL x NF must be less than or equal to 8.
1248
0
  if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
1249
0
    return llvm::None;
1250
1251
0
  RVVTypes Types;
1252
0
  for (const std::string &Proto : PrototypeSeq) {
1253
0
    auto T = computeType(BT, Log2LMUL, Proto);
1254
0
    if (!T.hasValue())
1255
0
      return llvm::None;
1256
    // Record legal type index
1257
0
    Types.push_back(T.getValue());
1258
0
  }
1259
0
  return Types;
1260
0
}
1261
1262
Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
1263
0
                                             StringRef Proto) {
1264
0
  std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
1265
  // Search first
1266
0
  auto It = LegalTypes.find(Idx);
1267
0
  if (It != LegalTypes.end())
1268
0
    return &(It->second);
1269
0
  if (IllegalTypes.count(Idx))
1270
0
    return llvm::None;
1271
  // Compute type and record the result.
1272
0
  RVVType T(BT, Log2LMUL, Proto);
1273
0
  if (T.isValid()) {
1274
    // Record legal type index and value.
1275
0
    LegalTypes.insert({Idx, T});
1276
0
    return &(LegalTypes[Idx]);
1277
0
  }
1278
  // Record illegal type index.
1279
0
  IllegalTypes.insert(Idx);
1280
0
  return llvm::None;
1281
0
}
1282
1283
void RVVEmitter::emitArchMacroAndBody(
1284
    std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
1285
0
    std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
1286
0
  uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions();
1287
0
  bool NeedEndif = emitExtDefStr(PrevExt, OS);
1288
0
  for (auto &Def : Defs) {
1289
0
    uint8_t CurExt = Def->getRISCVExtensions();
1290
0
    if (CurExt != PrevExt) {
1291
0
      if (NeedEndif)
1292
0
        OS << "#endif\n\n";
1293
0
      NeedEndif = emitExtDefStr(CurExt, OS);
1294
0
      PrevExt = CurExt;
1295
0
    }
1296
0
    if (Def->hasAutoDef())
1297
0
      PrintBody(OS, *Def);
1298
0
  }
1299
0
  if (NeedEndif)
1300
0
    OS << "#endif\n\n";
1301
0
}
1302
1303
0
bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) {
1304
0
  if (Extents == RISCVExtension::Basic)
1305
0
    return false;
1306
0
  OS << "#if ";
1307
0
  ListSeparator LS(" && ");
1308
0
  if (Extents & RISCVExtension::F)
1309
0
    OS << LS << "defined(__riscv_f)";
1310
0
  if (Extents & RISCVExtension::D)
1311
0
    OS << LS << "defined(__riscv_d)";
1312
0
  if (Extents & RISCVExtension::Zfh)
1313
0
    OS << LS << "defined(__riscv_zfh)";
1314
0
  if (Extents & RISCVExtension::Zvlsseg)
1315
0
    OS << LS << "defined(__riscv_zvlsseg)";
1316
0
  if (Extents & RISCVExtension::RV64)
1317
0
    OS << LS << "(__riscv_xlen == 64)";
1318
0
  OS << "\n";
1319
0
  return true;
1320
0
}
1321
1322
namespace clang {
1323
0
void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) {
1324
0
  RVVEmitter(Records).createHeader(OS);
1325
0
}
1326
1327
0
void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) {
1328
0
  RVVEmitter(Records).createBuiltins(OS);
1329
0
}
1330
1331
0
void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
1332
0
  RVVEmitter(Records).createCodeGen(OS);
1333
0
}
1334
1335
} // End namespace clang