1 /*
2  * Copyright (C) 2006 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "LocalSocketImpl"
18 
19 #include <nativehelper/JNIPlatformHelp.h>
20 #include "jni.h"
21 #include "utils/Log.h"
22 #include "utils/misc.h"
23 
24 #include <stdio.h>
25 #include <string.h>
26 #include <sys/types.h>
27 #include <sys/socket.h>
28 #include <sys/un.h>
29 #include <arpa/inet.h>
30 #include <netinet/in.h>
31 #include <stdlib.h>
32 #include <errno.h>
33 #include <unistd.h>
34 #include <sys/ioctl.h>
35 
36 #include <android-base/cmsg.h>
37 #include <android-base/macros.h>
38 #include <cutils/sockets.h>
39 #include <netinet/tcp.h>
40 #include <nativehelper/ScopedUtfChars.h>
41 
42 using android::base::ReceiveFileDescriptorVector;
43 using android::base::SendFileDescriptorVector;
44 
45 namespace android {
46 
47 static jfieldID field_inboundFileDescriptors;
48 static jfieldID field_outboundFileDescriptors;
49 static jclass class_Credentials;
50 static jclass class_FileDescriptor;
51 static jmethodID method_CredentialsInit;
52 
53 /* private native void connectLocal(FileDescriptor fd,
54  * String name, int namespace) throws IOException
55  */
56 static void
socket_connect_local(JNIEnv * env,jobject object,jobject fileDescriptor,jstring name,jint namespaceId)57 socket_connect_local(JNIEnv *env, jobject object,
58                         jobject fileDescriptor, jstring name, jint namespaceId)
59 {
60     int ret;
61     int fd;
62 
63     if (name == NULL) {
64         jniThrowNullPointerException(env, NULL);
65         return;
66     }
67 
68     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
69 
70     if (env->ExceptionCheck()) {
71         return;
72     }
73 
74     ScopedUtfChars nameUtf8(env, name);
75 
76     ret = socket_local_client_connect(
77                 fd,
78                 nameUtf8.c_str(),
79                 namespaceId,
80                 SOCK_STREAM);
81 
82     if (ret < 0) {
83         jniThrowIOException(env, errno);
84         return;
85     }
86 }
87 
88 #define DEFAULT_BACKLOG 4
89 
90 /* private native void bindLocal(FileDescriptor fd, String name, namespace)
91  * throws IOException;
92  */
93 
94 static void
socket_bind_local(JNIEnv * env,jobject object,jobject fileDescriptor,jstring name,jint namespaceId)95 socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
96                 jstring name, jint namespaceId)
97 {
98     int ret;
99     int fd;
100 
101     if (name == NULL) {
102         jniThrowNullPointerException(env, NULL);
103         return;
104     }
105 
106     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
107 
108     if (env->ExceptionCheck()) {
109         return;
110     }
111 
112     ScopedUtfChars nameUtf8(env, name);
113 
114     ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
115 
116     if (ret < 0) {
117         jniThrowIOException(env, errno);
118         return;
119     }
120 }
121 
122 /**
123  * Reads data from a socket into buf, processing any ancillary data
124  * and adding it to thisJ.
125  *
126  * Returns the length of normal data read, or -1 if an exception has
127  * been thrown in this function.
128  */
socket_read_all(JNIEnv * env,jobject thisJ,int fd,void * buffer,size_t len)129 static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
130         void *buffer, size_t len)
131 {
132     ssize_t ret;
133     std::vector<android::base::unique_fd> received_fds;
134 
135     ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds);
136 
137     if (ret < 0) {
138         if (errno == EPIPE) {
139             // Treat this as an end of stream
140             return 0;
141         }
142 
143         jniThrowIOException(env, errno);
144         return -1;
145     }
146 
147     if (received_fds.size() > 0) {
148         jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL);
149 
150         if (fdArray == NULL) {
151             // NewObjectArray has thrown.
152             return -1;
153         }
154 
155         for (size_t i = 0; i < received_fds.size(); i++) {
156             jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get());
157 
158             if (env->ExceptionCheck()) {
159                 return -1;
160             }
161 
162             env->SetObjectArrayElement(fdArray, i, fdObject);
163 
164             if (env->ExceptionCheck()) {
165                 return -1;
166             }
167         }
168 
169         for (auto &fd : received_fds) {
170             // The fds are stored in java.io.FileDescriptors now.
171             static_cast<void>(fd.release());
172         }
173 
174         env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
175     }
176 
177     return ret;
178 }
179 
180 /**
181  * Writes all the data in the specified buffer to the specified socket.
182  *
183  * Returns 0 on success or -1 if an exception was thrown.
184  */
socket_write_all(JNIEnv * env,jobject object,int fd,void * buf,size_t len)185 static int socket_write_all(JNIEnv *env, jobject object, int fd,
186         void *buf, size_t len)
187 {
188     struct msghdr msg;
189     unsigned char *buffer = (unsigned char *)buf;
190     memset(&msg, 0, sizeof(msg));
191 
192     jobjectArray outboundFds
193             = (jobjectArray)env->GetObjectField(
194                 object, field_outboundFileDescriptors);
195 
196     if (env->ExceptionCheck()) {
197         return -1;
198     }
199 
200     int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
201     std::vector<int> fds;
202 
203     // Add any pending outbound file descriptors to the message
204     if (outboundFds != NULL) {
205         if (env->ExceptionCheck()) {
206             return -1;
207         }
208 
209         for (int i = 0; i < countFds; i++) {
210             jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
211             if (env->ExceptionCheck()) {
212                 return -1;
213             }
214 
215             fds.push_back(jniGetFDFromFileDescriptor(env, fdObject));
216             if (env->ExceptionCheck()) {
217                 return -1;
218             }
219         }
220     }
221 
222     ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds);
223 
224     while (rc != len) {
225         if (rc == -1) {
226             jniThrowIOException(env, errno);
227             return -1;
228         }
229 
230         buffer += rc;
231         len -= rc;
232 
233         rc = send(fd, buffer, len, MSG_NOSIGNAL);
234     }
235 
236     return 0;
237 }
238 
socket_read(JNIEnv * env,jobject object,jobject fileDescriptor)239 static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
240 {
241     int fd;
242     int err;
243 
244     if (fileDescriptor == NULL) {
245         jniThrowNullPointerException(env, NULL);
246         return (jint)-1;
247     }
248 
249     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
250 
251     if (env->ExceptionCheck()) {
252         return (jint)0;
253     }
254 
255     unsigned char buf;
256 
257     err = socket_read_all(env, object, fd, &buf, 1);
258 
259     if (err < 0) {
260         jniThrowIOException(env, errno);
261         return (jint)0;
262     }
263 
264     if (err == 0) {
265         // end of file
266         return (jint)-1;
267     }
268 
269     return (jint)buf;
270 }
271 
socket_readba(JNIEnv * env,jobject object,jbyteArray buffer,jint off,jint len,jobject fileDescriptor)272 static jint socket_readba (JNIEnv *env, jobject object,
273         jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
274 {
275     int fd;
276     jbyte* byteBuffer;
277     int ret;
278 
279     if (fileDescriptor == NULL || buffer == NULL) {
280         jniThrowNullPointerException(env, NULL);
281         return (jint)-1;
282     }
283 
284     if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
285         jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
286         return (jint)-1;
287     }
288 
289     if (len == 0) {
290         // because socket_read_all returns 0 on EOF
291         return 0;
292     }
293 
294     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
295 
296     if (env->ExceptionCheck()) {
297         return (jint)-1;
298     }
299 
300     byteBuffer = env->GetByteArrayElements(buffer, NULL);
301 
302     if (NULL == byteBuffer) {
303         // an exception will have been thrown
304         return (jint)-1;
305     }
306 
307     ret = socket_read_all(env, object,
308             fd, byteBuffer + off, len);
309 
310     // A return of -1 above means an exception is pending
311 
312     env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
313 
314     return (jint) ((ret == 0) ? -1 : ret);
315 }
316 
socket_write(JNIEnv * env,jobject object,jint b,jobject fileDescriptor)317 static void socket_write (JNIEnv *env, jobject object,
318         jint b, jobject fileDescriptor)
319 {
320     int fd;
321     int err;
322 
323     if (fileDescriptor == NULL) {
324         jniThrowNullPointerException(env, NULL);
325         return;
326     }
327 
328     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
329 
330     if (env->ExceptionCheck()) {
331         return;
332     }
333 
334     err = socket_write_all(env, object, fd, &b, 1);
335     UNUSED(err);
336     // A return of -1 above means an exception is pending
337 }
338 
socket_writeba(JNIEnv * env,jobject object,jbyteArray buffer,jint off,jint len,jobject fileDescriptor)339 static void socket_writeba (JNIEnv *env, jobject object,
340         jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
341 {
342     int fd;
343     int err;
344     jbyte* byteBuffer;
345 
346     if (fileDescriptor == NULL || buffer == NULL) {
347         jniThrowNullPointerException(env, NULL);
348         return;
349     }
350 
351     if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
352         jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
353         return;
354     }
355 
356     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
357 
358     if (env->ExceptionCheck()) {
359         return;
360     }
361 
362     byteBuffer = env->GetByteArrayElements(buffer,NULL);
363 
364     if (NULL == byteBuffer) {
365         // an exception will have been thrown
366         return;
367     }
368 
369     err = socket_write_all(env, object, fd,
370             byteBuffer + off, len);
371     UNUSED(err);
372     // A return of -1 above means an exception is pending
373 
374     env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
375 }
376 
socket_get_peer_credentials(JNIEnv * env,jobject object,jobject fileDescriptor)377 static jobject socket_get_peer_credentials(JNIEnv *env,
378         jobject object, jobject fileDescriptor)
379 {
380     int err;
381     int fd;
382 
383     if (fileDescriptor == NULL) {
384         jniThrowNullPointerException(env, NULL);
385         return NULL;
386     }
387 
388     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
389 
390     if (env->ExceptionCheck()) {
391         return NULL;
392     }
393 
394     struct ucred creds;
395 
396     memset(&creds, 0, sizeof(creds));
397     socklen_t szCreds = sizeof(creds);
398 
399     err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
400 
401     if (err < 0) {
402         jniThrowIOException(env, errno);
403         return NULL;
404     }
405 
406     if (szCreds == 0) {
407         return NULL;
408     }
409 
410     return env->NewObject(class_Credentials, method_CredentialsInit,
411             creds.pid, creds.uid, creds.gid);
412 }
413 
414 /*
415  * JNI registration.
416  */
417 static const JNINativeMethod gMethods[] = {
418      /* name, signature, funcPtr */
419     {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
420                                                 (void*)socket_connect_local},
421     {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
422     {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
423     {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
424     {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
425     {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
426     {"getPeerCredentials_native",
427             "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
428             (void*) socket_get_peer_credentials}
429 };
430 
register_android_net_LocalSocketImpl(JNIEnv * env)431 int register_android_net_LocalSocketImpl(JNIEnv *env)
432 {
433     jclass clazz;
434 
435     clazz = env->FindClass("android/net/LocalSocketImpl");
436 
437     if (clazz == NULL) {
438         goto error;
439     }
440 
441     field_inboundFileDescriptors = env->GetFieldID(clazz,
442             "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
443 
444     if (field_inboundFileDescriptors == NULL) {
445         goto error;
446     }
447 
448     field_outboundFileDescriptors = env->GetFieldID(clazz,
449             "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
450 
451     if (field_outboundFileDescriptors == NULL) {
452         goto error;
453     }
454 
455     class_Credentials = env->FindClass("android/net/Credentials");
456 
457     if (class_Credentials == NULL) {
458         goto error;
459     }
460 
461     class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
462 
463     class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
464 
465     if (class_FileDescriptor == NULL) {
466         goto error;
467     }
468 
469     class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
470 
471     method_CredentialsInit
472             = env->GetMethodID(class_Credentials, "<init>", "(III)V");
473 
474     if (method_CredentialsInit == NULL) {
475         goto error;
476     }
477 
478     return jniRegisterNativeMethods(env,
479         "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
480 
481 error:
482     ALOGE("Error registering android.net.LocalSocketImpl");
483     return -1;
484 }
485 
486 };
487