1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 
21 import android.content.Context;
22 import android.net.ConnectivityManager;
23 import android.net.IpSecAlgorithm;
24 import android.net.IpSecManager;
25 import android.net.IpSecTransform;
26 import android.platform.test.annotations.AppModeFull;
27 import android.system.Os;
28 import android.system.OsConstants;
29 import android.util.Log;
30 
31 import androidx.test.InstrumentationRegistry;
32 import androidx.test.runner.AndroidJUnit4;
33 
34 import java.io.FileDescriptor;
35 import java.io.IOException;
36 import java.net.DatagramPacket;
37 import java.net.DatagramSocket;
38 import java.net.Inet6Address;
39 import java.net.InetAddress;
40 import java.net.InetSocketAddress;
41 import java.net.ServerSocket;
42 import java.net.Socket;
43 import java.net.SocketException;
44 import java.util.Arrays;
45 import java.util.concurrent.atomic.AtomicInteger;
46 
47 import org.junit.Before;
48 import org.junit.Test;
49 import org.junit.runner.RunWith;
50 
51 @RunWith(AndroidJUnit4.class)
52 public class IpSecBaseTest {
53 
54     private static final String TAG = IpSecBaseTest.class.getSimpleName();
55 
56     protected static final String IPV4_LOOPBACK = "127.0.0.1";
57     protected static final String IPV6_LOOPBACK = "::1";
58     protected static final String[] LOOPBACK_ADDRS = new String[] {IPV4_LOOPBACK, IPV6_LOOPBACK};
59     protected static final int[] DIRECTIONS =
60             new int[] {IpSecManager.DIRECTION_IN, IpSecManager.DIRECTION_OUT};
61 
62     protected static final byte[] TEST_DATA = "Best test data ever!".getBytes();
63     protected static final int DATA_BUFFER_LEN = 4096;
64     protected static final int SOCK_TIMEOUT = 500;
65 
66     private static final byte[] KEY_DATA = {
67         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
68         0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
69         0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
70         0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
71         0x20, 0x21, 0x22, 0x23
72     };
73 
74     protected static final byte[] AUTH_KEY = getKey(256);
75     protected static final byte[] CRYPT_KEY = getKey(256);
76 
77     protected ConnectivityManager mCM;
78     protected IpSecManager mISM;
79 
80     @Before
setUp()81     public void setUp() throws Exception {
82         mISM =
83                 (IpSecManager)
84                         InstrumentationRegistry.getContext()
85                                 .getSystemService(Context.IPSEC_SERVICE);
86         mCM =
87                 (ConnectivityManager)
88                         InstrumentationRegistry.getContext()
89                                 .getSystemService(Context.CONNECTIVITY_SERVICE);
90     }
91 
getKey(int bitLength)92     protected static byte[] getKey(int bitLength) {
93         return Arrays.copyOf(KEY_DATA, bitLength / 8);
94     }
95 
getDomain(InetAddress address)96     protected static int getDomain(InetAddress address) {
97         int domain;
98         if (address instanceof Inet6Address) {
99             domain = OsConstants.AF_INET6;
100         } else {
101             domain = OsConstants.AF_INET;
102         }
103         return domain;
104     }
105 
getPort(FileDescriptor sock)106     protected static int getPort(FileDescriptor sock) throws Exception {
107         return ((InetSocketAddress) Os.getsockname(sock)).getPort();
108     }
109 
110     public static interface GenericSocket extends AutoCloseable {
send(byte[] data)111         void send(byte[] data) throws Exception;
112 
receive()113         byte[] receive() throws Exception;
114 
getPort()115         int getPort() throws Exception;
116 
close()117         void close() throws Exception;
118 
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)119         void applyTransportModeTransform(
120                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception;
121 
removeTransportModeTransforms(IpSecManager ism)122         void removeTransportModeTransforms(IpSecManager ism) throws Exception;
123     }
124 
125     public static interface GenericTcpSocket extends GenericSocket {}
126 
127     public static interface GenericUdpSocket extends GenericSocket {
sendTo(byte[] data, InetAddress dstAddr, int port)128         void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception;
129     }
130 
131     public abstract static class NativeSocket implements GenericSocket {
132         public FileDescriptor mFd;
133 
NativeSocket(FileDescriptor fd)134         public NativeSocket(FileDescriptor fd) {
135             mFd = fd;
136         }
137 
138         @Override
send(byte[] data)139         public void send(byte[] data) throws Exception {
140             Os.write(mFd, data, 0, data.length);
141         }
142 
143         @Override
receive()144         public byte[] receive() throws Exception {
145             byte[] in = new byte[DATA_BUFFER_LEN];
146             AtomicInteger bytesRead = new AtomicInteger(-1);
147 
148             Thread readSockThread = new Thread(() -> {
149                 long startTime = System.currentTimeMillis();
150                 while (bytesRead.get() < 0 && System.currentTimeMillis() < startTime + SOCK_TIMEOUT) {
151                     try {
152                         bytesRead.set(Os.recvfrom(mFd, in, 0, DATA_BUFFER_LEN, 0, null));
153                     } catch (Exception e) {
154                         Log.e(TAG, "Error encountered reading from socket", e);
155                     }
156                 }
157             });
158 
159             readSockThread.start();
160             readSockThread.join(SOCK_TIMEOUT);
161 
162             if (bytesRead.get() < 0) {
163                 throw new IOException("No data received from socket");
164             }
165 
166             return Arrays.copyOfRange(in, 0, bytesRead.get());
167         }
168 
169         @Override
getPort()170         public int getPort() throws Exception {
171             return IpSecBaseTest.getPort(mFd);
172         }
173 
174         @Override
close()175         public void close() throws Exception {
176             Os.close(mFd);
177         }
178 
179         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)180         public void applyTransportModeTransform(
181                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
182             ism.applyTransportModeTransform(mFd, direction, transform);
183         }
184 
185         @Override
removeTransportModeTransforms(IpSecManager ism)186         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
187             ism.removeTransportModeTransforms(mFd);
188         }
189     }
190 
191     public static class NativeTcpSocket extends NativeSocket implements GenericTcpSocket {
NativeTcpSocket(FileDescriptor fd)192         public NativeTcpSocket(FileDescriptor fd) {
193             super(fd);
194         }
195     }
196 
197     public static class NativeUdpSocket extends NativeSocket implements GenericUdpSocket {
NativeUdpSocket(FileDescriptor fd)198         public NativeUdpSocket(FileDescriptor fd) {
199             super(fd);
200         }
201 
202         @Override
sendTo(byte[] data, InetAddress dstAddr, int port)203         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
204             Os.sendto(mFd, data, 0, data.length, 0, dstAddr, port);
205         }
206     }
207 
208     public static class JavaUdpSocket implements GenericUdpSocket {
209         public final DatagramSocket mSocket;
210 
JavaUdpSocket(InetAddress localAddr, int port)211         public JavaUdpSocket(InetAddress localAddr, int port) {
212             try {
213                 mSocket = new DatagramSocket(port, localAddr);
214                 mSocket.setSoTimeout(SOCK_TIMEOUT);
215             } catch (SocketException e) {
216                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
217                 // could easily end up in an endless wait.
218                 throw new RuntimeException(e);
219             }
220         }
221 
JavaUdpSocket(InetAddress localAddr)222         public JavaUdpSocket(InetAddress localAddr) {
223             try {
224                 mSocket = new DatagramSocket(0, localAddr);
225                 mSocket.setSoTimeout(SOCK_TIMEOUT);
226             } catch (SocketException e) {
227                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
228                 // could easily end up in an endless wait.
229                 throw new RuntimeException(e);
230             }
231         }
232 
233         @Override
send(byte[] data)234         public void send(byte[] data) throws Exception {
235             mSocket.send(new DatagramPacket(data, data.length));
236         }
237 
238         @Override
sendTo(byte[] data, InetAddress dstAddr, int port)239         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
240             mSocket.send(new DatagramPacket(data, data.length, dstAddr, port));
241         }
242 
243         @Override
getPort()244         public int getPort() throws Exception {
245             return mSocket.getLocalPort();
246         }
247 
248         @Override
close()249         public void close() throws Exception {
250             mSocket.close();
251         }
252 
253         @Override
receive()254         public byte[] receive() throws Exception {
255             DatagramPacket data = new DatagramPacket(new byte[DATA_BUFFER_LEN], DATA_BUFFER_LEN);
256             mSocket.receive(data);
257             return Arrays.copyOfRange(data.getData(), 0, data.getLength());
258         }
259 
260         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)261         public void applyTransportModeTransform(
262                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
263             ism.applyTransportModeTransform(mSocket, direction, transform);
264         }
265 
266         @Override
removeTransportModeTransforms(IpSecManager ism)267         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
268             ism.removeTransportModeTransforms(mSocket);
269         }
270     }
271 
272     public static class JavaTcpSocket implements GenericTcpSocket {
273         public final Socket mSocket;
274 
JavaTcpSocket(Socket socket)275         public JavaTcpSocket(Socket socket) {
276             mSocket = socket;
277             try {
278                 mSocket.setSoTimeout(SOCK_TIMEOUT);
279             } catch (SocketException e) {
280                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
281                 // could easily end up in an endless wait.
282                 throw new RuntimeException(e);
283             }
284         }
285 
286         @Override
send(byte[] data)287         public void send(byte[] data) throws Exception {
288             mSocket.getOutputStream().write(data);
289         }
290 
291         @Override
receive()292         public byte[] receive() throws Exception {
293             byte[] in = new byte[DATA_BUFFER_LEN];
294             int bytesRead = mSocket.getInputStream().read(in);
295             return Arrays.copyOfRange(in, 0, bytesRead);
296         }
297 
298         @Override
getPort()299         public int getPort() throws Exception {
300             return mSocket.getLocalPort();
301         }
302 
303         @Override
close()304         public void close() throws Exception {
305             mSocket.close();
306         }
307 
308         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)309         public void applyTransportModeTransform(
310                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
311             ism.applyTransportModeTransform(mSocket, direction, transform);
312         }
313 
314         @Override
removeTransportModeTransforms(IpSecManager ism)315         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
316             ism.removeTransportModeTransforms(mSocket);
317         }
318     }
319 
320     public static class SocketPair<T> {
321         public final T mLeftSock;
322         public final T mRightSock;
323 
SocketPair(T leftSock, T rightSock)324         public SocketPair(T leftSock, T rightSock) {
325             mLeftSock = leftSock;
326             mRightSock = rightSock;
327         }
328     }
329 
applyTransformBidirectionally( IpSecManager ism, IpSecTransform transform, GenericSocket socket)330     protected static void applyTransformBidirectionally(
331             IpSecManager ism, IpSecTransform transform, GenericSocket socket) throws Exception {
332         for (int direction : DIRECTIONS) {
333             socket.applyTransportModeTransform(ism, direction, transform);
334         }
335     }
336 
getNativeUdpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)337     public static SocketPair<NativeUdpSocket> getNativeUdpSocketPair(
338             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
339             throws Exception {
340         int domain = getDomain(localAddr);
341 
342         NativeUdpSocket leftSock = new NativeUdpSocket(
343             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
344         NativeUdpSocket rightSock = new NativeUdpSocket(
345             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
346 
347         for (NativeUdpSocket sock : new NativeUdpSocket[] {leftSock, rightSock}) {
348             applyTransformBidirectionally(ism, transform, sock);
349             Os.bind(sock.mFd, localAddr, 0);
350         }
351 
352         if (connected) {
353             Os.connect(leftSock.mFd, localAddr, rightSock.getPort());
354             Os.connect(rightSock.mFd, localAddr, leftSock.getPort());
355         }
356 
357         return new SocketPair<>(leftSock, rightSock);
358     }
359 
getNativeTcpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform)360     public static SocketPair<NativeTcpSocket> getNativeTcpSocketPair(
361             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
362         int domain = getDomain(localAddr);
363 
364         NativeTcpSocket server = new NativeTcpSocket(
365                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
366         NativeTcpSocket client = new NativeTcpSocket(
367                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
368 
369         Os.bind(server.mFd, localAddr, 0);
370 
371         applyTransformBidirectionally(ism, transform, server);
372         applyTransformBidirectionally(ism, transform, client);
373 
374         Os.listen(server.mFd, 10);
375         Os.connect(client.mFd, localAddr, server.getPort());
376         NativeTcpSocket accepted = new NativeTcpSocket(Os.accept(server.mFd, null));
377 
378         applyTransformBidirectionally(ism, transform, accepted);
379         server.close();
380 
381         return new SocketPair<>(client, accepted);
382     }
383 
getJavaUdpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)384     public static SocketPair<JavaUdpSocket> getJavaUdpSocketPair(
385             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
386             throws Exception {
387         JavaUdpSocket leftSock = new JavaUdpSocket(localAddr);
388         JavaUdpSocket rightSock = new JavaUdpSocket(localAddr);
389 
390         applyTransformBidirectionally(ism, transform, leftSock);
391         applyTransformBidirectionally(ism, transform, rightSock);
392 
393         if (connected) {
394             leftSock.mSocket.connect(localAddr, rightSock.mSocket.getLocalPort());
395             rightSock.mSocket.connect(localAddr, leftSock.mSocket.getLocalPort());
396         }
397 
398         return new SocketPair<>(leftSock, rightSock);
399     }
400 
getJavaTcpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform)401     public static SocketPair<JavaTcpSocket> getJavaTcpSocketPair(
402             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
403         JavaTcpSocket clientSock = new JavaTcpSocket(new Socket());
404         ServerSocket serverSocket = new ServerSocket();
405         serverSocket.bind(new InetSocketAddress(localAddr, 0));
406 
407         // While technically the client socket does not need to be bound, the OpenJDK implementation
408         // of Socket only allocates an FD when bind() or connect() or other similar methods are
409         // called. So we call bind to force the FD creation, so that we can apply a transform to it
410         // prior to socket connect.
411         clientSock.mSocket.bind(new InetSocketAddress(localAddr, 0));
412 
413         // IpSecService doesn't support serverSockets at the moment; workaround using FD
414         FileDescriptor serverFd = serverSocket.getImpl().getFD$();
415 
416         applyTransformBidirectionally(ism, transform, new NativeTcpSocket(serverFd));
417         applyTransformBidirectionally(ism, transform, clientSock);
418 
419         clientSock.mSocket.connect(new InetSocketAddress(localAddr, serverSocket.getLocalPort()));
420         JavaTcpSocket acceptedSock = new JavaTcpSocket(serverSocket.accept());
421 
422         applyTransformBidirectionally(ism, transform, acceptedSock);
423         serverSocket.close();
424 
425         return new SocketPair<>(clientSock, acceptedSock);
426     }
427 
checkSocketPair(GenericSocket left, GenericSocket right)428     private void checkSocketPair(GenericSocket left, GenericSocket right) throws Exception {
429         left.send(TEST_DATA);
430         assertArrayEquals(TEST_DATA, right.receive());
431 
432         right.send(TEST_DATA);
433         assertArrayEquals(TEST_DATA, left.receive());
434 
435         left.close();
436         right.close();
437     }
438 
checkUnconnectedUdpSocketPair( GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr)439     private void checkUnconnectedUdpSocketPair(
440             GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr) throws Exception {
441         left.sendTo(TEST_DATA, localAddr, right.getPort());
442         assertArrayEquals(TEST_DATA, right.receive());
443 
444         right.sendTo(TEST_DATA, localAddr, left.getPort());
445         assertArrayEquals(TEST_DATA, left.receive());
446 
447         left.close();
448         right.close();
449     }
450 
buildIpSecTransform( Context context, IpSecManager.SecurityParameterIndex spi, IpSecManager.UdpEncapsulationSocket encapSocket, InetAddress remoteAddr)451     protected static IpSecTransform buildIpSecTransform(
452             Context context,
453             IpSecManager.SecurityParameterIndex spi,
454             IpSecManager.UdpEncapsulationSocket encapSocket,
455             InetAddress remoteAddr)
456             throws Exception {
457         IpSecTransform.Builder builder =
458                 new IpSecTransform.Builder(context)
459                         .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
460                         .setAuthentication(
461                                 new IpSecAlgorithm(
462                                         IpSecAlgorithm.AUTH_HMAC_SHA256,
463                                         AUTH_KEY,
464                                         AUTH_KEY.length * 4));
465 
466         if (encapSocket != null) {
467             builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
468         }
469 
470         return builder.buildTransportModeTransform(remoteAddr, spi);
471     }
472 
buildDefaultTransform(InetAddress localAddr)473     private IpSecTransform buildDefaultTransform(InetAddress localAddr) throws Exception {
474         try (IpSecManager.SecurityParameterIndex spi =
475                 mISM.allocateSecurityParameterIndex(localAddr)) {
476             return buildIpSecTransform(InstrumentationRegistry.getContext(), spi, null, localAddr);
477         }
478     }
479 
480     @Test
481     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testJavaTcpSocketPair()482     public void testJavaTcpSocketPair() throws Exception {
483         for (String addr : LOOPBACK_ADDRS) {
484             InetAddress local = InetAddress.getByName(addr);
485             try (IpSecTransform transform = buildDefaultTransform(local)) {
486                 SocketPair<JavaTcpSocket> sockets = getJavaTcpSocketPair(local, mISM, transform);
487                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
488             }
489         }
490     }
491 
492     @Test
493     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testJavaUdpSocketPair()494     public void testJavaUdpSocketPair() throws Exception {
495         for (String addr : LOOPBACK_ADDRS) {
496             InetAddress local = InetAddress.getByName(addr);
497             try (IpSecTransform transform = buildDefaultTransform(local)) {
498                 SocketPair<JavaUdpSocket> sockets =
499                         getJavaUdpSocketPair(local, mISM, transform, true);
500                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
501             }
502         }
503     }
504 
505     @Test
506     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testJavaUdpSocketPairUnconnected()507     public void testJavaUdpSocketPairUnconnected() throws Exception {
508         for (String addr : LOOPBACK_ADDRS) {
509             InetAddress local = InetAddress.getByName(addr);
510             try (IpSecTransform transform = buildDefaultTransform(local)) {
511                 SocketPair<JavaUdpSocket> sockets =
512                         getJavaUdpSocketPair(local, mISM, transform, false);
513                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
514             }
515         }
516     }
517 
518     @Test
519     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testNativeTcpSocketPair()520     public void testNativeTcpSocketPair() throws Exception {
521         for (String addr : LOOPBACK_ADDRS) {
522             InetAddress local = InetAddress.getByName(addr);
523             try (IpSecTransform transform = buildDefaultTransform(local)) {
524                 SocketPair<NativeTcpSocket> sockets =
525                         getNativeTcpSocketPair(local, mISM, transform);
526                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
527             }
528         }
529     }
530 
531     @Test
532     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testNativeUdpSocketPair()533     public void testNativeUdpSocketPair() throws Exception {
534         for (String addr : LOOPBACK_ADDRS) {
535             InetAddress local = InetAddress.getByName(addr);
536             try (IpSecTransform transform = buildDefaultTransform(local)) {
537                 SocketPair<NativeUdpSocket> sockets =
538                         getNativeUdpSocketPair(local, mISM, transform, true);
539                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
540             }
541         }
542     }
543 
544     @Test
545     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testNativeUdpSocketPairUnconnected()546     public void testNativeUdpSocketPairUnconnected() throws Exception {
547         for (String addr : LOOPBACK_ADDRS) {
548             InetAddress local = InetAddress.getByName(addr);
549             try (IpSecTransform transform = buildDefaultTransform(local)) {
550                 SocketPair<NativeUdpSocket> sockets =
551                         getNativeUdpSocketPair(local, mISM, transform, false);
552                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
553             }
554         }
555     }
556 }
557