Coverage Report

Created: 2022-07-16 07:03

/Users/buildslave/jenkins/workspace/coverage/llvm-project/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "OffloadWrapper.h"
10
#include "llvm/ADT/ArrayRef.h"
11
#include "llvm/ADT/Triple.h"
12
#include "llvm/IR/Constants.h"
13
#include "llvm/IR/GlobalVariable.h"
14
#include "llvm/IR/IRBuilder.h"
15
#include "llvm/IR/LLVMContext.h"
16
#include "llvm/IR/Module.h"
17
#include "llvm/Support/Error.h"
18
#include "llvm/Transforms/Utils/ModuleUtils.h"
19
20
using namespace llvm;
21
22
namespace {
23
/// Magic number that begins the section containing the CUDA fatbinary.
24
constexpr unsigned CudaFatMagic = 0x466243b1;
25
constexpr unsigned HIPFatMagic = 0x48495046;
26
27
/// Copied from clang/CGCudaRuntime.h.
28
enum OffloadEntryKindFlag : uint32_t {
29
  /// Mark the entry as a global entry. This indicates the presense of a
30
  /// kernel if the size size field is zero and a variable otherwise.
31
  OffloadGlobalEntry = 0x0,
32
  /// Mark the entry as a managed global variable.
33
  OffloadGlobalManagedEntry = 0x1,
34
  /// Mark the entry as a surface variable.
35
  OffloadGlobalSurfaceEntry = 0x2,
36
  /// Mark the entry as a texture variable.
37
  OffloadGlobalTextureEntry = 0x3,
38
};
39
40
0
IntegerType *getSizeTTy(Module &M) {
41
0
  LLVMContext &C = M.getContext();
42
0
  switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
43
0
  case 4u:
44
0
    return Type::getInt32Ty(C);
45
0
  case 8u:
46
0
    return Type::getInt64Ty(C);
47
0
  }
48
0
  llvm_unreachable("unsupported pointer type size");
49
0
}
50
51
// struct __tgt_offload_entry {
52
//   void *addr;
53
//   char *name;
54
//   size_t size;
55
//   int32_t flags;
56
//   int32_t reserved;
57
// };
58
0
StructType *getEntryTy(Module &M) {
59
0
  LLVMContext &C = M.getContext();
60
0
  StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry");
61
0
  if (!EntryTy)
62
0
    EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
63
0
                                 Type::getInt8PtrTy(C), getSizeTTy(M),
64
0
                                 Type::getInt32Ty(C), Type::getInt32Ty(C));
65
0
  return EntryTy;
66
0
}
67
68
0
PointerType *getEntryPtrTy(Module &M) {
69
0
  return PointerType::getUnqual(getEntryTy(M));
70
0
}
71
72
// struct __tgt_device_image {
73
//   void *ImageStart;
74
//   void *ImageEnd;
75
//   __tgt_offload_entry *EntriesBegin;
76
//   __tgt_offload_entry *EntriesEnd;
77
// };
78
0
StructType *getDeviceImageTy(Module &M) {
79
0
  LLVMContext &C = M.getContext();
80
0
  StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
81
0
  if (!ImageTy)
82
0
    ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
83
0
                                 Type::getInt8PtrTy(C), getEntryPtrTy(M),
84
0
                                 getEntryPtrTy(M));
85
0
  return ImageTy;
86
0
}
87
88
0
PointerType *getDeviceImagePtrTy(Module &M) {
89
0
  return PointerType::getUnqual(getDeviceImageTy(M));
90
0
}
91
92
// struct __tgt_bin_desc {
93
//   int32_t NumDeviceImages;
94
//   __tgt_device_image *DeviceImages;
95
//   __tgt_offload_entry *HostEntriesBegin;
96
//   __tgt_offload_entry *HostEntriesEnd;
97
// };
98
0
StructType *getBinDescTy(Module &M) {
99
0
  LLVMContext &C = M.getContext();
100
0
  StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
101
0
  if (!DescTy)
102
0
    DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
103
0
                                getDeviceImagePtrTy(M), getEntryPtrTy(M),
104
0
                                getEntryPtrTy(M));
105
0
  return DescTy;
106
0
}
107
108
0
PointerType *getBinDescPtrTy(Module &M) {
109
0
  return PointerType::getUnqual(getBinDescTy(M));
110
0
}
111
112
/// Creates binary descriptor for the given device images. Binary descriptor
113
/// is an object that is passed to the offloading runtime at program startup
114
/// and it describes all device images available in the executable or shared
115
/// library. It is defined as follows
116
///
117
/// __attribute__((visibility("hidden")))
118
/// extern __tgt_offload_entry *__start_omp_offloading_entries;
119
/// __attribute__((visibility("hidden")))
120
/// extern __tgt_offload_entry *__stop_omp_offloading_entries;
121
///
122
/// static const char Image0[] = { <Bufs.front() contents> };
123
///  ...
124
/// static const char ImageN[] = { <Bufs.back() contents> };
125
///
126
/// static const __tgt_device_image Images[] = {
127
///   {
128
///     Image0,                            /*ImageStart*/
129
///     Image0 + sizeof(Image0),           /*ImageEnd*/
130
///     __start_omp_offloading_entries,    /*EntriesBegin*/
131
///     __stop_omp_offloading_entries      /*EntriesEnd*/
132
///   },
133
///   ...
134
///   {
135
///     ImageN,                            /*ImageStart*/
136
///     ImageN + sizeof(ImageN),           /*ImageEnd*/
137
///     __start_omp_offloading_entries,    /*EntriesBegin*/
138
///     __stop_omp_offloading_entries      /*EntriesEnd*/
139
///   }
140
/// };
141
///
142
/// static const __tgt_bin_desc BinDesc = {
143
///   sizeof(Images) / sizeof(Images[0]),  /*NumDeviceImages*/
144
///   Images,                              /*DeviceImages*/
145
///   __start_omp_offloading_entries,      /*HostEntriesBegin*/
146
///   __stop_omp_offloading_entries        /*HostEntriesEnd*/
147
/// };
148
///
149
/// Global variable that represents BinDesc is returned.
150
0
GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
151
0
  LLVMContext &C = M.getContext();
152
  // Create external begin/end symbols for the offload entries table.
153
0
  auto *EntriesB = new GlobalVariable(
154
0
      M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
155
0
      /*Initializer*/ nullptr, "__start_omp_offloading_entries");
156
0
  EntriesB->setVisibility(GlobalValue::HiddenVisibility);
157
0
  auto *EntriesE = new GlobalVariable(
158
0
      M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
159
0
      /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
160
0
  EntriesE->setVisibility(GlobalValue::HiddenVisibility);
161
162
  // We assume that external begin/end symbols that we have created above will
163
  // be defined by the linker. But linker will do that only if linker inputs
164
  // have section with "omp_offloading_entries" name which is not guaranteed.
165
  // So, we just create dummy zero sized object in the offload entries section
166
  // to force linker to define those symbols.
167
0
  auto *DummyInit =
168
0
      ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
169
0
  auto *DummyEntry = new GlobalVariable(
170
0
      M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
171
0
      "__dummy.omp_offloading.entry");
172
0
  DummyEntry->setSection("omp_offloading_entries");
173
0
  DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
174
175
0
  auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
176
0
  Constant *ZeroZero[] = {Zero, Zero};
177
178
  // Create initializer for the images array.
179
0
  SmallVector<Constant *, 4u> ImagesInits;
180
0
  ImagesInits.reserve(Bufs.size());
181
0
  for (ArrayRef<char> Buf : Bufs) {
182
0
    auto *Data = ConstantDataArray::get(C, Buf);
183
0
    auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
184
0
                                     GlobalVariable::InternalLinkage, Data,
185
0
                                     ".omp_offloading.device_image");
186
0
    Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
187
188
0
    auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size());
189
0
    Constant *ZeroSize[] = {Zero, Size};
190
191
0
    auto *ImageB =
192
0
        ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero);
193
0
    auto *ImageE =
194
0
        ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
195
196
0
    ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
197
0
                                              ImageE, EntriesB, EntriesE));
198
0
  }
199
200
  // Then create images array.
201
0
  auto *ImagesData = ConstantArray::get(
202
0
      ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
203
204
0
  auto *Images =
205
0
      new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
206
0
                         GlobalValue::InternalLinkage, ImagesData,
207
0
                         ".omp_offloading.device_images");
208
0
  Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
209
210
0
  auto *ImagesB =
211
0
      ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
212
213
  // And finally create the binary descriptor object.
214
0
  auto *DescInit = ConstantStruct::get(
215
0
      getBinDescTy(M),
216
0
      ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
217
0
      EntriesB, EntriesE);
218
219
0
  return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
220
0
                            GlobalValue::InternalLinkage, DescInit,
221
0
                            ".omp_offloading.descriptor");
222
0
}
223
224
0
void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
225
0
  LLVMContext &C = M.getContext();
226
0
  auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
227
0
  auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
228
0
                                ".omp_offloading.descriptor_reg", &M);
229
0
  Func->setSection(".text.startup");
230
231
  // Get __tgt_register_lib function declaration.
232
0
  auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
233
0
                                      /*isVarArg*/ false);
234
0
  FunctionCallee RegFuncC =
235
0
      M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
236
237
  // Construct function body
238
0
  IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
239
0
  Builder.CreateCall(RegFuncC, BinDesc);
240
0
  Builder.CreateRetVoid();
241
242
  // Add this function to constructors.
243
  // Set priority to 1 so that __tgt_register_lib is executed AFTER
244
  // __tgt_register_requires (we want to know what requirements have been
245
  // asked for before we load a libomptarget plugin so that by the time the
246
  // plugin is loaded it can report how many devices there are which can
247
  // satisfy these requirements).
248
0
  appendToGlobalCtors(M, Func, /*Priority*/ 1);
249
0
}
250
251
0
void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
252
0
  LLVMContext &C = M.getContext();
253
0
  auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
254
0
  auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
255
0
                                ".omp_offloading.descriptor_unreg", &M);
256
0
  Func->setSection(".text.startup");
257
258
  // Get __tgt_unregister_lib function declaration.
259
0
  auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
260
0
                                        /*isVarArg*/ false);
261
0
  FunctionCallee UnRegFuncC =
262
0
      M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
263
264
  // Construct function body
265
0
  IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
266
0
  Builder.CreateCall(UnRegFuncC, BinDesc);
267
0
  Builder.CreateRetVoid();
268
269
  // Add this function to global destructors.
270
  // Match priority of __tgt_register_lib
271
0
  appendToGlobalDtors(M, Func, /*Priority*/ 1);
272
0
}
273
274
// struct fatbin_wrapper {
275
//  int32_t magic;
276
//  int32_t version;
277
//  void *image;
278
//  void *reserved;
279
//};
280
0
StructType *getFatbinWrapperTy(Module &M) {
281
0
  LLVMContext &C = M.getContext();
282
0
  StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
283
0
  if (!FatbinTy)
284
0
    FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
285
0
                                  Type::getInt32Ty(C), Type::getInt8PtrTy(C),
286
0
                                  Type::getInt8PtrTy(C));
287
0
  return FatbinTy;
288
0
}
289
290
/// Embed the image \p Image into the module \p M so it can be found by the
291
/// runtime.
292
0
GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
293
0
  LLVMContext &C = M.getContext();
294
0
  llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
295
0
  llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
296
297
  // Create the global string containing the fatbinary.
298
0
  StringRef FatbinConstantSection =
299
0
      IsHIP ? ".hip_fatbin"
300
0
            : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
301
0
  auto *Data = ConstantDataArray::get(C, Image);
302
0
  auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
303
0
                                    GlobalVariable::InternalLinkage, Data,
304
0
                                    ".fatbin_image");
305
0
  Fatbin->setSection(FatbinConstantSection);
306
307
  // Create the fatbinary wrapper
308
0
  StringRef FatbinWrapperSection = IsHIP               ? ".hipFatBinSegment"
309
0
                                   : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
310
0
                                                       : ".nvFatBinSegment";
311
0
  Constant *FatbinWrapper[] = {
312
0
      ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
313
0
      ConstantInt::get(Type::getInt32Ty(C), 1),
314
0
      ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
315
0
      ConstantPointerNull::get(Type::getInt8PtrTy(C))};
316
317
0
  Constant *FatbinInitializer =
318
0
      ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
319
320
0
  auto *FatbinDesc =
321
0
      new GlobalVariable(M, getFatbinWrapperTy(M),
322
0
                         /*isConstant*/ true, GlobalValue::InternalLinkage,
323
0
                         FatbinInitializer, ".fatbin_wrapper");
324
0
  FatbinDesc->setSection(FatbinWrapperSection);
325
0
  FatbinDesc->setAlignment(Align(8));
326
327
  // We create a dummy entry to ensure the linker will define the begin / end
328
  // symbols. The CUDA runtime should ignore the null address if we attempt to
329
  // register it.
330
0
  auto *DummyInit =
331
0
      ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
332
0
  auto *DummyEntry = new GlobalVariable(
333
0
      M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
334
0
      IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
335
0
  DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
336
0
  DummyEntry->setSection(IsHIP ? "hip_offloading_entries"
337
0
                               : "cuda_offloading_entries");
338
339
0
  return FatbinDesc;
340
0
}
341
342
/// Create the register globals function. We will iterate all of the offloading
343
/// entries stored at the begin / end symbols and register them according to
344
/// their type. This creates the following function in IR:
345
///
346
/// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
347
/// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
348
///
349
/// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
350
///                                    void *, void *, void *, void *, int *);
351
/// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
352
///                               int64_t, int32_t, int32_t);
353
///
354
/// void __cudaRegisterTest(void **fatbinHandle) {
355
///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
356
///        entry != &__stop_cuda_offloading_entries; ++entry) {
357
///     if (!entry->size)
358
///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
359
///                              entry->name, -1, 0, 0, 0, 0, 0);
360
///     else
361
///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
362
///                         0, entry->size, 0, 0);
363
///   }
364
/// }
365
0
Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
366
0
  LLVMContext &C = M.getContext();
367
  // Get the __cudaRegisterFunction function declaration.
368
0
  auto *RegFuncTy = FunctionType::get(
369
0
      Type::getInt32Ty(C),
370
0
      {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
371
0
       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
372
0
       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
373
0
       Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
374
0
      /*isVarArg*/ false);
375
0
  FunctionCallee RegFunc = M.getOrInsertFunction(
376
0
      IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
377
378
  // Get the __cudaRegisterVar function declaration.
379
0
  auto *RegVarTy = FunctionType::get(
380
0
      Type::getVoidTy(C),
381
0
      {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
382
0
       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
383
0
       getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
384
0
      /*isVarArg*/ false);
385
0
  FunctionCallee RegVar = M.getOrInsertFunction(
386
0
      IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
387
388
  // Create the references to the start / stop symbols defined by the linker.
389
0
  auto *EntriesB =
390
0
      new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
391
0
                         /*isConstant*/ true, GlobalValue::ExternalLinkage,
392
0
                         /*Initializer*/ nullptr,
393
0
                         IsHIP ? "__start_hip_offloading_entries"
394
0
                               : "__start_cuda_offloading_entries");
395
0
  EntriesB->setVisibility(GlobalValue::HiddenVisibility);
396
0
  auto *EntriesE =
397
0
      new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
398
0
                         /*isConstant*/ true, GlobalValue::ExternalLinkage,
399
0
                         /*Initializer*/ nullptr,
400
0
                         IsHIP ? "__stop_hip_offloading_entries"
401
0
                               : "__stop_cuda_offloading_entries");
402
0
  EntriesE->setVisibility(GlobalValue::HiddenVisibility);
403
404
0
  auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
405
0
                                         Type::getInt8PtrTy(C)->getPointerTo(),
406
0
                                         /*isVarArg*/ false);
407
0
  auto *RegGlobalsFn =
408
0
      Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
409
0
                       IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
410
0
  RegGlobalsFn->setSection(".text.startup");
411
412
  // Create the loop to register all the entries.
413
0
  IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
414
0
  auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
415
0
  auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
416
0
  auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
417
0
  auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
418
0
  auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
419
0
  auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
420
0
  auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
421
0
  auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
422
0
  auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
423
424
0
  auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
425
0
  Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
426
0
  Builder.SetInsertPoint(EntryBB);
427
0
  auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
428
0
  auto *AddrPtr =
429
0
      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
430
0
                                {ConstantInt::get(getSizeTTy(M), 0),
431
0
                                 ConstantInt::get(Type::getInt32Ty(C), 0)});
432
0
  auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
433
0
  auto *NamePtr =
434
0
      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
435
0
                                {ConstantInt::get(getSizeTTy(M), 0),
436
0
                                 ConstantInt::get(Type::getInt32Ty(C), 1)});
437
0
  auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
438
0
  auto *SizePtr =
439
0
      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
440
0
                                {ConstantInt::get(getSizeTTy(M), 0),
441
0
                                 ConstantInt::get(Type::getInt32Ty(C), 2)});
442
0
  auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
443
0
  auto *FlagsPtr =
444
0
      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
445
0
                                {ConstantInt::get(getSizeTTy(M), 0),
446
0
                                 ConstantInt::get(Type::getInt32Ty(C), 3)});
447
0
  auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag");
448
0
  auto *FnCond =
449
0
      Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
450
0
  Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
451
452
  // Create kernel registration code.
453
0
  Builder.SetInsertPoint(IfThenBB);
454
0
  Builder.CreateCall(RegFunc,
455
0
                     {RegGlobalsFn->arg_begin(), Addr, Name, Name,
456
0
                      ConstantInt::get(Type::getInt32Ty(C), -1),
457
0
                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
458
0
                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
459
0
                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
460
0
                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
461
0
                      ConstantPointerNull::get(Type::getInt32PtrTy(C))});
462
0
  Builder.CreateBr(IfEndBB);
463
0
  Builder.SetInsertPoint(IfElseBB);
464
465
0
  auto *Switch = Builder.CreateSwitch(Flags, IfEndBB);
466
  // Create global variable registration code.
467
0
  Builder.SetInsertPoint(SwGlobalBB);
468
0
  Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
469
0
                              ConstantInt::get(Type::getInt32Ty(C), 0), Size,
470
0
                              ConstantInt::get(Type::getInt32Ty(C), 0),
471
0
                              ConstantInt::get(Type::getInt32Ty(C), 0)});
472
0
  Builder.CreateBr(IfEndBB);
473
0
  Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB);
474
475
  // Create managed variable registration code.
476
0
  Builder.SetInsertPoint(SwManagedBB);
477
0
  Builder.CreateBr(IfEndBB);
478
0
  Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB);
479
480
  // Create surface variable registration code.
481
0
  Builder.SetInsertPoint(SwSurfaceBB);
482
0
  Builder.CreateBr(IfEndBB);
483
0
  Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB);
484
485
  // Create texture variable registration code.
486
0
  Builder.SetInsertPoint(SwTextureBB);
487
0
  Builder.CreateBr(IfEndBB);
488
0
  Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB);
489
490
0
  Builder.SetInsertPoint(IfEndBB);
491
0
  auto *NewEntry = Builder.CreateInBoundsGEP(
492
0
      getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
493
0
  auto *Cmp = Builder.CreateICmpEQ(
494
0
      NewEntry,
495
0
      ConstantExpr::getInBoundsGetElementPtr(
496
0
          ArrayType::get(getEntryTy(M), 0), EntriesE,
497
0
          ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
498
0
                                ConstantInt::get(getSizeTTy(M), 0)})));
499
0
  Entry->addIncoming(
500
0
      ConstantExpr::getInBoundsGetElementPtr(
501
0
          ArrayType::get(getEntryTy(M), 0), EntriesB,
502
0
          ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
503
0
                                ConstantInt::get(getSizeTTy(M), 0)})),
504
0
      &RegGlobalsFn->getEntryBlock());
505
0
  Entry->addIncoming(NewEntry, IfEndBB);
506
0
  Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
507
0
  Builder.SetInsertPoint(ExitBB);
508
0
  Builder.CreateRetVoid();
509
510
0
  return RegGlobalsFn;
511
0
}
512
513
// Create the constructor and destructor to register the fatbinary with the CUDA
514
// runtime.
515
void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
516
0
                                  bool IsHIP) {
517
0
  LLVMContext &C = M.getContext();
518
0
  auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
519
0
  auto *CtorFunc =
520
0
      Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
521
0
                       IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
522
0
  CtorFunc->setSection(".text.startup");
523
524
0
  auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
525
0
  auto *DtorFunc =
526
0
      Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
527
0
                       IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
528
0
  DtorFunc->setSection(".text.startup");
529
530
  // Get the __cudaRegisterFatBinary function declaration.
531
0
  auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
532
0
                                     Type::getInt8PtrTy(C),
533
0
                                     /*isVarArg*/ false);
534
0
  FunctionCallee RegFatbin = M.getOrInsertFunction(
535
0
      IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
536
  // Get the __cudaRegisterFatBinaryEnd function declaration.
537
0
  auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
538
0
                                        Type::getInt8PtrTy(C)->getPointerTo(),
539
0
                                        /*isVarArg*/ false);
540
0
  FunctionCallee RegFatbinEnd =
541
0
      M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
542
  // Get the __cudaUnregisterFatBinary function declaration.
543
0
  auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
544
0
                                       Type::getInt8PtrTy(C)->getPointerTo(),
545
0
                                       /*isVarArg*/ false);
546
0
  FunctionCallee UnregFatbin = M.getOrInsertFunction(
547
0
      IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
548
0
      UnregFatTy);
549
550
0
  auto *AtExitTy =
551
0
      FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
552
0
                        /*isVarArg*/ false);
553
0
  FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
554
555
0
  auto *BinaryHandleGlobal = new llvm::GlobalVariable(
556
0
      M, Type::getInt8PtrTy(C)->getPointerTo(), false,
557
0
      llvm::GlobalValue::InternalLinkage,
558
0
      llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
559
0
      IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
560
561
  // Create the constructor to register this image with the runtime.
562
0
  IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
563
0
  CallInst *Handle = CtorBuilder.CreateCall(
564
0
      RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
565
0
                     FatbinDesc, Type::getInt8PtrTy(C)));
566
0
  CtorBuilder.CreateAlignedStore(
567
0
      Handle, BinaryHandleGlobal,
568
0
      Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
569
0
  CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
570
0
  if (!IsHIP)
571
0
    CtorBuilder.CreateCall(RegFatbinEnd, Handle);
572
0
  CtorBuilder.CreateCall(AtExit, DtorFunc);
573
0
  CtorBuilder.CreateRetVoid();
574
575
  // Create the destructor to unregister the image with the runtime. We cannot
576
  // use a standard global destructor after CUDA 9.2 so this must be called by
577
  // `atexit()` intead.
578
0
  IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
579
0
  LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
580
0
      Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
581
0
      Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
582
0
  DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
583
0
  DtorBuilder.CreateRetVoid();
584
585
  // Add this function to constructors.
586
0
  appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
587
0
}
588
589
} // namespace
590
591
0
Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
592
0
  GlobalVariable *Desc = createBinDesc(M, Images);
593
0
  if (!Desc)
594
0
    return createStringError(inconvertibleErrorCode(),
595
0
                             "No binary descriptors created.");
596
0
  createRegisterFunction(M, Desc);
597
0
  createUnregisterFunction(M, Desc);
598
0
  return Error::success();
599
0
}
600
601
0
Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
602
0
  GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
603
0
  if (!Desc)
604
0
    return createStringError(inconvertibleErrorCode(),
605
0
                             "No fatinbary section created.");
606
607
0
  createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
608
0
  return Error::success();
609
0
}
610
611
0
Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
612
0
  GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
613
0
  if (!Desc)
614
0
    return createStringError(inconvertibleErrorCode(),
615
0
                             "No fatinbary section created.");
616
617
0
  createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
618
0
  return Error::success();
619
0
}