1 /*
2  * Copyright 2010-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_type.h"
18 
19 #include <list>
20 #include <vector>
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/AST/Attr.h"
24 #include "clang/AST/RecordLayout.h"
25 
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Type.h"
30 
31 #include "slang_assert.h"
32 #include "slang_rs_context.h"
33 #include "slang_rs_export_element.h"
34 #include "slang_version.h"
35 
36 #define CHECK_PARENT_EQUALITY(ParentClass, E) \
37   if (!ParentClass::matchODR(E, true))        \
38     return false;
39 
40 namespace slang {
41 
42 namespace {
43 
44 // For the data types we support:
45 //  Category      - data type category
46 //  SName         - "common name" in script (C99)
47 //  RsType        - element name in RenderScript
48 //  RsShortType   - short element name in RenderScript
49 //  SizeInBits    - size in bits
50 //  CName         - reflected C name
51 //  JavaName      - reflected Java name
52 //  JavaArrayElementName - reflected name in Java arrays
53 //  CVecName      - prefix for C vector types
54 //  JavaVecName   - prefix for Java vector type
55 //  JavaPromotion - unsigned type undergoing Java promotion
56 //
57 // IMPORTANT: The data types in this table should be at the same index as
58 // specified by the corresponding DataType enum.
59 //
60 // TODO: Pull this information out into a separate file.
61 static RSReflectionType gReflectionTypes[] = {
62 #define _ nullptr
63   //      Category     SName              RsType       RsST           CName         JN      JAEN       CVN       JVN     JP
64 {PrimitiveDataType,   "half",         "FLOAT_16",     "F16", 16,     "half",   "short",  "short",   "Half",  "Short", false},
65 {PrimitiveDataType,  "float",         "FLOAT_32",     "F32", 32,    "float",   "float",  "float",  "Float",  "Float", false},
66 {PrimitiveDataType, "double",         "FLOAT_64",     "F64", 64,   "double",  "double", "double", "Double", "Double", false},
67 {PrimitiveDataType,   "char",         "SIGNED_8",      "I8",  8,   "int8_t",    "byte",   "byte",   "Byte",   "Byte", false},
68 {PrimitiveDataType,  "short",        "SIGNED_16",     "I16", 16,  "int16_t",   "short",  "short",  "Short",  "Short", false},
69 {PrimitiveDataType,    "int",        "SIGNED_32",     "I32", 32,  "int32_t",     "int",    "int",    "Int",    "Int", false},
70 {PrimitiveDataType,   "long",        "SIGNED_64",     "I64", 64,  "int64_t",    "long",   "long",   "Long",   "Long", false},
71 {PrimitiveDataType,  "uchar",       "UNSIGNED_8",      "U8",  8,  "uint8_t",   "short",   "byte",  "UByte",  "Short",  true},
72 {PrimitiveDataType, "ushort",      "UNSIGNED_16",     "U16", 16, "uint16_t",     "int",  "short", "UShort",    "Int",  true},
73 {PrimitiveDataType,   "uint",      "UNSIGNED_32",     "U32", 32, "uint32_t",    "long",    "int",   "UInt",   "Long",  true},
74 {PrimitiveDataType,  "ulong",      "UNSIGNED_64",     "U64", 64, "uint64_t",    "long",   "long",  "ULong",   "Long", false},
75 {PrimitiveDataType,   "bool",          "BOOLEAN", "BOOLEAN",  8,     "bool", "boolean",   "byte",        _,        _, false},
76 {PrimitiveDataType,        _,   "UNSIGNED_5_6_5",         _, 16,          _,         _,        _,        _,        _, false},
77 {PrimitiveDataType,        _, "UNSIGNED_5_5_5_1",         _, 16,          _,         _,        _,        _,        _, false},
78 {PrimitiveDataType,        _, "UNSIGNED_4_4_4_4",         _, 16,          _,         _,        _,        _,        _, false},
79 
80 {MatrixDataType, "rs_matrix2x2", "MATRIX_2X2", _,  4*32, "rs_matrix2x2", "Matrix2f", _, _, _, false},
81 {MatrixDataType, "rs_matrix3x3", "MATRIX_3X3", _,  9*32, "rs_matrix3x3", "Matrix3f", _, _, _, false},
82 {MatrixDataType, "rs_matrix4x4", "MATRIX_4X4", _, 16*32, "rs_matrix4x4", "Matrix4f", _, _, _, false},
83 
84 // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
85 // This is handled specially by the GetElementSizeInBits() method.
86 {ObjectDataType,          "rs_element",          "RS_ELEMENT",          "ELEMENT", 32,         "Element",         "Element", _, _, _, false},
87 {ObjectDataType,             "rs_type",             "RS_TYPE",             "TYPE", 32,            "Type",            "Type", _, _, _, false},
88 {ObjectDataType,       "rs_allocation",       "RS_ALLOCATION",       "ALLOCATION", 32,      "Allocation",      "Allocation", _, _, _, false},
89 {ObjectDataType,          "rs_sampler",          "RS_SAMPLER",          "SAMPLER", 32,         "Sampler",         "Sampler", _, _, _, false},
90 {ObjectDataType,           "rs_script",           "RS_SCRIPT",           "SCRIPT", 32,          "Script",          "Script", _, _, _, false},
91 {ObjectDataType,             "rs_mesh",             "RS_MESH",             "MESH", 32,            "Mesh",            "Mesh", _, _, _, false},
92 {ObjectDataType,             "rs_path",             "RS_PATH",             "PATH", 32,            "Path",            "Path", _, _, _, false},
93 {ObjectDataType, "rs_program_fragment", "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", _, _, _, false},
94 {ObjectDataType,   "rs_program_vertex",   "RS_PROGRAM_VERTEX",   "PROGRAM_VERTEX", 32,   "ProgramVertex",   "ProgramVertex", _, _, _, false},
95 {ObjectDataType,   "rs_program_raster",   "RS_PROGRAM_RASTER",   "PROGRAM_RASTER", 32,   "ProgramRaster",   "ProgramRaster", _, _, _, false},
96 {ObjectDataType,    "rs_program_store",    "RS_PROGRAM_STORE",    "PROGRAM_STORE", 32,    "ProgramStore",    "ProgramStore", _, _, _, false},
97 {ObjectDataType,             "rs_font",             "RS_FONT",             "FONT", 32,            "Font",            "Font", _, _, _, false},
98 #undef _
99 };
100 
101 const int kMaxVectorSize = 4;
102 
103 struct BuiltinInfo {
104   clang::BuiltinType::Kind builtinTypeKind;
105   DataType type;
106   /* TODO If we return std::string instead of llvm::StringRef, we could build
107    * the name instead of duplicating the entries.
108    */
109   const char *cname[kMaxVectorSize];
110 };
111 
112 
113 BuiltinInfo BuiltinInfoTable[] = {
114     {clang::BuiltinType::Bool, DataTypeBoolean,
115      {"bool", "bool2", "bool3", "bool4"}},
116     {clang::BuiltinType::Char_U, DataTypeUnsigned8,
117      {"uchar", "uchar2", "uchar3", "uchar4"}},
118     {clang::BuiltinType::UChar, DataTypeUnsigned8,
119      {"uchar", "uchar2", "uchar3", "uchar4"}},
120     {clang::BuiltinType::Char16, DataTypeSigned16,
121      {"short", "short2", "short3", "short4"}},
122     {clang::BuiltinType::Char32, DataTypeSigned32,
123      {"int", "int2", "int3", "int4"}},
124     {clang::BuiltinType::UShort, DataTypeUnsigned16,
125      {"ushort", "ushort2", "ushort3", "ushort4"}},
126     {clang::BuiltinType::UInt, DataTypeUnsigned32,
127      {"uint", "uint2", "uint3", "uint4"}},
128     {clang::BuiltinType::ULong, DataTypeUnsigned64,
129      {"ulong", "ulong2", "ulong3", "ulong4"}},
130     {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
131      {"ulong", "ulong2", "ulong3", "ulong4"}},
132 
133     {clang::BuiltinType::Char_S, DataTypeSigned8,
134      {"char", "char2", "char3", "char4"}},
135     {clang::BuiltinType::SChar, DataTypeSigned8,
136      {"char", "char2", "char3", "char4"}},
137     {clang::BuiltinType::Short, DataTypeSigned16,
138      {"short", "short2", "short3", "short4"}},
139     {clang::BuiltinType::Int, DataTypeSigned32,
140      {"int", "int2", "int3", "int4"}},
141     {clang::BuiltinType::Long, DataTypeSigned64,
142      {"long", "long2", "long3", "long4"}},
143     {clang::BuiltinType::LongLong, DataTypeSigned64,
144      {"long", "long2", "long3", "long4"}},
145     {clang::BuiltinType::Half, DataTypeFloat16,
146      {"half", "half2", "half3", "half4"}},
147     {clang::BuiltinType::Float, DataTypeFloat32,
148      {"float", "float2", "float3", "float4"}},
149     {clang::BuiltinType::Double, DataTypeFloat64,
150      {"double", "double2", "double3", "double4"}},
151 };
152 const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
153 
154 struct NameAndPrimitiveType {
155   const char *name;
156   DataType dataType;
157 };
158 
159 static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
160     {"rs_matrix2x2", DataTypeRSMatrix2x2},
161     {"rs_matrix3x3", DataTypeRSMatrix3x3},
162     {"rs_matrix4x4", DataTypeRSMatrix4x4},
163     {"rs_element", DataTypeRSElement},
164     {"rs_type", DataTypeRSType},
165     {"rs_allocation", DataTypeRSAllocation},
166     {"rs_sampler", DataTypeRSSampler},
167     {"rs_script", DataTypeRSScript},
168     {"rs_mesh", DataTypeRSMesh},
169     {"rs_path", DataTypeRSPath},
170     {"rs_program_fragment", DataTypeRSProgramFragment},
171     {"rs_program_vertex", DataTypeRSProgramVertex},
172     {"rs_program_raster", DataTypeRSProgramRaster},
173     {"rs_program_store", DataTypeRSProgramStore},
174     {"rs_font", DataTypeRSFont},
175 };
176 
177 const int MatrixAndObjectDataTypesCount =
178     sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
179 
180 static const clang::Type *TypeExportableHelper(
181     const clang::Type *T,
182     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
183     slang::RSContext *Context,
184     const clang::VarDecl *VD,
185     const clang::RecordDecl *TopLevelRecord,
186     ExportKind EK);
187 
188 template <unsigned N>
ReportTypeError(slang::RSContext * Context,const clang::NamedDecl * ND,const clang::RecordDecl * TopLevelRecord,const char (& Message)[N],unsigned int TargetAPI=0)189 static void ReportTypeError(slang::RSContext *Context,
190                             const clang::NamedDecl *ND,
191                             const clang::RecordDecl *TopLevelRecord,
192                             const char (&Message)[N],
193                             unsigned int TargetAPI = 0) {
194   // Attempt to use the type declaration first (if we have one).
195   // Fall back to the variable definition, if we are looking at something
196   // like an array declaration that can't be exported.
197   if (TopLevelRecord) {
198     Context->ReportError(TopLevelRecord->getLocation(), Message)
199         << TopLevelRecord->getName() << TargetAPI;
200   } else if (ND) {
201     Context->ReportError(ND->getLocation(), Message) << ND->getName()
202                                                      << TargetAPI;
203   } else {
204     slangAssert(false && "Variables should be validated before exporting");
205   }
206 }
207 
ConstantArrayTypeExportableHelper(const clang::ConstantArrayType * CAT,llvm::SmallPtrSet<const clang::Type *,8> & SPS,slang::RSContext * Context,const clang::VarDecl * VD,const clang::RecordDecl * TopLevelRecord,ExportKind EK)208 static const clang::Type *ConstantArrayTypeExportableHelper(
209     const clang::ConstantArrayType *CAT,
210     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
211     slang::RSContext *Context,
212     const clang::VarDecl *VD,
213     const clang::RecordDecl *TopLevelRecord,
214     ExportKind EK) {
215   // Check element type
216   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
217   if (ElementType->isArrayType()) {
218     ReportTypeError(Context, VD, TopLevelRecord,
219                     "multidimensional arrays cannot be exported: '%0'");
220     return nullptr;
221   } else if (ElementType->isExtVectorType()) {
222     const clang::ExtVectorType *EVT =
223         static_cast<const clang::ExtVectorType*>(ElementType);
224     unsigned numElements = EVT->getNumElements();
225 
226     const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
227     if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
228       ReportTypeError(Context, VD, TopLevelRecord,
229         "vectors of non-primitive types cannot be exported: '%0'");
230       return nullptr;
231     }
232 
233     if (numElements == 3 && CAT->getSize() != 1) {
234       ReportTypeError(Context, VD, TopLevelRecord,
235         "arrays of width 3 vector types cannot be exported: '%0'");
236       return nullptr;
237     }
238   }
239 
240   if (TypeExportableHelper(ElementType, SPS, Context, VD,
241                            TopLevelRecord, EK) == nullptr) {
242     return nullptr;
243   } else {
244     return CAT;
245   }
246 }
247 
FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind)248 BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
249   for (int i = 0; i < BuiltinInfoTableCount; i++) {
250     if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
251       return &BuiltinInfoTable[i];
252     }
253   }
254   return nullptr;
255 }
256 
TypeExportableHelper(clang::Type const * T,llvm::SmallPtrSet<clang::Type const *,8> & SPS,slang::RSContext * Context,clang::VarDecl const * VD,clang::RecordDecl const * TopLevelRecord,ExportKind EK)257 static const clang::Type *TypeExportableHelper(
258     clang::Type const *T,
259     llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
260     slang::RSContext *Context,
261     clang::VarDecl const *VD,
262     clang::RecordDecl const *TopLevelRecord,
263     ExportKind EK) {
264   // Normalize first
265   if ((T = GetCanonicalType(T)) == nullptr)
266     return nullptr;
267 
268   if (SPS.count(T))
269     return T;
270 
271   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
272 
273   switch (T->getTypeClass()) {
274     case clang::Type::Builtin: {
275       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
276       return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
277     }
278     case clang::Type::Record: {
279       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
280         return T;  // RS object type, no further checks are needed
281       }
282 
283       // Check internal struct
284       if (T->isUnionType()) {
285         ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
286                         "unions cannot be exported: '%0'");
287         return nullptr;
288       } else if (!T->isStructureType()) {
289         slangAssert(false && "Unknown type cannot be exported");
290         return nullptr;
291       }
292 
293       clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
294       slangAssert(RD);
295       RD = RD->getDefinition();
296       if (RD == nullptr) {
297         ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
298                         "struct is not defined in this module");
299         return nullptr;
300       }
301 
302       if (!TopLevelRecord) {
303         TopLevelRecord = RD;
304       }
305       if (RD->getName().empty()) {
306         ReportTypeError(Context, nullptr, RD,
307                         "anonymous structures cannot be exported");
308         return nullptr;
309       }
310 
311       // Fast check
312       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
313         return nullptr;
314 
315       // Insert myself into checking set
316       SPS.insert(T);
317 
318       // Check all element
319       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
320                FE = RD->field_end();
321            FI != FE;
322            FI++) {
323         const clang::FieldDecl *FD = *FI;
324         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
325         FT = GetCanonicalType(FT);
326 
327         if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord,
328                                   EK)) {
329           return nullptr;
330         }
331 
332         // We don't support bit fields yet
333         //
334         // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
335         if (FD->isBitField()) {
336           Context->ReportError(
337               FD->getLocation(),
338               "bit fields are not able to be exported: '%0.%1'")
339               << RD->getName() << FD->getName();
340           return nullptr;
341         }
342       }
343 
344       return T;
345     }
346     case clang::Type::FunctionProto:
347     case clang::Type::FunctionNoProto:
348       ReportTypeError(Context, VD, TopLevelRecord,
349                       "function types cannot be exported: '%0'");
350       return nullptr;
351     case clang::Type::Pointer: {
352       if (TopLevelRecord) {
353         ReportTypeError(Context, VD, TopLevelRecord,
354             "structures containing pointers cannot be used as the type of "
355             "an exported global variable or the parameter to an exported "
356             "function: '%0'");
357         return nullptr;
358       }
359 
360       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
361       const clang::Type *PointeeType = GetPointeeType(PT);
362 
363       if (PointeeType->getTypeClass() == clang::Type::Pointer) {
364         ReportTypeError(Context, VD, TopLevelRecord,
365             "multiple levels of pointers cannot be exported: '%0'");
366         return nullptr;
367       }
368 
369       // Void pointers are forbidden for export, although we must accept
370       // void pointers that come in as arguments to a legacy kernel.
371       if (PointeeType->isVoidType() && EK != LegacyKernelArgument) {
372         ReportTypeError(Context, VD, TopLevelRecord,
373             "void pointers cannot be exported: '%0'");
374         return nullptr;
375       }
376 
377       // We don't support pointer with array-type pointee
378       if (PointeeType->isArrayType()) {
379         ReportTypeError(Context, VD, TopLevelRecord,
380             "pointers to arrays cannot be exported: '%0'");
381         return nullptr;
382       }
383 
384       // Check for unsupported pointee type
385       if (TypeExportableHelper(PointeeType, SPS, Context, VD,
386                                 TopLevelRecord, EK) == nullptr)
387         return nullptr;
388       else
389         return T;
390     }
391     case clang::Type::ExtVector: {
392       const clang::ExtVectorType *EVT =
393               static_cast<const clang::ExtVectorType*>(CTI);
394       // Only vector with size 2, 3 and 4 are supported.
395       if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
396         return nullptr;
397 
398       // Check base element type
399       const clang::Type *ElementType = GetExtVectorElementType(EVT);
400 
401       if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
402           (TypeExportableHelper(ElementType, SPS, Context, VD,
403                                 TopLevelRecord, EK) == nullptr))
404         return nullptr;
405       else
406         return T;
407     }
408     case clang::Type::ConstantArray: {
409       const clang::ConstantArrayType *CAT =
410               static_cast<const clang::ConstantArrayType*>(CTI);
411 
412       return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
413                                                TopLevelRecord, EK);
414     }
415     case clang::Type::Enum: {
416       // FIXME: We currently convert enums to integers, rather than reflecting
417       // a more complete (and nicer type-safe Java version).
418       return Context->getASTContext().IntTy.getTypePtr();
419     }
420     default: {
421       slangAssert(false && "Unknown type cannot be validated");
422       return nullptr;
423     }
424   }
425 }
426 
427 // Return the type that can be used to create RSExportType, will always return
428 // the canonical type.
429 //
430 // If the Type T is not exportable, this function returns nullptr. DiagEngine is
431 // used to generate proper Clang diagnostic messages when a non-exportable type
432 // is detected. TopLevelRecord is used to capture the highest struct (in the
433 // case of a nested hierarchy) for detecting other types that cannot be exported
434 // (mostly pointers within a struct).
TypeExportable(const clang::Type * T,slang::RSContext * Context,const clang::VarDecl * VD,ExportKind EK)435 static const clang::Type *TypeExportable(const clang::Type *T,
436                                          slang::RSContext *Context,
437                                          const clang::VarDecl *VD,
438                                          ExportKind EK) {
439   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
440       llvm::SmallPtrSet<const clang::Type*, 8>();
441 
442   return TypeExportableHelper(T, SPS, Context, VD, nullptr, EK);
443 }
444 
ValidateRSObjectInVarDecl(slang::RSContext * Context,const clang::VarDecl * VD,bool InCompositeType,unsigned int TargetAPI)445 static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
446                                       const clang::VarDecl *VD, bool InCompositeType,
447                                       unsigned int TargetAPI) {
448   if (TargetAPI < SLANG_JB_TARGET_API) {
449     // Only if we are already in a composite type (like an array or structure).
450     if (InCompositeType) {
451       // Only if we are actually exported (i.e. non-static).
452       if (VD->hasLinkage() &&
453           (VD->getFormalLinkage() == clang::ExternalLinkage)) {
454         // Only if we are not a pointer to an object.
455         const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
456         if (T->getTypeClass() != clang::Type::Pointer) {
457           ReportTypeError(Context, VD, nullptr,
458                           "arrays/structures containing RS object types "
459                           "cannot be exported in target API < %1: '%0'",
460                           SLANG_JB_TARGET_API);
461           return false;
462         }
463       }
464     }
465   }
466 
467   return true;
468 }
469 
470 // Helper function for ValidateType(). We do a recursive descent on the
471 // type hierarchy to ensure that we can properly export/handle the
472 // declaration.
473 // \return true if the variable declaration is valid,
474 //         false if it is invalid (along with proper diagnostics).
475 //
476 // C - ASTContext (for diagnostics + builtin types).
477 // T - sub-type that we are validating.
478 // ND - (optional) top-level named declaration that we are validating.
479 // SPS - set of types we have already seen/validated.
480 // InCompositeType - true if we are within an outer composite type.
481 // UnionDecl - set if we are in a sub-type of a union.
482 // TargetAPI - target SDK API level.
483 // IsFilterscript - whether or not we are compiling for Filterscript
484 // IsExtern - is this type externally visible (i.e. extern global or parameter
485 //                                             to an extern function)
ValidateTypeHelper(slang::RSContext * Context,clang::ASTContext & C,const clang::Type * & T,const clang::NamedDecl * ND,clang::SourceLocation Loc,llvm::SmallPtrSet<const clang::Type *,8> & SPS,bool InCompositeType,clang::RecordDecl * UnionDecl,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)486 static bool ValidateTypeHelper(
487     slang::RSContext *Context,
488     clang::ASTContext &C,
489     const clang::Type *&T,
490     const clang::NamedDecl *ND,
491     clang::SourceLocation Loc,
492     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
493     bool InCompositeType,
494     clang::RecordDecl *UnionDecl,
495     unsigned int TargetAPI,
496     bool IsFilterscript,
497     bool IsExtern) {
498   if ((T = GetCanonicalType(T)) == nullptr)
499     return true;
500 
501   if (SPS.count(T))
502     return true;
503 
504   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
505 
506   switch (T->getTypeClass()) {
507     case clang::Type::Record: {
508       if (RSExportPrimitiveType::IsRSObjectType(T)) {
509         const clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
510         if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
511                                              TargetAPI)) {
512           return false;
513         }
514       }
515 
516       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
517         if (!UnionDecl) {
518           return true;
519         } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
520           ReportTypeError(Context, nullptr, UnionDecl,
521               "unions containing RS object types are not allowed");
522           return false;
523         }
524       }
525 
526       clang::RecordDecl *RD = nullptr;
527 
528       // Check internal struct
529       if (T->isUnionType()) {
530         RD = T->getAsUnionType()->getDecl();
531         UnionDecl = RD;
532       } else if (T->isStructureType()) {
533         RD = T->getAsStructureType()->getDecl();
534       } else {
535         slangAssert(false && "Unknown type cannot be exported");
536         return false;
537       }
538 
539       slangAssert(RD);
540       RD = RD->getDefinition();
541       if (RD == nullptr) {
542         // FIXME
543         return true;
544       }
545 
546       // Fast check
547       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
548         return false;
549 
550       // Insert myself into checking set
551       SPS.insert(T);
552 
553       // Check all elements
554       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
555                FE = RD->field_end();
556            FI != FE;
557            FI++) {
558         const clang::FieldDecl *FD = *FI;
559         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
560         FT = GetCanonicalType(FT);
561 
562         if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
563                                 TargetAPI, IsFilterscript, IsExtern)) {
564           return false;
565         }
566       }
567 
568       return true;
569     }
570 
571     case clang::Type::Builtin: {
572       if (IsFilterscript) {
573         clang::QualType QT = T->getCanonicalTypeInternal();
574         if (QT == C.DoubleTy ||
575             QT == C.LongDoubleTy ||
576             QT == C.LongTy ||
577             QT == C.LongLongTy) {
578           if (ND) {
579             Context->ReportError(
580                 Loc,
581                 "Builtin types > 32 bits in size are forbidden in "
582                 "Filterscript: '%0'")
583                 << ND->getName();
584           } else {
585             Context->ReportError(
586                 Loc,
587                 "Builtin types > 32 bits in size are forbidden in "
588                 "Filterscript");
589           }
590           return false;
591         }
592       }
593       break;
594     }
595 
596     case clang::Type::Pointer: {
597       if (IsFilterscript) {
598         if (ND) {
599           Context->ReportError(Loc,
600                                "Pointers are forbidden in Filterscript: '%0'")
601               << ND->getName();
602           return false;
603         } else {
604           // TODO(srhines): Find a better way to handle expressions (i.e. no
605           // NamedDecl) involving pointers in FS that should be allowed.
606           // An example would be calls to library functions like
607           // rsMatrixMultiply() that take rs_matrixNxN * types.
608         }
609       }
610 
611       // Forbid pointers in structures that are externally visible.
612       if (InCompositeType && IsExtern) {
613         if (ND) {
614           Context->ReportError(Loc,
615               "structures containing pointers cannot be used as the type of "
616               "an exported global variable or the parameter to an exported "
617               "function: '%0'")
618             << ND->getName();
619         } else {
620           Context->ReportError(Loc,
621               "structures containing pointers cannot be used as the type of "
622               "an exported global variable or the parameter to an exported "
623               "function");
624         }
625         return false;
626       }
627 
628       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
629       const clang::Type *PointeeType = GetPointeeType(PT);
630 
631       return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
632                                 InCompositeType, UnionDecl, TargetAPI,
633                                 IsFilterscript, IsExtern);
634     }
635 
636     case clang::Type::ExtVector: {
637       const clang::ExtVectorType *EVT =
638               static_cast<const clang::ExtVectorType*>(CTI);
639       const clang::Type *ElementType = GetExtVectorElementType(EVT);
640       if (TargetAPI < SLANG_ICS_TARGET_API &&
641           InCompositeType &&
642           EVT->getNumElements() == 3 &&
643           ND &&
644           ND->getFormalLinkage() == clang::ExternalLinkage) {
645         ReportTypeError(Context, ND, nullptr,
646                         "structs containing vectors of dimension 3 cannot "
647                         "be exported at this API level: '%0'");
648         return false;
649       }
650       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
651                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
652     }
653 
654     case clang::Type::ConstantArray: {
655       const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
656       const clang::Type *ElementType = GetConstantArrayElementType(CAT);
657       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
658                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
659     }
660 
661     default: {
662       break;
663     }
664   }
665 
666   return true;
667 }
668 
669 }  // namespace
670 
CreateDummyName(const char * type,const std::string & name)671 std::string CreateDummyName(const char *type, const std::string &name) {
672   std::stringstream S;
673   S << "<" << type;
674   if (!name.empty()) {
675     S << ":" << name;
676   }
677   S << ">";
678   return S.str();
679 }
680 
681 /****************************** RSExportType ******************************/
NormalizeType(const clang::Type * & T,llvm::StringRef & TypeName,RSContext * Context,const clang::VarDecl * VD,ExportKind EK)682 bool RSExportType::NormalizeType(const clang::Type *&T,
683                                  llvm::StringRef &TypeName,
684                                  RSContext *Context,
685                                  const clang::VarDecl *VD,
686                                  ExportKind EK) {
687   if ((T = TypeExportable(T, Context, VD, EK)) == nullptr) {
688     return false;
689   }
690   // Get type name
691   TypeName = RSExportType::GetTypeName(T);
692   if (Context && TypeName.empty()) {
693     if (VD) {
694       Context->ReportError(VD->getLocation(),
695                            "anonymous types cannot be exported");
696     } else {
697       Context->ReportError("anonymous types cannot be exported");
698     }
699     return false;
700   }
701 
702   return true;
703 }
704 
ValidateType(slang::RSContext * Context,clang::ASTContext & C,clang::QualType QT,const clang::NamedDecl * ND,clang::SourceLocation Loc,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)705 bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
706                                 clang::QualType QT, const clang::NamedDecl *ND,
707                                 clang::SourceLocation Loc,
708                                 unsigned int TargetAPI, bool IsFilterscript,
709                                 bool IsExtern) {
710   const clang::Type *T = QT.getTypePtr();
711   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
712       llvm::SmallPtrSet<const clang::Type*, 8>();
713 
714   // If this is an externally visible variable declaration, we check if the
715   // type is able to be exported first.
716   if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
717     if (VD->getFormalLinkage() == clang::ExternalLinkage) {
718       if (!TypeExportable(T, Context, VD, NotLegacyKernelArgument)) {
719         return false;
720       }
721     }
722   }
723   return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
724                             IsFilterscript, IsExtern);
725 }
726 
ValidateVarDecl(slang::RSContext * Context,clang::VarDecl * VD,unsigned int TargetAPI,bool IsFilterscript)727 bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
728                                    clang::VarDecl *VD, unsigned int TargetAPI,
729                                    bool IsFilterscript) {
730   return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
731                       VD->getLocation(), TargetAPI, IsFilterscript,
732                       (VD->getFormalLinkage() == clang::ExternalLinkage));
733 }
734 
735 const clang::Type
GetTypeOfDecl(const clang::DeclaratorDecl * DD)736 *RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
737   if (DD) {
738     clang::QualType T = DD->getType();
739 
740     if (T.isNull())
741       return nullptr;
742     else
743       return T.getTypePtr();
744   }
745   return nullptr;
746 }
747 
GetTypeName(const clang::Type * T)748 llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
749   T = GetCanonicalType(T);
750   if (T == nullptr)
751     return llvm::StringRef();
752 
753   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
754 
755   switch (T->getTypeClass()) {
756     case clang::Type::Builtin: {
757       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
758       BuiltinInfo *info = FindBuiltinType(BT->getKind());
759       if (info != nullptr) {
760         return info->cname[0];
761       }
762       slangAssert(false && "Unknown data type of the builtin");
763       break;
764     }
765     case clang::Type::Record: {
766       clang::RecordDecl *RD;
767       if (T->isStructureType()) {
768         RD = T->getAsStructureType()->getDecl();
769       } else {
770         break;
771       }
772 
773       llvm::StringRef Name = RD->getName();
774       if (Name.empty()) {
775         if (RD->getTypedefNameForAnonDecl() != nullptr) {
776           Name = RD->getTypedefNameForAnonDecl()->getName();
777         }
778 
779         if (Name.empty()) {
780           // Try to find a name from redeclaration (i.e. typedef)
781           for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
782                    RE = RD->redecls_end();
783                RI != RE;
784                RI++) {
785             slangAssert(*RI != nullptr && "cannot be NULL object");
786 
787             Name = (*RI)->getName();
788             if (!Name.empty())
789               break;
790           }
791         }
792       }
793       return Name;
794     }
795     case clang::Type::Pointer: {
796       // "*" plus pointee name
797       const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
798       const clang::Type *PT = GetPointeeType(P);
799       llvm::StringRef PointeeName;
800       if (NormalizeType(PT, PointeeName, nullptr, nullptr,
801                         NotLegacyKernelArgument)) {
802         char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
803         Name[0] = '*';
804         memcpy(Name + 1, PointeeName.data(), PointeeName.size());
805         Name[PointeeName.size() + 1] = '\0';
806         return Name;
807       }
808       break;
809     }
810     case clang::Type::ExtVector: {
811       const clang::ExtVectorType *EVT =
812               static_cast<const clang::ExtVectorType*>(CTI);
813       return RSExportVectorType::GetTypeName(EVT);
814       break;
815     }
816     case clang::Type::ConstantArray : {
817       // Construct name for a constant array is too complicated.
818       return "<ConstantArray>";
819     }
820     default: {
821       break;
822     }
823   }
824 
825   return llvm::StringRef();
826 }
827 
828 
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,ExportKind EK)829 RSExportType *RSExportType::Create(RSContext *Context,
830                                    const clang::Type *T,
831                                    const llvm::StringRef &TypeName,
832                                    ExportKind EK) {
833   // Lookup the context to see whether the type was processed before.
834   // Newly created RSExportType will insert into context
835   // in RSExportType::RSExportType()
836   RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
837 
838   if (ETI != Context->export_types_end())
839     return ETI->second;
840 
841   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
842 
843   RSExportType *ET = nullptr;
844   switch (T->getTypeClass()) {
845     case clang::Type::Record: {
846       DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
847       switch (dt) {
848         case DataTypeUnknown: {
849           // User-defined types
850           ET = RSExportRecordType::Create(Context,
851                                           T->getAsStructureType(),
852                                           TypeName);
853           break;
854         }
855         case DataTypeRSMatrix2x2: {
856           // 2 x 2 Matrix type
857           ET = RSExportMatrixType::Create(Context,
858                                           T->getAsStructureType(),
859                                           TypeName,
860                                           2);
861           break;
862         }
863         case DataTypeRSMatrix3x3: {
864           // 3 x 3 Matrix type
865           ET = RSExportMatrixType::Create(Context,
866                                           T->getAsStructureType(),
867                                           TypeName,
868                                           3);
869           break;
870         }
871         case DataTypeRSMatrix4x4: {
872           // 4 x 4 Matrix type
873           ET = RSExportMatrixType::Create(Context,
874                                           T->getAsStructureType(),
875                                           TypeName,
876                                           4);
877           break;
878         }
879         default: {
880           // Others are primitive types
881           ET = RSExportPrimitiveType::Create(Context, T, TypeName);
882           break;
883         }
884       }
885       break;
886     }
887     case clang::Type::Builtin: {
888       ET = RSExportPrimitiveType::Create(Context, T, TypeName);
889       break;
890     }
891     case clang::Type::Pointer: {
892       ET = RSExportPointerType::Create(Context,
893                                        static_cast<const clang::PointerType*>(CTI),
894                                        TypeName);
895       // FIXME: free the name (allocated in RSExportType::GetTypeName)
896       delete [] TypeName.data();
897       break;
898     }
899     case clang::Type::ExtVector: {
900       ET = RSExportVectorType::Create(Context,
901                                       static_cast<const clang::ExtVectorType*>(CTI),
902                                       TypeName);
903       break;
904     }
905     case clang::Type::ConstantArray: {
906       ET = RSExportConstantArrayType::Create(
907               Context,
908               static_cast<const clang::ConstantArrayType*>(CTI));
909       break;
910     }
911     default: {
912       Context->ReportError("unknown type cannot be exported: '%0'")
913           << T->getTypeClassName();
914       break;
915     }
916   }
917 
918   return ET;
919 }
920 
Create(RSContext * Context,const clang::Type * T,ExportKind EK,const clang::VarDecl * VD)921 RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T,
922                                    ExportKind EK, const clang::VarDecl *VD) {
923   llvm::StringRef TypeName;
924   if (NormalizeType(T, TypeName, Context, VD, EK)) {
925     return Create(Context, T, TypeName, EK);
926   } else {
927     return nullptr;
928   }
929 }
930 
CreateFromDecl(RSContext * Context,const clang::VarDecl * VD)931 RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
932                                            const clang::VarDecl *VD) {
933   return RSExportType::Create(Context, GetTypeOfDecl(VD),
934                               NotLegacyKernelArgument, VD);
935 }
936 
getStoreSize() const937 size_t RSExportType::getStoreSize() const {
938   return getRSContext()->getDataLayout().getTypeStoreSize(getLLVMType());
939 }
940 
getAllocSize() const941 size_t RSExportType::getAllocSize() const {
942     return getRSContext()->getDataLayout().getTypeAllocSize(getLLVMType());
943 }
944 
RSExportType(RSContext * Context,ExportClass Class,const llvm::StringRef & Name,clang::SourceLocation Loc)945 RSExportType::RSExportType(RSContext *Context,
946                            ExportClass Class,
947                            const llvm::StringRef &Name, clang::SourceLocation Loc)
948     : RSExportable(Context, RSExportable::EX_TYPE, Loc),
949       mClass(Class),
950       // Make a copy on Name since memory stored @Name is either allocated in
951       // ASTContext or allocated in GetTypeName which will be destroyed later.
952       mName(Name.data(), Name.size()),
953       mLLVMType(nullptr) {
954   // Don't cache the type whose name start with '<'. Those type failed to
955   // get their name since constructing their name in GetTypeName() requiring
956   // complicated work.
957   if (!IsDummyName(Name)) {
958     // TODO(zonr): Need to check whether the insertion is successful or not.
959     Context->insertExportType(llvm::StringRef(Name), this);
960   }
961 
962 }
963 
keep()964 bool RSExportType::keep() {
965   if (!RSExportable::keep())
966     return false;
967   // Invalidate converted LLVM type.
968   mLLVMType = nullptr;
969   return true;
970 }
971 
matchODR(const RSExportType * E,bool) const972 bool RSExportType::matchODR(const RSExportType *E, bool /* LookInto */) const {
973   return (E->getClass() == getClass());
974 }
975 
~RSExportType()976 RSExportType::~RSExportType() {
977 }
978 
979 /************************** RSExportPrimitiveType **************************/
980 llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
981 RSExportPrimitiveType::RSSpecificTypeMap;
982 
IsPrimitiveType(const clang::Type * T)983 bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
984   if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
985     return true;
986   else
987     return false;
988 }
989 
990 DataType
GetRSSpecificType(const llvm::StringRef & TypeName)991 RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
992   if (TypeName.empty())
993     return DataTypeUnknown;
994 
995   if (RSSpecificTypeMap->empty()) {
996     for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
997       (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
998           MatrixAndObjectDataTypes[i].dataType;
999     }
1000   }
1001 
1002   RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
1003   if (I == RSSpecificTypeMap->end())
1004     return DataTypeUnknown;
1005   else
1006     return I->getValue();
1007 }
1008 
GetRSSpecificType(const clang::Type * T)1009 DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
1010   T = GetCanonicalType(T);
1011   if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
1012     return DataTypeUnknown;
1013 
1014   return GetRSSpecificType( RSExportType::GetTypeName(T) );
1015 }
1016 
IsRSMatrixType(DataType DT)1017 bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
1018     if (DT < 0 || DT >= DataTypeMax) {
1019         return false;
1020     }
1021     return gReflectionTypes[DT].category == MatrixDataType;
1022 }
1023 
IsRSObjectType(DataType DT)1024 bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
1025     if (DT < 0 || DT >= DataTypeMax) {
1026         return false;
1027     }
1028     return gReflectionTypes[DT].category == ObjectDataType;
1029 }
1030 
IsStructureTypeWithRSObject(const clang::Type * T)1031 bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
1032   bool RSObjectTypeSeen = false;
1033   slangAssert(T);
1034   while (T->isArrayType()) {
1035     T = T->getArrayElementTypeNoTypeQual();
1036     slangAssert(T);
1037   }
1038 
1039   const clang::RecordType *RT = T->getAsStructureType();
1040   if (!RT) {
1041     return false;
1042   }
1043 
1044   const clang::RecordDecl *RD = RT->getDecl();
1045   if (RD) {
1046     RD = RD->getDefinition();
1047   }
1048   if (!RD) {
1049     return false;
1050   }
1051 
1052   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1053          FE = RD->field_end();
1054        FI != FE;
1055        FI++) {
1056     // We just look through all field declarations to see if we find a
1057     // declaration for an RS object type (or an array of one).
1058     const clang::FieldDecl *FD = *FI;
1059     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1060     slangAssert(FT);
1061     while (FT->isArrayType()) {
1062       FT = FT->getArrayElementTypeNoTypeQual();
1063       slangAssert(FT);
1064     }
1065 
1066     DataType DT = GetRSSpecificType(FT);
1067     if (IsRSObjectType(DT)) {
1068       // RS object types definitely need to be zero-initialized
1069       RSObjectTypeSeen = true;
1070     } else {
1071       switch (DT) {
1072         case DataTypeRSMatrix2x2:
1073         case DataTypeRSMatrix3x3:
1074         case DataTypeRSMatrix4x4:
1075           // Matrix types should get zero-initialized as well
1076           RSObjectTypeSeen = true;
1077           break;
1078         default:
1079           // Ignore all other primitive types
1080           break;
1081       }
1082       if (FT->isStructureType()) {
1083         // Recursively handle structs of structs (even though these can't
1084         // be exported, it is possible for a user to have them internally).
1085         RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1086       }
1087     }
1088   }
1089 
1090   return RSObjectTypeSeen;
1091 }
1092 
GetElementSizeInBits(const RSExportPrimitiveType * EPT)1093 size_t RSExportPrimitiveType::GetElementSizeInBits(const RSExportPrimitiveType *EPT) {
1094   int type = EPT->getType();
1095   slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1096               "RSExportPrimitiveType::GetElementSizeInBits : unknown data type");
1097   // All RS object types are 256 bits in 64-bit RS.
1098   if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1099     return 256;
1100   }
1101   return gReflectionTypes[type].size_in_bits;
1102 }
1103 
1104 DataType
GetDataType(RSContext * Context,const clang::Type * T)1105 RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1106   if (T == nullptr)
1107     return DataTypeUnknown;
1108 
1109   switch (T->getTypeClass()) {
1110     case clang::Type::Builtin: {
1111       const clang::BuiltinType *BT =
1112               static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1113       BuiltinInfo *info = FindBuiltinType(BT->getKind());
1114       if (info != nullptr) {
1115         return info->type;
1116       }
1117       // The size of type WChar depend on platform so we abandon the support
1118       // to them.
1119       Context->ReportError("built-in type cannot be exported: '%0'")
1120           << T->getTypeClassName();
1121       break;
1122     }
1123     case clang::Type::Record: {
1124       // must be RS object type
1125       return RSExportPrimitiveType::GetRSSpecificType(T);
1126     }
1127     default: {
1128       Context->ReportError("primitive type cannot be exported: '%0'")
1129           << T->getTypeClassName();
1130       break;
1131     }
1132   }
1133 
1134   return DataTypeUnknown;
1135 }
1136 
1137 RSExportPrimitiveType
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,bool Normalized)1138 *RSExportPrimitiveType::Create(RSContext *Context,
1139                                const clang::Type *T,
1140                                const llvm::StringRef &TypeName,
1141                                bool Normalized) {
1142   DataType DT = GetDataType(Context, T);
1143 
1144   if ((DT == DataTypeUnknown) || TypeName.empty())
1145     return nullptr;
1146   else
1147     return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1148                                      DT, Normalized);
1149 }
1150 
Create(RSContext * Context,const clang::Type * T)1151 RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1152                                                      const clang::Type *T) {
1153   llvm::StringRef TypeName;
1154   if (RSExportType::NormalizeType(T, TypeName, Context, nullptr,
1155                                   NotLegacyKernelArgument) &&
1156       IsPrimitiveType(T)) {
1157     return Create(Context, T, TypeName);
1158   } else {
1159     return nullptr;
1160   }
1161 }
1162 
convertToLLVMType() const1163 llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1164   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1165 
1166   if (isRSObjectType()) {
1167     // struct {
1168     //   int *p;
1169     // } __attribute__((packed, aligned(pointer_size)))
1170     //
1171     // which is
1172     //
1173     // <{ [1 x i32] }> in LLVM
1174     //
1175     std::vector<llvm::Type *> Elements;
1176     if (getRSContext()->is64Bit()) {
1177       // 64-bit path
1178       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1179       return llvm::StructType::get(C, Elements, true);
1180     } else {
1181       // 32-bit legacy path
1182       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1183       return llvm::StructType::get(C, Elements, true);
1184     }
1185   }
1186 
1187   switch (mType) {
1188     case DataTypeFloat16: {
1189       return llvm::Type::getHalfTy(C);
1190       break;
1191     }
1192     case DataTypeFloat32: {
1193       return llvm::Type::getFloatTy(C);
1194       break;
1195     }
1196     case DataTypeFloat64: {
1197       return llvm::Type::getDoubleTy(C);
1198       break;
1199     }
1200     case DataTypeBoolean: {
1201       return llvm::Type::getInt1Ty(C);
1202       break;
1203     }
1204     case DataTypeSigned8:
1205     case DataTypeUnsigned8: {
1206       return llvm::Type::getInt8Ty(C);
1207       break;
1208     }
1209     case DataTypeSigned16:
1210     case DataTypeUnsigned16:
1211     case DataTypeUnsigned565:
1212     case DataTypeUnsigned5551:
1213     case DataTypeUnsigned4444: {
1214       return llvm::Type::getInt16Ty(C);
1215       break;
1216     }
1217     case DataTypeSigned32:
1218     case DataTypeUnsigned32: {
1219       return llvm::Type::getInt32Ty(C);
1220       break;
1221     }
1222     case DataTypeSigned64:
1223     case DataTypeUnsigned64: {
1224       return llvm::Type::getInt64Ty(C);
1225       break;
1226     }
1227     default: {
1228       slangAssert(false && "Unknown data type");
1229     }
1230   }
1231 
1232   return nullptr;
1233 }
1234 
matchODR(const RSExportType * E,bool) const1235 bool RSExportPrimitiveType::matchODR(const RSExportType *E,
1236                                      bool /* LookInto */) const {
1237   CHECK_PARENT_EQUALITY(RSExportType, E);
1238   return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1239 }
1240 
getRSReflectionType(DataType DT)1241 RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1242   if (DT > DataTypeUnknown && DT < DataTypeMax) {
1243     return &gReflectionTypes[DT];
1244   } else {
1245     return nullptr;
1246   }
1247 }
1248 
1249 /**************************** RSExportPointerType ****************************/
1250 
1251 RSExportPointerType
Create(RSContext * Context,const clang::PointerType * PT,const llvm::StringRef & TypeName)1252 *RSExportPointerType::Create(RSContext *Context,
1253                              const clang::PointerType *PT,
1254                              const llvm::StringRef &TypeName) {
1255   const clang::Type *PointeeType = GetPointeeType(PT);
1256   const RSExportType *PointeeET;
1257 
1258   if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1259     PointeeET = RSExportType::Create(Context, PointeeType,
1260                                      NotLegacyKernelArgument);
1261   } else {
1262     // Double or higher dimension of pointer, export as int*
1263     PointeeET = RSExportPrimitiveType::Create(Context,
1264                     Context->getASTContext().IntTy.getTypePtr());
1265   }
1266 
1267   if (PointeeET == nullptr) {
1268     // Error diagnostic is emitted for corresponding pointee type
1269     return nullptr;
1270   }
1271 
1272   return new RSExportPointerType(Context, TypeName, PointeeET);
1273 }
1274 
convertToLLVMType() const1275 llvm::Type *RSExportPointerType::convertToLLVMType() const {
1276   llvm::Type *PointeeType = mPointeeType->getLLVMType();
1277   return llvm::PointerType::getUnqual(PointeeType);
1278 }
1279 
keep()1280 bool RSExportPointerType::keep() {
1281   if (!RSExportType::keep())
1282     return false;
1283   const_cast<RSExportType*>(mPointeeType)->keep();
1284   return true;
1285 }
1286 
matchODR(const RSExportType * E,bool) const1287 bool RSExportPointerType::matchODR(const RSExportType *E,
1288                                    bool /* LookInto */) const {
1289   // Exported types cannot contain pointers
1290   slangAssert(false && "Not supposed to perform ODR check on pointers");
1291   return false;
1292 }
1293 
1294 /***************************** RSExportVectorType *****************************/
1295 llvm::StringRef
GetTypeName(const clang::ExtVectorType * EVT)1296 RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1297   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1298   llvm::StringRef name;
1299 
1300   if ((ElementType->getTypeClass() != clang::Type::Builtin))
1301     return name;
1302 
1303   const clang::BuiltinType *BT =
1304           static_cast<const clang::BuiltinType*>(
1305               ElementType->getCanonicalTypeInternal().getTypePtr());
1306 
1307   if ((EVT->getNumElements() < 1) ||
1308       (EVT->getNumElements() > 4))
1309     return name;
1310 
1311   BuiltinInfo *info = FindBuiltinType(BT->getKind());
1312   if (info != nullptr) {
1313     int I = EVT->getNumElements() - 1;
1314     if (I < kMaxVectorSize) {
1315       name = info->cname[I];
1316     } else {
1317       slangAssert(false && "Max vector is 4");
1318     }
1319   }
1320   return name;
1321 }
1322 
Create(RSContext * Context,const clang::ExtVectorType * EVT,const llvm::StringRef & TypeName,bool Normalized)1323 RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1324                                                const clang::ExtVectorType *EVT,
1325                                                const llvm::StringRef &TypeName,
1326                                                bool Normalized) {
1327   slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1328 
1329   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1330   DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1331 
1332   if (DT != DataTypeUnknown)
1333     return new RSExportVectorType(Context,
1334                                   TypeName,
1335                                   DT,
1336                                   Normalized,
1337                                   EVT->getNumElements());
1338   else
1339     return nullptr;
1340 }
1341 
convertToLLVMType() const1342 llvm::Type *RSExportVectorType::convertToLLVMType() const {
1343   llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1344   return llvm::VectorType::get(ElementType, getNumElement());
1345 }
1346 
matchODR(const RSExportType * E,bool) const1347 bool RSExportVectorType::matchODR(const RSExportType *E,
1348                                   bool /* LookInto*/) const {
1349   CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1350   return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1351               == getNumElement());
1352 }
1353 
1354 /***************************** RSExportMatrixType *****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,unsigned Dim)1355 RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1356                                                const clang::RecordType *RT,
1357                                                const llvm::StringRef &TypeName,
1358                                                unsigned Dim) {
1359   slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1360   slangAssert((Dim > 1) && "Invalid dimension of matrix");
1361 
1362   // Check whether the struct rs_matrix is in our expected form (but assume it's
1363   // correct if we're not sure whether it's correct or not)
1364   const clang::RecordDecl* RD = RT->getDecl();
1365   RD = RD->getDefinition();
1366   if (RD != nullptr) {
1367     // Find definition, perform further examination
1368     if (RD->field_empty()) {
1369       Context->ReportError(
1370           RD->getLocation(),
1371           "invalid matrix struct: must have 1 field for saving values: '%0'")
1372           << RD->getName();
1373       return nullptr;
1374     }
1375 
1376     clang::RecordDecl::field_iterator FIT = RD->field_begin();
1377     const clang::FieldDecl *FD = *FIT;
1378     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1379     if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1380       Context->ReportError(RD->getLocation(),
1381                            "invalid matrix struct: first field should"
1382                            " be an array with constant size: '%0'")
1383           << RD->getName();
1384       return nullptr;
1385     }
1386     const clang::ConstantArrayType *CAT =
1387       static_cast<const clang::ConstantArrayType *>(FT);
1388     const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1389     if ((ElementType == nullptr) ||
1390         (ElementType->getTypeClass() != clang::Type::Builtin) ||
1391         (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1392          clang::BuiltinType::Float)) {
1393       Context->ReportError(RD->getLocation(),
1394                            "invalid matrix struct: first field "
1395                            "should be a float array: '%0'")
1396           << RD->getName();
1397       return nullptr;
1398     }
1399 
1400     if (CAT->getSize() != Dim * Dim) {
1401       Context->ReportError(RD->getLocation(),
1402                            "invalid matrix struct: first field "
1403                            "should be an array with size %0: '%1'")
1404           << (Dim * Dim) << (RD->getName());
1405       return nullptr;
1406     }
1407 
1408     FIT++;
1409     if (FIT != RD->field_end()) {
1410       Context->ReportError(RD->getLocation(),
1411                            "invalid matrix struct: must have "
1412                            "exactly 1 field: '%0'")
1413           << RD->getName();
1414       return nullptr;
1415     }
1416   }
1417 
1418   return new RSExportMatrixType(Context, TypeName, Dim);
1419 }
1420 
convertToLLVMType() const1421 llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1422   // Construct LLVM type:
1423   // struct {
1424   //  float X[mDim * mDim];
1425   // }
1426 
1427   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1428   llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1429                                             mDim * mDim);
1430   return llvm::StructType::get(C, X, false);
1431 }
1432 
matchODR(const RSExportType * E,bool) const1433 bool RSExportMatrixType::matchODR(const RSExportType *E,
1434                                   bool /* LookInto */) const {
1435   CHECK_PARENT_EQUALITY(RSExportType, E);
1436   return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1437 }
1438 
1439 /************************* RSExportConstantArrayType *************************/
1440 RSExportConstantArrayType
Create(RSContext * Context,const clang::ConstantArrayType * CAT)1441 *RSExportConstantArrayType::Create(RSContext *Context,
1442                                    const clang::ConstantArrayType *CAT) {
1443   slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1444 
1445   slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1446 
1447   unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1448   slangAssert((Size > 0) && "Constant array should have size greater than 0");
1449 
1450   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1451   RSExportType *ElementET = RSExportType::Create(Context, ElementType,
1452                                                  NotLegacyKernelArgument);
1453 
1454   if (ElementET == nullptr) {
1455     return nullptr;
1456   }
1457 
1458   return new RSExportConstantArrayType(Context,
1459                                        ElementET,
1460                                        Size);
1461 }
1462 
convertToLLVMType() const1463 llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1464   return llvm::ArrayType::get(mElementType->getLLVMType(), getNumElement());
1465 }
1466 
keep()1467 bool RSExportConstantArrayType::keep() {
1468   if (!RSExportType::keep())
1469     return false;
1470   const_cast<RSExportType*>(mElementType)->keep();
1471   return true;
1472 }
1473 
matchODR(const RSExportType * E,bool LookInto) const1474 bool RSExportConstantArrayType::matchODR(const RSExportType *E,
1475                                          bool LookInto) const {
1476   CHECK_PARENT_EQUALITY(RSExportType, E);
1477   const RSExportConstantArrayType *RHS =
1478       static_cast<const RSExportConstantArrayType*>(E);
1479   return ((getNumElement() == RHS->getNumElement()) &&
1480           (getElementType()->matchODR(RHS->getElementType(), LookInto)));
1481 }
1482 
1483 /**************************** RSExportRecordType ****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,bool mIsArtificial)1484 RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1485                                                const clang::RecordType *RT,
1486                                                const llvm::StringRef &TypeName,
1487                                                bool mIsArtificial) {
1488   slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1489 
1490   const clang::RecordDecl *RD = RT->getDecl();
1491   slangAssert(RD->isStruct());
1492 
1493   RD = RD->getDefinition();
1494   if (RD == nullptr) {
1495     slangAssert(false && "struct is not defined in this module");
1496     return nullptr;
1497   }
1498 
1499   // Struct layout construct by clang. We rely on this for obtaining the
1500   // alloc size of a struct and offset of every field in that struct.
1501   const clang::ASTRecordLayout *RL =
1502       &Context->getASTContext().getASTRecordLayout(RD);
1503   slangAssert((RL != nullptr) &&
1504       "Failed to retrieve the struct layout from Clang.");
1505 
1506   RSExportRecordType *ERT =
1507       new RSExportRecordType(Context,
1508                              TypeName,
1509                              RD->getLocation(),
1510                              RD->hasAttr<clang::PackedAttr>(),
1511                              mIsArtificial,
1512                              RL->getDataSize().getQuantity(),
1513                              RL->getSize().getQuantity());
1514   unsigned int Index = 0;
1515 
1516   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1517            FE = RD->field_end();
1518        FI != FE;
1519        FI++, Index++) {
1520 
1521     // FIXME: All fields should be primitive type
1522     slangAssert(FI->getKind() == clang::Decl::Field);
1523     clang::FieldDecl *FD = *FI;
1524 
1525     if (FD->isBitField()) {
1526       return nullptr;
1527     }
1528 
1529     if (FD->isImplicit() && (FD->getName() == RS_PADDING_FIELD_NAME))
1530       continue;
1531 
1532     // Type
1533     RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1534 
1535     if (ET != nullptr) {
1536       ERT->mFields.push_back(
1537           new Field(ET, FD->getName(), ERT,
1538                     static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1539     } else {
1540       // clang static analysis complains about a potential memory leak
1541       // for the memory pointed by ERT at the end of this basic
1542       // block. This is a false warning because the compiler does not
1543       // see that the pointer to this memory is saved away in the
1544       // constructor for RSExportRecordType by calling
1545       // RSContext::newExportable(this). So, we disable this
1546       // particular instance of the warning.
1547       Context->ReportError(RD->getLocation(),
1548                            "field type cannot be exported: '%0.%1'")
1549           << RD->getName() << FD->getName(); // NOLINT
1550       return nullptr;
1551     }
1552   }
1553 
1554   return ERT;
1555 }
1556 
convertToLLVMType() const1557 llvm::Type *RSExportRecordType::convertToLLVMType() const {
1558   // Create an opaque type since struct may reference itself recursively.
1559 
1560   // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1561   std::vector<llvm::Type*> FieldTypes;
1562 
1563   for (const_field_iterator FI = fields_begin(), FE = fields_end();
1564        FI != FE;
1565        FI++) {
1566     const Field *F = *FI;
1567     const RSExportType *FET = F->getType();
1568 
1569     FieldTypes.push_back(FET->getLLVMType());
1570   }
1571 
1572   llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1573                                                FieldTypes,
1574                                                mIsPacked);
1575   if (ST != nullptr) {
1576     return ST;
1577   } else {
1578     return nullptr;
1579   }
1580 }
1581 
keep()1582 bool RSExportRecordType::keep() {
1583   if (!RSExportType::keep())
1584     return false;
1585   for (std::list<const Field*>::iterator I = mFields.begin(),
1586           E = mFields.end();
1587        I != E;
1588        I++) {
1589     const_cast<RSExportType*>((*I)->getType())->keep();
1590   }
1591   return true;
1592 }
1593 
matchODR(const RSExportType * E,bool LookInto) const1594 bool RSExportRecordType::matchODR(const RSExportType *E, bool LookInto) const {
1595   CHECK_PARENT_EQUALITY(RSExportType, E);
1596   // Enforce ODR checking - the type E represents must hold
1597   // *exactly* the same "definition" as the one defined previously. We
1598   // say two record types A and B have the same definition iff:
1599   //
1600   //  struct A {              struct B {
1601   //    Type(a1) a1,            Type(b1) b1,
1602   //    Type(a2) a2,            Type(b1) b2,
1603   //    ...                     ...
1604   //    Type(aN) aN             Type(bM) bM,
1605   //  };                      }
1606   //  Cond. #0. A = B;
1607   //  Cond. #1. They have same number of fields, i.e., N = M;
1608   //  Cond. #2. for (i := 1 to N)
1609   //              Type(ai).matchODR(Type(bi)) must hold;
1610   //  Cond. #3. for (i := 1 to N)
1611   //              Name(ai) = Name(bi) must hold;
1612   //
1613   // where,
1614   //  Type(F) = the type of field F and
1615   //  Name(F) = the field name.
1616 
1617 
1618   const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1619   // Cond. #0.
1620   if (getName() != ERT->getName())
1621     return false;
1622 
1623   // Examine fields - types and names
1624   if (LookInto) {
1625     // Cond. #1
1626     if (ERT->getFields().size() != getFields().size())
1627       return false;
1628 
1629     for (RSExportRecordType::const_field_iterator AI = fields_begin(),
1630          BI = ERT->fields_begin(), AE = fields_end(); AI != AE; ++AI, ++BI) {
1631       const RSExportType *AITy = (*AI)->getType();
1632       const RSExportType *BITy = (*BI)->getType();
1633       // Cond. #3; field names must agree
1634       if ((*AI)->getName() != (*BI)->getName())
1635         return false;
1636 
1637       // Cond. #2; field types must agree recursively until we see another
1638       // next level of RSExportRecordType - such field types will be
1639       // examined and reported later when checkODR() encounters them.
1640       if (!AITy->matchODR(BITy, false))
1641         return false;
1642     }
1643   }
1644   return true;
1645 }
1646 
convertToRTD(RSReflectionTypeData * rtd) const1647 void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1648     memset(rtd, 0, sizeof(*rtd));
1649     rtd->vecSize = 1;
1650 
1651     switch(getClass()) {
1652     case RSExportType::ExportClassPrimitive: {
1653             const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1654             rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1655             return;
1656         }
1657     case RSExportType::ExportClassPointer: {
1658             const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1659             const RSExportType *PointeeType = EPT->getPointeeType();
1660             PointeeType->convertToRTD(rtd);
1661             rtd->isPointer = true;
1662             return;
1663         }
1664     case RSExportType::ExportClassVector: {
1665             const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1666             rtd->type = EVT->getRSReflectionType(EVT);
1667             rtd->vecSize = EVT->getNumElement();
1668             return;
1669         }
1670     case RSExportType::ExportClassMatrix: {
1671             const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1672             unsigned Dim = EMT->getDim();
1673             slangAssert((Dim >= 2) && (Dim <= 4));
1674             rtd->type = &gReflectionTypes[15 + Dim-2];
1675             return;
1676         }
1677     case RSExportType::ExportClassConstantArray: {
1678             const RSExportConstantArrayType* CAT =
1679               static_cast<const RSExportConstantArrayType*>(this);
1680             CAT->getElementType()->convertToRTD(rtd);
1681             rtd->arraySize = CAT->getNumElement();
1682             return;
1683         }
1684     case RSExportType::ExportClassRecord: {
1685             slangAssert(!"RSExportType::ExportClassRecord not implemented");
1686             return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1687         }
1688     default: {
1689             slangAssert(false && "Unknown class of type");
1690         }
1691     }
1692 }
1693 
1694 
1695 }  // namespace slang
1696