1 /*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "slang_rs_export_func.h"
18
19 #include <string>
20
21 #include "clang/AST/ASTContext.h"
22 #include "clang/AST/Decl.h"
23
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/DerivedTypes.h"
26
27 #include "slang_assert.h"
28 #include "slang_rs_context.h"
29
30 namespace slang {
31
32 namespace {
33
34 // Ensure that the exported function is actually valid
ValidateFuncDecl(slang::RSContext * Context,const clang::FunctionDecl * FD)35 static bool ValidateFuncDecl(slang::RSContext *Context,
36 const clang::FunctionDecl *FD) {
37 slangAssert(Context && FD);
38 const clang::ASTContext &C = FD->getASTContext();
39 if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
40 Context->ReportError(
41 FD->getLocation(),
42 "invokable non-static functions are required to return void");
43 return false;
44 }
45 return true;
46 }
47
48 } // namespace
49
Create(RSContext * Context,const clang::FunctionDecl * FD)50 RSExportFunc *RSExportFunc::Create(RSContext *Context,
51 const clang::FunctionDecl *FD) {
52 llvm::StringRef Name = FD->getName();
53 RSExportFunc *F;
54
55 slangAssert(!Name.empty() && "Function must have a name");
56
57 if (!ValidateFuncDecl(Context, FD)) {
58 return nullptr;
59 }
60
61 F = new RSExportFunc(Context, Name, FD);
62
63 // Initialize mParamPacketType
64 if (FD->getNumParams() <= 0) {
65 F->mParamPacketType = nullptr;
66 } else {
67 clang::ASTContext &Ctx = Context->getASTContext();
68
69 std::string Id = CreateDummyName("helper_func_param", F->getName());
70
71 clang::RecordDecl *RD =
72 clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
73 Ctx.getTranslationUnitDecl(),
74 clang::SourceLocation(),
75 clang::SourceLocation(),
76 &Ctx.Idents.get(Id));
77
78 for (unsigned i = 0; i < FD->getNumParams(); i++) {
79 const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
80 llvm::StringRef ParamName = PVD->getName();
81
82 if (PVD->hasDefaultArg())
83 fprintf(stderr, "Note: parameter '%s' in function '%s' has default "
84 "value which is not supported\n",
85 ParamName.str().c_str(),
86 F->getName().c_str());
87
88 clang::FieldDecl *FD =
89 clang::FieldDecl::Create(Ctx,
90 RD,
91 clang::SourceLocation(),
92 clang::SourceLocation(),
93 PVD->getIdentifier(),
94 PVD->getOriginalType(),
95 nullptr,
96 /* BitWidth = */ nullptr,
97 /* Mutable = */ false,
98 /* HasInit = */ clang::ICIS_NoInit);
99 RD->addDecl(FD);
100 }
101
102 RD->completeDefinition();
103
104 clang::QualType T = Ctx.getTagDeclType(RD);
105 slangAssert(!T.isNull());
106
107 RSExportType *ET =
108 RSExportType::Create(Context, T.getTypePtr(), NotLegacyKernelArgument);
109
110 if (ET == nullptr) {
111 fprintf(stderr, "Failed to export the function %s. There's at least one "
112 "parameter whose type is not supported by the "
113 "reflection\n", F->getName().c_str());
114 return nullptr;
115 }
116
117 slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
118 "Parameter packet must be a record");
119
120 F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
121 }
122
123 return F;
124 }
125
126 bool
checkParameterPacketType(llvm::StructType * ParamTy) const127 RSExportFunc::checkParameterPacketType(llvm::StructType *ParamTy) const {
128 if (ParamTy == nullptr)
129 return !hasParam();
130 else if (!hasParam())
131 return false;
132
133 slangAssert(mParamPacketType != nullptr);
134
135 const RSExportRecordType *ERT = mParamPacketType;
136 // must have same number of elements
137 if (ERT->getFields().size() != ParamTy->getNumElements())
138 return false;
139
140 const llvm::StructLayout *ParamTySL =
141 getRSContext()->getDataLayout().getStructLayout(ParamTy);
142
143 unsigned Index = 0;
144 for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
145 FE = ERT->fields_end(); FI != FE; FI++, Index++) {
146 const RSExportRecordType::Field *F = *FI;
147
148 llvm::Type *T1 = F->getType()->getLLVMType();
149 llvm::Type *T2 = ParamTy->getTypeAtIndex(Index);
150
151 // Fast check
152 if (T1 == T2)
153 continue;
154
155 // Check offset
156 size_t T1Offset = F->getOffsetInParent();
157 size_t T2Offset = ParamTySL->getElementOffset(Index);
158
159 if (T1Offset != T2Offset)
160 return false;
161
162 // Check size
163 size_t T1Size = F->getType()->getAllocSize();
164 size_t T2Size = getRSContext()->getDataLayout().getTypeAllocSize(T2);
165
166 if (T1Size != T2Size)
167 return false;
168 }
169
170 return true;
171 }
172
173 } // namespace slang
174