1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "jni_binder.h"
18 
19 #include <dlfcn.h>
20 #include <inttypes.h>
21 #include <stdio.h>
22 
23 #include "android-base/logging.h"
24 #include "android-base/stringprintf.h"
25 
26 #include "jvmti_helper.h"
27 #include "scoped_local_ref.h"
28 #include "scoped_utf_chars.h"
29 #include "ti_utf.h"
30 
31 namespace art {
32 
MangleForJni(const std::string & s)33 static std::string MangleForJni(const std::string& s) {
34   std::string result;
35   size_t char_count = ti::CountModifiedUtf8Chars(s.c_str(), s.length());
36   const char* cp = &s[0];
37   for (size_t i = 0; i < char_count; ++i) {
38     uint32_t ch = ti::GetUtf16FromUtf8(&cp);
39     if ((ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')) {
40       result.push_back(ch);
41     } else if (ch == '.' || ch == '/') {
42       result += "_";
43     } else if (ch == '_') {
44       result += "_1";
45     } else if (ch == ';') {
46       result += "_2";
47     } else if (ch == '[') {
48       result += "_3";
49     } else {
50       const uint16_t leading = ti::GetLeadingUtf16Char(ch);
51       const uint32_t trailing = ti::GetTrailingUtf16Char(ch);
52 
53       android::base::StringAppendF(&result, "_0%04x", leading);
54       if (trailing != 0) {
55         android::base::StringAppendF(&result, "_0%04x", trailing);
56       }
57     }
58   }
59   return result;
60 }
61 
GetJniShortName(const std::string & class_descriptor,const std::string & method)62 static std::string GetJniShortName(const std::string& class_descriptor, const std::string& method) {
63   // Remove the leading 'L' and trailing ';'...
64   std::string class_name(class_descriptor);
65   CHECK_EQ(class_name[0], 'L') << class_name;
66   CHECK_EQ(class_name[class_name.size() - 1], ';') << class_name;
67   class_name.erase(0, 1);
68   class_name.erase(class_name.size() - 1, 1);
69 
70   std::string short_name;
71   short_name += "Java_";
72   short_name += MangleForJni(class_name);
73   short_name += "_";
74   short_name += MangleForJni(method);
75   return short_name;
76 }
77 
BindMethod(jvmtiEnv * jvmti_env,JNIEnv * env,jclass klass,jmethodID method)78 static void BindMethod(jvmtiEnv* jvmti_env, JNIEnv* env, jclass klass, jmethodID method) {
79   std::string name;
80   std::string signature;
81   std::string mangled_names[2];
82   {
83     char* name_cstr;
84     char* sig_cstr;
85     jvmtiError name_result = jvmti_env->GetMethodName(method, &name_cstr, &sig_cstr, nullptr);
86     CheckJvmtiError(jvmti_env, name_result);
87     CHECK(name_cstr != nullptr);
88     CHECK(sig_cstr != nullptr);
89     name = name_cstr;
90     signature = sig_cstr;
91 
92     char* klass_name;
93     jvmtiError klass_result = jvmti_env->GetClassSignature(klass, &klass_name, nullptr);
94     CheckJvmtiError(jvmti_env, klass_result);
95 
96     mangled_names[0] = GetJniShortName(klass_name, name);
97     // TODO: Long JNI name.
98 
99     CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, name_cstr));
100     CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, sig_cstr));
101     CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, klass_name));
102   }
103 
104   for (const std::string& mangled_name : mangled_names) {
105     if (mangled_name.empty()) {
106       continue;
107     }
108     void* sym = dlsym(RTLD_DEFAULT, mangled_name.c_str());
109     if (sym == nullptr) {
110       continue;
111     }
112 
113     JNINativeMethod native_method;
114     native_method.fnPtr = sym;
115     native_method.name = name.c_str();
116     native_method.signature = signature.c_str();
117 
118     env->RegisterNatives(klass, &native_method, 1);
119 
120     return;
121   }
122 
123   LOG(FATAL) << "Could not find " << mangled_names[0];
124 }
125 
DescriptorToDot(const char * descriptor)126 static std::string DescriptorToDot(const char* descriptor) {
127   size_t length = strlen(descriptor);
128   if (length > 1) {
129     if (descriptor[0] == 'L' && descriptor[length - 1] == ';') {
130       // Descriptors have the leading 'L' and trailing ';' stripped.
131       std::string result(descriptor + 1, length - 2);
132       std::replace(result.begin(), result.end(), '/', '.');
133       return result;
134     } else {
135       // For arrays the 'L' and ';' remain intact.
136       std::string result(descriptor);
137       std::replace(result.begin(), result.end(), '/', '.');
138       return result;
139     }
140   }
141   // Do nothing for non-class/array descriptors.
142   return descriptor;
143 }
144 
GetSystemClassLoader(JNIEnv * env)145 static jobject GetSystemClassLoader(JNIEnv* env) {
146   ScopedLocalRef<jclass> cl_klass(env, env->FindClass("java/lang/ClassLoader"));
147   CHECK(cl_klass.get() != nullptr);
148   jmethodID getsystemclassloader_method = env->GetStaticMethodID(cl_klass.get(),
149                                                                  "getSystemClassLoader",
150                                                                  "()Ljava/lang/ClassLoader;");
151   CHECK(getsystemclassloader_method != nullptr);
152   return env->CallStaticObjectMethod(cl_klass.get(), getsystemclassloader_method);
153 }
154 
FindClassWithClassLoader(JNIEnv * env,const char * class_name,jobject class_loader)155 static jclass FindClassWithClassLoader(JNIEnv* env, const char* class_name, jobject class_loader) {
156   // Create a String of the name.
157   std::string descriptor = android::base::StringPrintf("L%s;", class_name);
158   std::string dot_name = DescriptorToDot(descriptor.c_str());
159   ScopedLocalRef<jstring> name_str(env, env->NewStringUTF(dot_name.c_str()));
160 
161   // Call Class.forName with it.
162   ScopedLocalRef<jclass> c_klass(env, env->FindClass("java/lang/Class"));
163   CHECK(c_klass.get() != nullptr);
164   jmethodID forname_method = env->GetStaticMethodID(
165       c_klass.get(),
166       "forName",
167       "(Ljava/lang/String;ZLjava/lang/ClassLoader;)Ljava/lang/Class;");
168   CHECK(forname_method != nullptr);
169 
170   return static_cast<jclass>(env->CallStaticObjectMethod(c_klass.get(),
171                                                          forname_method,
172                                                          name_str.get(),
173                                                          JNI_FALSE,
174                                                          class_loader));
175 }
176 
GetClass(jvmtiEnv * jvmti_env,JNIEnv * env,const char * class_name,jobject class_loader)177 jclass GetClass(jvmtiEnv* jvmti_env, JNIEnv* env, const char* class_name, jobject class_loader) {
178   if (class_loader != nullptr) {
179     return FindClassWithClassLoader(env, class_name, class_loader);
180   }
181 
182   jclass from_implied = env->FindClass(class_name);
183   if (from_implied != nullptr) {
184     return from_implied;
185   }
186   env->ExceptionClear();
187 
188   ScopedLocalRef<jobject> system_class_loader(env, GetSystemClassLoader(env));
189   CHECK(system_class_loader.get() != nullptr);
190   jclass from_system = FindClassWithClassLoader(env, class_name, system_class_loader.get());
191   if (from_system != nullptr) {
192     return from_system;
193   }
194   env->ExceptionClear();
195 
196   // Look at the context classloaders of all threads.
197   jint thread_count;
198   jthread* threads;
199   CheckJvmtiError(jvmti_env, jvmti_env->GetAllThreads(&thread_count, &threads));
200   JvmtiUniquePtr threads_uptr = MakeJvmtiUniquePtr(jvmti_env, threads);
201 
202   jclass result = nullptr;
203   for (jint t = 0; t != thread_count; ++t) {
204     // Always loop over all elements, as we need to free the local references.
205     if (result == nullptr) {
206       jvmtiThreadInfo info;
207       CheckJvmtiError(jvmti_env, jvmti_env->GetThreadInfo(threads[t], &info));
208       CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, info.name));
209       if (info.thread_group != nullptr) {
210         env->DeleteLocalRef(info.thread_group);
211       }
212       if (info.context_class_loader != nullptr) {
213         result = FindClassWithClassLoader(env, class_name, info.context_class_loader);
214         env->ExceptionClear();
215         env->DeleteLocalRef(info.context_class_loader);
216       }
217     }
218     env->DeleteLocalRef(threads[t]);
219   }
220 
221   if (result != nullptr) {
222     return result;
223   }
224 
225   // TODO: Implement scanning *all* classloaders.
226   LOG(WARNING) << "Scanning all classloaders unimplemented";
227 
228   return nullptr;
229 }
230 
BindFunctionsOnClass(jvmtiEnv * jvmti_env,JNIEnv * env,jclass klass)231 void BindFunctionsOnClass(jvmtiEnv* jvmti_env, JNIEnv* env, jclass klass) {
232   // Use JVMTI to get the methods.
233   jint method_count;
234   jmethodID* methods;
235   jvmtiError methods_result = jvmti_env->GetClassMethods(klass, &method_count, &methods);
236   CheckJvmtiError(jvmti_env, methods_result);
237 
238   // Check each method.
239   for (jint i = 0; i < method_count; ++i) {
240     jint modifiers;
241     jvmtiError mod_result = jvmti_env->GetMethodModifiers(methods[i], &modifiers);
242     CheckJvmtiError(jvmti_env, mod_result);
243     constexpr jint kNative = static_cast<jint>(0x0100);
244     if ((modifiers & kNative) != 0) {
245       BindMethod(jvmti_env, env, klass, methods[i]);
246     }
247   }
248 
249   CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, methods));
250 }
251 
BindFunctions(jvmtiEnv * jvmti_env,JNIEnv * env,const char * class_name,jobject class_loader)252 void BindFunctions(jvmtiEnv* jvmti_env, JNIEnv* env, const char* class_name, jobject class_loader) {
253   // Use JNI to load the class.
254   ScopedLocalRef<jclass> klass(env, GetClass(jvmti_env, env, class_name, class_loader));
255   CHECK(klass.get() != nullptr) << class_name;
256   BindFunctionsOnClass(jvmti_env, env, klass.get());
257 }
258 
259 }  // namespace art
260