Coverage Report

Created: 2019-07-24 05:18

/Users/buildslave/jenkins/workspace/clang-stage2-coverage-R/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
Line
Count
Source (jump to first uncovered line)
1
//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
///
9
/// \file
10
/// Fix bitcasted functions.
11
///
12
/// WebAssembly requires caller and callee signatures to match, however in LLVM,
13
/// some amount of slop is vaguely permitted. Detect mismatch by looking for
14
/// bitcasts of functions and rewrite them to use wrapper functions instead.
15
///
16
/// This doesn't catch all cases, such as when a function's address is taken in
17
/// one place and casted in another, but it works for many common cases.
18
///
19
/// Note that LLVM already optimizes away function bitcasts in common cases by
20
/// dropping arguments as needed, so this pass only ends up getting used in less
21
/// common cases.
22
///
23
//===----------------------------------------------------------------------===//
24
25
#include "WebAssembly.h"
26
#include "llvm/IR/CallSite.h"
27
#include "llvm/IR/Constants.h"
28
#include "llvm/IR/Instructions.h"
29
#include "llvm/IR/Module.h"
30
#include "llvm/IR/Operator.h"
31
#include "llvm/Pass.h"
32
#include "llvm/Support/Debug.h"
33
#include "llvm/Support/raw_ostream.h"
34
using namespace llvm;
35
36
#define DEBUG_TYPE "wasm-fix-function-bitcasts"
37
38
namespace {
39
class FixFunctionBitcasts final : public ModulePass {
40
0
  StringRef getPassName() const override {
41
0
    return "WebAssembly Fix Function Bitcasts";
42
0
  }
43
44
426
  void getAnalysisUsage(AnalysisUsage &AU) const override {
45
426
    AU.setPreservesCFG();
46
426
    ModulePass::getAnalysisUsage(AU);
47
426
  }
48
49
  bool runOnModule(Module &M) override;
50
51
public:
52
  static char ID;
53
426
  FixFunctionBitcasts() : ModulePass(ID) {}
54
};
55
} // End anonymous namespace
56
57
char FixFunctionBitcasts::ID = 0;
58
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
59
                "Fix mismatching bitcasts for WebAssembly", false, false)
60
61
426
ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
62
426
  return new FixFunctionBitcasts();
63
426
}
64
65
// Recursively descend the def-use lists from V to find non-bitcast users of
66
// bitcasts of V.
67
static void findUses(Value *V, Function &F,
68
                     SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
69
5.29k
                     SmallPtrSetImpl<Constant *> &ConstantBCs) {
70
5.29k
  for (Use &U : V->uses()) {
71
1.58k
    if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
72
83
      findUses(BC, F, Uses, ConstantBCs);
73
1.50k
    else if (U.get()->getType() != F.getType()) {
74
136
      CallSite CS(U.getUser());
75
136
      if (!CS)
76
103
        // Skip uses that aren't immediately called
77
103
        continue;
78
33
      Value *Callee = CS.getCalledValue();
79
33
      if (Callee != V)
80
2
        // Skip calls where the function isn't the callee
81
2
        continue;
82
31
      if (isa<Constant>(U.get())) {
83
29
        // Only add constant bitcasts to the list once; they get RAUW'd
84
29
        auto C = ConstantBCs.insert(cast<Constant>(U.get()));
85
29
        if (!C.second)
86
4
          continue;
87
27
      }
88
27
      Uses.push_back(std::make_pair(&U, &F));
89
27
    }
90
1.58k
  }
91
5.29k
}
92
93
// Create a wrapper function with type Ty that calls F (which may have a
94
// different type). Attempt to support common bitcasted function idioms:
95
//  - Call with more arguments than needed: arguments are dropped
96
//  - Call with fewer arguments than needed: arguments are filled in with undef
97
//  - Return value is not needed: drop it
98
//  - Return value needed but not present: supply an undef
99
//
100
// If the all the argument types of trivially castable to one another (i.e.
101
// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
102
// instead).
103
//
104
// If there is a type mismatch that we know would result in an invalid wasm
105
// module then generate wrapper that contains unreachable (i.e. abort at
106
// runtime).  Such programs are deep into undefined behaviour territory,
107
// but we choose to fail at runtime rather than generate and invalid module
108
// or fail at compiler time.  The reason we delay the error is that we want
109
// to support the CMake which expects to be able to compile and link programs
110
// that refer to functions with entirely incorrect signatures (this is how
111
// CMake detects the existence of a function in a toolchain).
112
//
113
// For bitcasts that involve struct types we don't know at this stage if they
114
// would be equivalent at the wasm level and so we can't know if we need to
115
// generate a wrapper.
116
28
static Function *createWrapper(Function *F, FunctionType *Ty) {
117
28
  Module *M = F->getParent();
118
28
119
28
  Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
120
28
                                       F->getName() + "_bitcast", M);
121
28
  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
122
28
  const DataLayout &DL = BB->getModule()->getDataLayout();
123
28
124
28
  // Determine what arguments to pass.
125
28
  SmallVector<Value *, 4> Args;
126
28
  Function::arg_iterator AI = Wrapper->arg_begin();
127
28
  Function::arg_iterator AE = Wrapper->arg_end();
128
28
  FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
129
28
  FunctionType::param_iterator PE = F->getFunctionType()->param_end();
130
28
  bool TypeMismatch = false;
131
28
  bool WrapperNeeded = false;
132
28
133
28
  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
134
28
  Type *RtnType = Ty->getReturnType();
135
28
136
28
  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
137
28
      
(F->getFunctionType()->isVarArg() != Ty->isVarArg())13
||
138
28
      
(ExpectedRtnType != RtnType)13
)
139
21
    WrapperNeeded = true;
140
28
141
40
  for (; AI != AE && 
PI != PE26
;
++AI, ++PI12
) {
142
15
    Type *ArgType = AI->getType();
143
15
    Type *ParamType = *PI;
144
15
145
15
    if (ArgType == ParamType) {
146
4
      Args.push_back(&*AI);
147
11
    } else {
148
11
      if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
149
7
        Instruction *PtrCast =
150
7
            CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
151
7
        BB->getInstList().push_back(PtrCast);
152
7
        Args.push_back(PtrCast);
153
7
      } else 
if (4
ArgType->isStructTy()4
||
ParamType->isStructTy()3
) {
154
1
        LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
155
1
                          << F->getName() << "\n");
156
1
        WrapperNeeded = false;
157
3
      } else {
158
3
        LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
159
3
                          << F->getName() << "\n");
160
3
        LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
161
3
                          << *ParamType << " Got: " << *ArgType << "\n");
162
3
        TypeMismatch = true;
163
3
        break;
164
3
      }
165
11
    }
166
15
  }
167
28
168
28
  if (WrapperNeeded && 
!TypeMismatch21
) {
169
26
    for (; PI != PE; 
++PI5
)
170
5
      Args.push_back(UndefValue::get(*PI));
171
21
    if (F->isVarArg())
172
17
      
for (; 6
AI != AE;
++AI11
)
173
11
        Args.push_back(&*AI);
174
21
175
21
    CallInst *Call = CallInst::Create(F, Args, "", BB);
176
21
177
21
    Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
178
21
    Type *RtnType = Ty->getReturnType();
179
21
    // Determine what value to return.
180
21
    if (RtnType->isVoidTy()) {
181
13
      ReturnInst::Create(M->getContext(), BB);
182
13
    } else 
if (8
ExpectedRtnType->isVoidTy()8
) {
183
2
      LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
184
2
      ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
185
6
    } else if (RtnType == ExpectedRtnType) {
186
3
      ReturnInst::Create(M->getContext(), Call, BB);
187
3
    } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
188
3
                                                    DL)) {
189
0
      Instruction *Cast =
190
0
          CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
191
0
      BB->getInstList().push_back(Cast);
192
0
      ReturnInst::Create(M->getContext(), Cast, BB);
193
3
    } else if (RtnType->isStructTy() || 
ExpectedRtnType->isStructTy()1
) {
194
2
      LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
195
2
                        << F->getName() << "\n");
196
2
      WrapperNeeded = false;
197
2
    } else {
198
1
      LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
199
1
                        << F->getName() << "\n");
200
1
      LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
201
1
                        << " Got: " << *RtnType << "\n");
202
1
      TypeMismatch = true;
203
1
    }
204
21
  }
205
28
206
28
  if (TypeMismatch) {
207
4
    // Create a new wrapper that simply contains `unreachable`.
208
4
    Wrapper->eraseFromParent();
209
4
    Wrapper = Function::Create(Ty, Function::PrivateLinkage,
210
4
                               F->getName() + "_bitcast_invalid", M);
211
4
    BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
212
4
    new UnreachableInst(M->getContext(), BB);
213
4
    Wrapper->setName(F->getName() + "_bitcast_invalid");
214
24
  } else if (!WrapperNeeded) {
215
6
    LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
216
6
                      << "\n");
217
6
    Wrapper->eraseFromParent();
218
6
    return nullptr;
219
6
  }
220
22
  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
221
22
  return Wrapper;
222
22
}
223
224
// Test whether a main function with type FuncTy should be rewritten to have
225
// type MainTy.
226
9
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
227
9
  // Only fix the main function if it's the standard zero-arg form. That way,
228
9
  // the standard cases will work as expected, and users will see signature
229
9
  // mismatches from the linker for non-standard cases.
230
9
  return FuncTy->getReturnType() == MainTy->getReturnType() &&
231
9
         
FuncTy->getNumParams() == 06
&&
232
9
         
!FuncTy->isVarArg()3
;
233
9
}
234
235
426
bool FixFunctionBitcasts::runOnModule(Module &M) {
236
426
  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
237
426
238
426
  Function *Main = nullptr;
239
426
  CallInst *CallMain = nullptr;
240
426
  SmallVector<std::pair<Use *, Function *>, 0> Uses;
241
426
  SmallPtrSet<Constant *, 2> ConstantBCs;
242
426
243
426
  // Collect all the places that need wrappers.
244
5.20k
  for (Function &F : M) {
245
5.20k
    findUses(&F, F, Uses, ConstantBCs);
246
5.20k
247
5.20k
    // If we have a "main" function, and its type isn't
248
5.20k
    // "int main(int argc, char *argv[])", create an artificial call with it
249
5.20k
    // bitcasted to that type so that we generate a wrapper for it, so that
250
5.20k
    // the C runtime can call it.
251
5.20k
    if (F.getName() == "main") {
252
9
      Main = &F;
253
9
      LLVMContext &C = M.getContext();
254
9
      Type *MainArgTys[] = {Type::getInt32Ty(C),
255
9
                            PointerType::get(Type::getInt8PtrTy(C), 0)};
256
9
      FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
257
9
                                               /*isVarArg=*/false);
258
9
      if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
259
3
        LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
260
3
                          << *F.getFunctionType() << "\n");
261
3
        Value *Args[] = {UndefValue::get(MainArgTys[0]),
262
3
                         UndefValue::get(MainArgTys[1])};
263
3
        Value *Casted =
264
3
            ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
265
3
        CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
266
3
        Use *UseMain = &CallMain->getOperandUse(2);
267
3
        Uses.push_back(std::make_pair(UseMain, &F));
268
3
      }
269
9
    }
270
5.20k
  }
271
426
272
426
  DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
273
426
274
426
  for (auto &UseFunc : Uses) {
275
30
    Use *U = UseFunc.first;
276
30
    Function *F = UseFunc.second;
277
30
    auto *PTy = cast<PointerType>(U->get()->getType());
278
30
    auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
279
30
280
30
    // If the function is casted to something like i8* as a "generic pointer"
281
30
    // to be later casted to something else, we can't generate a wrapper for it.
282
30
    // Just ignore such casts for now.
283
30
    if (!Ty)
284
0
      continue;
285
30
286
30
    auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
287
30
    if (Pair.second)
288
28
      Pair.first->second = createWrapper(F, Ty);
289
30
290
30
    Function *Wrapper = Pair.first->second;
291
30
    if (!Wrapper)
292
6
      continue;
293
24
294
24
    if (isa<Constant>(U->get()))
295
22
      U->get()->replaceAllUsesWith(Wrapper);
296
2
    else
297
2
      U->set(Wrapper);
298
24
  }
299
426
300
426
  // If we created a wrapper for main, rename the wrapper so that it's the
301
426
  // one that gets called from startup.
302
426
  if (CallMain) {
303
3
    Main->setName("__original_main");
304
3
    auto *MainWrapper =
305
3
        cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
306
3
    delete CallMain;
307
3
    if (Main->isDeclaration()) {
308
1
      // The wrapper is not needed in this case as we don't need to export
309
1
      // it to anyone else.
310
1
      MainWrapper->eraseFromParent();
311
2
    } else {
312
2
      // Otherwise give the wrapper the same linkage as the original main
313
2
      // function, so that it can be called from the same places.
314
2
      MainWrapper->setName("main");
315
2
      MainWrapper->setLinkage(Main->getLinkage());
316
2
      MainWrapper->setVisibility(Main->getVisibility());
317
2
    }
318
3
  }
319
426
320
426
  return true;
321
426
}