1 /*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Assert.h"
18 #include "Log.h"
19 #include "RSUtils.h"
20
21 #include <algorithm>
22 #include <vector>
23
24 #include <llvm/IR/CallSite.h>
25 #include <llvm/IR/Type.h>
26 #include <llvm/IR/Instructions.h>
27 #include <llvm/IR/Module.h>
28 #include <llvm/IR/Function.h>
29 #include <llvm/Pass.h>
30
31 namespace { // anonymous namespace
32
33 static const bool kDebug = false;
34
35 /* RSX86_64CallConvPass: This pass scans for calls to Renderscript functions in
36 * the CPU reference driver. For such calls, it identifies the
37 * pass-by-reference large-object pointer arguments introduced by the frontend
38 * to conform to the AArch64 calling convention (AAPCS). These pointer
39 * arguments are converted to pass-by-value to match the calling convention of
40 * the CPU reference driver.
41 */
42 class RSX86_64CallConvPass: public llvm::ModulePass {
43 private:
IsRSFunctionOfInterest(llvm::Function & F)44 bool IsRSFunctionOfInterest(llvm::Function &F) {
45 // Only Renderscript functions that are not defined locally be considered
46 if (!F.empty()) // defined locally
47 return false;
48
49 // llvm intrinsic or internal function
50 llvm::StringRef FName = F.getName();
51 if (FName.startswith("llvm."))
52 return false;
53
54 // All other functions need to be checked for large-object parameters.
55 // Disallowed (non-Renderscript) functions are detected by a different pass.
56 return true;
57 }
58
59 // Test if this argument needs to be converted to pass-by-value.
IsDerefNeeded(llvm::Function * F,llvm::Argument & Arg)60 bool IsDerefNeeded(llvm::Function *F, llvm::Argument &Arg) {
61 unsigned ArgNo = Arg.getArgNo();
62 llvm::Type *ArgTy = Arg.getType();
63
64 // Do not consider arguments with 'sret' attribute. Parameters with this
65 // attribute are actually pointers to structure return values.
66 if (Arg.hasStructRetAttr())
67 return false;
68
69 // Dereference needed only if type is a pointer to a struct
70 if (!ArgTy->isPointerTy() || !ArgTy->getPointerElementType()->isStructTy())
71 return false;
72
73 // Dereference needed only for certain RS struct objects.
74 llvm::Type *StructTy = ArgTy->getPointerElementType();
75 if (!isRsObjectType(StructTy))
76 return false;
77
78 // TODO Find a better way to encode exceptions
79 llvm::StringRef FName = F->getName();
80 // rsSetObject's first parameter is a pointer
81 if (FName.find("rsSetObject") != std::string::npos && ArgNo == 0)
82 return false;
83 // rsClearObject's first parameter is a pointer
84 if (FName.find("rsClearObject") != std::string::npos && ArgNo == 0)
85 return false;
86 // rsForEachInternal's fifth parameter is a pointer
87 if (FName.find("rsForEachInternal") != std::string::npos && ArgNo == 4)
88 return false;
89
90 return true;
91 }
92
93 // Compute which arguments to this function need be converted to pass-by-value
FillArgsToDeref(llvm::Function * F,std::vector<unsigned> & ArgNums)94 bool FillArgsToDeref(llvm::Function *F, std::vector<unsigned> &ArgNums) {
95 bccAssert(ArgNums.size() == 0);
96
97 for (auto &Arg: F->getArgumentList()) {
98 if (IsDerefNeeded(F, Arg)) {
99 ArgNums.push_back(Arg.getArgNo());
100
101 if (kDebug) {
102 ALOGV("Lowering argument %u for function %s\n", Arg.getArgNo(),
103 F->getName().str().c_str());
104 }
105 }
106 }
107 return ArgNums.size() > 0;
108 }
109
RedefineFn(llvm::Function * OrigFn,std::vector<unsigned> & ArgsToDeref)110 llvm::Function *RedefineFn(llvm::Function *OrigFn,
111 std::vector<unsigned> &ArgsToDeref) {
112
113 llvm::FunctionType *FTy = OrigFn->getFunctionType();
114 std::vector<llvm::Type *> Params(FTy->param_begin(), FTy->param_end());
115
116 llvm::FunctionType *NewTy = llvm::FunctionType::get(FTy->getReturnType(),
117 Params,
118 FTy->isVarArg());
119 llvm::Function *NewFn = llvm::Function::Create(NewTy,
120 OrigFn->getLinkage(),
121 OrigFn->getName(),
122 OrigFn->getParent());
123
124 // Add the ByVal attribute to the attribute list corresponding to this
125 // argument. The list at index (i+1) corresponds to the i-th argument. The
126 // list at index 0 corresponds to the return value's attribute.
127 for (auto i: ArgsToDeref) {
128 NewFn->addAttribute(i+1, llvm::Attribute::ByVal);
129 }
130
131 NewFn->copyAttributesFrom(OrigFn);
132 NewFn->takeName(OrigFn);
133
134 for (auto AI=OrigFn->arg_begin(), AE=OrigFn->arg_end(),
135 NAI=NewFn->arg_begin();
136 AI != AE; ++ AI, ++NAI) {
137 NAI->takeName(&*AI);
138 }
139
140 return NewFn;
141 }
142
ReplaceCallInsn(llvm::CallSite & CS,llvm::Function * NewFn,std::vector<unsigned> & ArgsToDeref)143 void ReplaceCallInsn(llvm::CallSite &CS,
144 llvm::Function *NewFn,
145 std::vector<unsigned> &ArgsToDeref) {
146
147 llvm::CallInst *CI = llvm::cast<llvm::CallInst>(CS.getInstruction());
148 std::vector<llvm::Value *> Args(CS.arg_begin(), CS.arg_end());
149 auto NewCI = llvm::CallInst::Create(NewFn, Args, "", CI);
150
151 // Add the ByVal attribute to the attribute list corresponding to this
152 // argument. The list at index (i+1) corresponds to the i-th argument. The
153 // list at index 0 corresponds to the return value's attribute.
154 for (auto i: ArgsToDeref) {
155 NewCI->addAttribute(i+1, llvm::Attribute::ByVal);
156 }
157 if (CI->isTailCall())
158 NewCI->setTailCall();
159
160 if (!CI->getType()->isVoidTy())
161 CI->replaceAllUsesWith(NewCI);
162
163 CI->eraseFromParent();
164 }
165
166 public:
167 static char ID;
168
RSX86_64CallConvPass()169 RSX86_64CallConvPass()
170 : ModulePass (ID) {
171 }
172
getAnalysisUsage(llvm::AnalysisUsage & AU) const173 virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
174 // This pass does not use any other analysis passes, but it does
175 // modify the existing functions in the module (thus altering the CFG).
176 }
177
runOnModule(llvm::Module & M)178 bool runOnModule(llvm::Module &M) override {
179 // Avoid adding Functions and altering FunctionList while iterating over it
180 // by collecting functions and processing them later.
181 std::vector<llvm::Function *> FunctionsToHandle;
182
183 auto &FunctionList = M.getFunctionList();
184 for (auto &OrigFn: FunctionList) {
185 if (!IsRSFunctionOfInterest(OrigFn))
186 continue;
187 FunctionsToHandle.push_back(&OrigFn);
188 }
189
190 for (auto OrigFn: FunctionsToHandle) {
191 std::vector<unsigned> ArgsToDeref;
192 if (!FillArgsToDeref(OrigFn, ArgsToDeref))
193 continue;
194
195 // Replace all calls to OrigFn and erase it from parent.
196 llvm::Function *NewFn = RedefineFn(OrigFn, ArgsToDeref);
197 while (!OrigFn->use_empty()) {
198 llvm::CallSite CS(OrigFn->user_back());
199 ReplaceCallInsn(CS, NewFn, ArgsToDeref);
200 }
201 OrigFn->eraseFromParent();
202 }
203
204 return FunctionsToHandle.size() > 0;
205 }
206
207 };
208
209 }
210
211 char RSX86_64CallConvPass::ID = 0;
212
213 static llvm::RegisterPass<RSX86_64CallConvPass> X("X86-64-calling-conv",
214 "remove AArch64 assumptions from calls in X86-64");
215
216 namespace bcc {
217
218 llvm::ModulePass *
createRSX86_64CallConvPass()219 createRSX86_64CallConvPass() {
220 return new RSX86_64CallConvPass();
221 }
222
223 }
224