1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_func.h"
18 
19 #include <string>
20 
21 #include "clang/AST/ASTContext.h"
22 #include "clang/AST/Decl.h"
23 
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/DerivedTypes.h"
26 
27 #include "slang_assert.h"
28 #include "slang_rs_context.h"
29 
30 namespace slang {
31 
32 namespace {
33 
34 // Ensure that the exported function is actually valid
ValidateFuncDecl(slang::RSContext * Context,const clang::FunctionDecl * FD)35 static bool ValidateFuncDecl(slang::RSContext *Context,
36                              const clang::FunctionDecl *FD) {
37   slangAssert(Context && FD);
38   const clang::ASTContext &C = FD->getASTContext();
39   if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
40     Context->ReportError(
41         FD->getLocation(),
42         "invokable non-static functions are required to return void");
43     return false;
44   }
45   return true;
46 }
47 
48 }  // namespace
49 
Create(RSContext * Context,const clang::FunctionDecl * FD)50 RSExportFunc *RSExportFunc::Create(RSContext *Context,
51                                    const clang::FunctionDecl *FD) {
52   llvm::StringRef Name = FD->getName();
53   RSExportFunc *F;
54 
55   slangAssert(!Name.empty() && "Function must have a name");
56 
57   if (!ValidateFuncDecl(Context, FD)) {
58     return nullptr;
59   }
60 
61   F = new RSExportFunc(Context, Name, FD);
62 
63   // Initialize mParamPacketType
64   if (FD->getNumParams() <= 0) {
65     F->mParamPacketType = nullptr;
66   } else {
67     clang::ASTContext &Ctx = Context->getASTContext();
68 
69     std::string Id = CreateDummyName("helper_func_param", F->getName());
70 
71     clang::RecordDecl *RD =
72         clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
73                                   Ctx.getTranslationUnitDecl(),
74                                   clang::SourceLocation(),
75                                   clang::SourceLocation(),
76                                   &Ctx.Idents.get(Id));
77 
78     for (unsigned i = 0; i < FD->getNumParams(); i++) {
79       const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
80       llvm::StringRef ParamName = PVD->getName();
81 
82       if (PVD->hasDefaultArg())
83         fprintf(stderr, "Note: parameter '%s' in function '%s' has default "
84                         "value which is not supported\n",
85                         ParamName.str().c_str(),
86                         F->getName().c_str());
87 
88       clang::FieldDecl *FD =
89           clang::FieldDecl::Create(Ctx,
90                                    RD,
91                                    clang::SourceLocation(),
92                                    clang::SourceLocation(),
93                                    PVD->getIdentifier(),
94                                    PVD->getOriginalType(),
95                                    nullptr,
96                                    /* BitWidth = */ nullptr,
97                                    /* Mutable = */ false,
98                                    /* HasInit = */ clang::ICIS_NoInit);
99       RD->addDecl(FD);
100     }
101 
102     RD->completeDefinition();
103 
104     clang::QualType T = Ctx.getTagDeclType(RD);
105     slangAssert(!T.isNull());
106 
107     RSExportType *ET =
108       RSExportType::Create(Context, T.getTypePtr(), NotLegacyKernelArgument);
109 
110     if (ET == nullptr) {
111       fprintf(stderr, "Failed to export the function %s. There's at least one "
112                       "parameter whose type is not supported by the "
113                       "reflection\n", F->getName().c_str());
114       return nullptr;
115     }
116 
117     slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
118            "Parameter packet must be a record");
119 
120     F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
121   }
122 
123   return F;
124 }
125 
126 bool
checkParameterPacketType(llvm::StructType * ParamTy) const127 RSExportFunc::checkParameterPacketType(llvm::StructType *ParamTy) const {
128   if (ParamTy == nullptr)
129     return !hasParam();
130   else if (!hasParam())
131     return false;
132 
133   slangAssert(mParamPacketType != nullptr);
134 
135   const RSExportRecordType *ERT = mParamPacketType;
136   // must have same number of elements
137   if (ERT->getFields().size() != ParamTy->getNumElements())
138     return false;
139 
140   const llvm::StructLayout *ParamTySL =
141       getRSContext()->getDataLayout().getStructLayout(ParamTy);
142 
143   unsigned Index = 0;
144   for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
145        FE = ERT->fields_end(); FI != FE; FI++, Index++) {
146     const RSExportRecordType::Field *F = *FI;
147 
148     llvm::Type *T1 = F->getType()->getLLVMType();
149     llvm::Type *T2 = ParamTy->getTypeAtIndex(Index);
150 
151     // Fast check
152     if (T1 == T2)
153       continue;
154 
155     // Check offset
156     size_t T1Offset = F->getOffsetInParent();
157     size_t T2Offset = ParamTySL->getElementOffset(Index);
158 
159     if (T1Offset != T2Offset)
160       return false;
161 
162     // Check size
163     size_t T1Size = F->getType()->getAllocSize();
164     size_t T2Size = getRSContext()->getDataLayout().getTypeAllocSize(T2);
165 
166     if (T1Size != T2Size)
167       return false;
168   }
169 
170   return true;
171 }
172 
173 }  // namespace slang
174