1 /*
2  * Copyright 2008, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <errno.h>
18 #include <stdlib.h>
19 #include <string.h>
20 #include <sys/socket.h>
21 #include <sys/uio.h>
22 #include <linux/if_ether.h>
23 #include <linux/if_packet.h>
24 #include <netinet/in.h>
25 #include <netinet/ip.h>
26 #include <netinet/udp.h>
27 #include <unistd.h>
28 
29 #ifdef ANDROID
30 #define LOG_TAG "DHCP"
31 #include <log/log.h>
32 #else
33 #include <stdio.h>
34 #define ALOGD printf
35 #define ALOGW printf
36 #endif
37 
38 #include "dhcpmsg.h"
39 
40 int fatal(const char*);
41 
open_raw_socket(const char * ifname __unused,uint8_t hwaddr[ETH_ALEN],int if_index)42 int open_raw_socket(const char* ifname __unused, uint8_t hwaddr[ETH_ALEN], int if_index) {
43     int s = socket(PF_PACKET, SOCK_DGRAM | SOCK_CLOEXEC, 0);
44     if (s < 0) return fatal("socket(PF_PACKET)");
45 
46     struct sockaddr_ll bindaddr = {
47             .sll_family = AF_PACKET,
48             .sll_protocol = htons(ETH_P_IP),
49             .sll_ifindex = if_index,
50             .sll_halen = ETH_ALEN,
51     };
52     memcpy(bindaddr.sll_addr, hwaddr, ETH_ALEN);
53 
54     if (bind(s, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) < 0) {
55         close(s);
56         return fatal("Cannot bind raw socket to interface");
57     }
58 
59     return s;
60 }
61 
checksum(void * buffer,unsigned int count,uint32_t startsum)62 static uint32_t checksum(void *buffer, unsigned int count, uint32_t startsum)
63 {
64     uint16_t *up = (uint16_t *)buffer;
65     uint32_t sum = startsum;
66     uint32_t upper16;
67 
68     while (count > 1) {
69         sum += *up++;
70         count -= 2;
71     }
72     if (count > 0) {
73         sum += (uint16_t) *(uint8_t *)up;
74     }
75     while ((upper16 = (sum >> 16)) != 0) {
76         sum = (sum & 0xffff) + upper16;
77     }
78     return sum;
79 }
80 
finish_sum(uint32_t sum)81 static uint32_t finish_sum(uint32_t sum)
82 {
83     return ~sum & 0xffff;
84 }
85 
send_packet(int s,int if_index,struct dhcp_msg * msg,int size,uint32_t saddr,uint32_t daddr,uint32_t sport,uint32_t dport)86 int send_packet(int s, int if_index, struct dhcp_msg *msg, int size,
87                 uint32_t saddr, uint32_t daddr, uint32_t sport, uint32_t dport)
88 {
89     struct iphdr ip;
90     struct udphdr udp;
91     struct iovec iov[3];
92     uint32_t udpsum;
93     uint16_t temp;
94     struct msghdr msghdr;
95     struct sockaddr_ll destaddr;
96 
97     ip.version = IPVERSION;
98     ip.ihl = sizeof(ip) >> 2;
99     ip.tos = 0;
100     ip.tot_len = htons(sizeof(ip) + sizeof(udp) + size);
101     ip.id = 0;
102     ip.frag_off = 0;
103     ip.ttl = IPDEFTTL;
104     ip.protocol = IPPROTO_UDP;
105     ip.check = 0;
106     ip.saddr = saddr;
107     ip.daddr = daddr;
108     ip.check = finish_sum(checksum(&ip, sizeof(ip), 0));
109 
110     udp.source = htons(sport);
111     udp.dest = htons(dport);
112     udp.len = htons(sizeof(udp) + size);
113     udp.check = 0;
114 
115     /* Calculate checksum for pseudo header */
116     udpsum = checksum(&ip.saddr, sizeof(ip.saddr), 0);
117     udpsum = checksum(&ip.daddr, sizeof(ip.daddr), udpsum);
118     temp = htons(IPPROTO_UDP);
119     udpsum = checksum(&temp, sizeof(temp), udpsum);
120     temp = udp.len;
121     udpsum = checksum(&temp, sizeof(temp), udpsum);
122 
123     /* Add in the checksum for the udp header */
124     udpsum = checksum(&udp, sizeof(udp), udpsum);
125 
126     /* Add in the checksum for the data */
127     udpsum = checksum(msg, size, udpsum);
128     udp.check = finish_sum(udpsum);
129 
130     iov[0].iov_base = (char *)&ip;
131     iov[0].iov_len = sizeof(ip);
132     iov[1].iov_base = (char *)&udp;
133     iov[1].iov_len = sizeof(udp);
134     iov[2].iov_base = (char *)msg;
135     iov[2].iov_len = size;
136     memset(&destaddr, 0, sizeof(destaddr));
137     destaddr.sll_family = AF_PACKET;
138     destaddr.sll_protocol = htons(ETH_P_IP);
139     destaddr.sll_ifindex = if_index;
140     destaddr.sll_halen = ETH_ALEN;
141     memcpy(destaddr.sll_addr, "\xff\xff\xff\xff\xff\xff", ETH_ALEN);
142 
143     msghdr.msg_name = &destaddr;
144     msghdr.msg_namelen = sizeof(destaddr);
145     msghdr.msg_iov = iov;
146     msghdr.msg_iovlen = sizeof(iov) / sizeof(struct iovec);
147     msghdr.msg_flags = 0;
148     msghdr.msg_control = 0;
149     msghdr.msg_controllen = 0;
150     return sendmsg(s, &msghdr, 0);
151 }
152 
receive_packet(int s,struct dhcp_msg * msg)153 int receive_packet(int s, struct dhcp_msg *msg)
154 {
155     int nread;
156     int is_valid;
157     struct dhcp_packet {
158         struct iphdr ip;
159         struct udphdr udp;
160         struct dhcp_msg dhcp;
161     } packet;
162     int dhcp_size;
163     uint32_t sum;
164     uint16_t temp;
165     uint32_t saddr, daddr;
166 
167     nread = read(s, &packet, sizeof(packet));
168     if (nread < 0) {
169         return -1;
170     }
171     /*
172      * The raw packet interface gives us all packets received by the
173      * network interface. We need to filter out all packets that are
174      * not meant for us.
175      */
176     is_valid = 0;
177     if (nread < (int)(sizeof(struct iphdr) + sizeof(struct udphdr))) {
178 #if VERBOSE
179         ALOGD("Packet is too small (%d) to be a UDP datagram", nread);
180 #endif
181     } else if (packet.ip.version != IPVERSION || packet.ip.ihl != (sizeof(packet.ip) >> 2)) {
182 #if VERBOSE
183         ALOGD("Not a valid IP packet");
184 #endif
185     } else if (nread < ntohs(packet.ip.tot_len)) {
186 #if VERBOSE
187         ALOGD("Packet was truncated (read %d, needed %d)", nread, ntohs(packet.ip.tot_len));
188 #endif
189     } else if (packet.ip.protocol != IPPROTO_UDP) {
190 #if VERBOSE
191         ALOGD("IP protocol (%d) is not UDP", packet.ip.protocol);
192 #endif
193     } else if (packet.udp.dest != htons(PORT_BOOTP_CLIENT)) {
194 #if VERBOSE
195         ALOGD("UDP dest port (%d) is not DHCP client", ntohs(packet.udp.dest));
196 #endif
197     } else {
198         is_valid = 1;
199     }
200 
201     if (!is_valid) {
202         return -1;
203     }
204 
205     /* Seems like it's probably a valid DHCP packet */
206     /* validate IP header checksum */
207     sum = finish_sum(checksum(&packet.ip, sizeof(packet.ip), 0));
208     if (sum != 0) {
209         ALOGW("IP header checksum failure (0x%x)", packet.ip.check);
210         return -1;
211     }
212     /*
213      * Validate the UDP checksum.
214      * Since we don't need the IP header anymore, we "borrow" it
215      * to construct the pseudo header used in the checksum calculation.
216      */
217     dhcp_size = ntohs(packet.udp.len) - sizeof(packet.udp);
218     /*
219      * check validity of dhcp_size.
220      * 1) cannot be negative or zero.
221      * 2) src buffer contains enough bytes to copy
222      * 3) cannot exceed destination buffer
223      */
224     if ((dhcp_size <= 0) ||
225         ((int)(nread - sizeof(struct iphdr) - sizeof(struct udphdr)) < dhcp_size) ||
226         ((int)sizeof(struct dhcp_msg) < dhcp_size)) {
227 #if VERBOSE
228         ALOGD("Malformed Packet");
229 #endif
230         return -1;
231     }
232     saddr = packet.ip.saddr;
233     daddr = packet.ip.daddr;
234     nread = ntohs(packet.ip.tot_len);
235     memset(&packet.ip, 0, sizeof(packet.ip));
236     packet.ip.saddr = saddr;
237     packet.ip.daddr = daddr;
238     packet.ip.protocol = IPPROTO_UDP;
239     packet.ip.tot_len = packet.udp.len;
240     temp = packet.udp.check;
241     packet.udp.check = 0;
242     sum = finish_sum(checksum(&packet, nread, 0));
243     packet.udp.check = temp;
244     if (!sum)
245         sum = finish_sum(sum);
246     if (temp != sum) {
247         ALOGW("UDP header checksum failure (0x%x should be 0x%x)", sum, temp);
248         return -1;
249     }
250     memcpy(msg, &packet.dhcp, dhcp_size);
251     return dhcp_size;
252 }
253