1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "wificond/net/nl80211_packet.h"
18 
19 #include <android-base/logging.h>
20 
21 using std::vector;
22 
23 namespace android {
24 namespace wificond {
25 
NL80211Packet(const vector<uint8_t> & data)26 NL80211Packet::NL80211Packet(const vector<uint8_t>& data)
27     : data_(data) {
28   data_ = data;
29 }
30 
NL80211Packet(const NL80211Packet & packet)31 NL80211Packet::NL80211Packet(const NL80211Packet& packet) {
32   data_ = packet.data_;
33   LOG(WARNING) << "Copy constructor is only used for unit tests";
34 }
35 
NL80211Packet(uint16_t type,uint8_t command,uint32_t sequence,uint32_t pid)36 NL80211Packet::NL80211Packet(uint16_t type,
37                              uint8_t command,
38                              uint32_t sequence,
39                              uint32_t pid) {
40   // Initialize the netlink header and generic netlink header.
41   // NLMSG_HDRLEN and GENL_HDRLEN already include the padding size.
42   data_.resize(NLMSG_HDRLEN + GENL_HDRLEN, 0);
43   // Initialize length field.
44   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
45   nl_header->nlmsg_len = data_.size();
46   // Add NLM_F_REQUEST flag.
47   nl_header->nlmsg_flags = nl_header->nlmsg_flags | NLM_F_REQUEST;
48   nl_header->nlmsg_type = type;
49   nl_header->nlmsg_seq = sequence;
50   nl_header->nlmsg_pid = pid;
51 
52   genlmsghdr* genl_header =
53       reinterpret_cast<genlmsghdr*>(data_.data() + NLMSG_HDRLEN);
54   genl_header->version = 1;
55   genl_header->cmd = command;
56   // genl_header->reserved is aready 0.
57 }
58 
IsValid() const59 bool NL80211Packet::IsValid() const {
60   // Verify the size of packet.
61   if (data_.size() < NLMSG_HDRLEN) {
62     LOG(ERROR) << "Cannot retrieve netlink header.";
63     return false;
64   }
65 
66   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
67 
68   // If type < NLMSG_MIN_TYPE, this should be a reserved control message,
69   // which doesn't carry a generic netlink header.
70   if (GetMessageType() >= NLMSG_MIN_TYPE) {
71     if (data_.size() < NLMSG_HDRLEN + GENL_HDRLEN ||
72         nl_header->nlmsg_len < NLMSG_HDRLEN + GENL_HDRLEN) {
73       LOG(ERROR) << "Cannot retrieve generic netlink header.";
74       return false;
75     }
76   }
77   // If it is an ERROR message, it should be long enough to carry an extra error
78   // code field.
79   // Kernel uses int for this field.
80   if (GetMessageType() == NLMSG_ERROR) {
81     if (data_.size() < NLMSG_HDRLEN + sizeof(int) ||
82         nl_header->nlmsg_len < NLMSG_HDRLEN + sizeof(int)) {
83      LOG(ERROR) << "Broken error message.";
84      return false;
85     }
86   }
87 
88   // Verify the netlink header.
89   if (data_.size() < nl_header->nlmsg_len ||
90       nl_header->nlmsg_len < sizeof(nlmsghdr)) {
91     LOG(ERROR) << "Discarding incomplete / invalid message.";
92     return false;
93   }
94   return true;
95 }
96 
IsDump() const97 bool NL80211Packet::IsDump() const {
98   return GetFlags() & NLM_F_DUMP;
99 }
100 
IsMulti() const101 bool NL80211Packet::IsMulti() const {
102   return GetFlags() & NLM_F_MULTI;
103 }
104 
GetCommand() const105 uint8_t NL80211Packet::GetCommand() const {
106   const genlmsghdr* genl_header = reinterpret_cast<const genlmsghdr*>(
107       data_.data() + NLMSG_HDRLEN);
108   return genl_header->cmd;
109 }
110 
GetFlags() const111 uint16_t NL80211Packet::GetFlags() const {
112   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
113   return nl_header->nlmsg_flags;
114 }
115 
GetMessageType() const116 uint16_t NL80211Packet::GetMessageType() const {
117   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
118   return nl_header->nlmsg_type;
119 }
120 
GetMessageSequence() const121 uint32_t NL80211Packet::GetMessageSequence() const {
122   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
123   return nl_header->nlmsg_seq;
124 }
125 
GetPortId() const126 uint32_t NL80211Packet::GetPortId() const {
127   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
128   return nl_header->nlmsg_pid;
129 }
130 
GetErrorCode() const131 int NL80211Packet::GetErrorCode() const {
132   return -*reinterpret_cast<const int*>(data_.data() + NLMSG_HDRLEN);
133 }
134 
GetConstData() const135 const vector<uint8_t>& NL80211Packet::GetConstData() const {
136   return data_;
137 }
138 
SetCommand(uint8_t command)139 void NL80211Packet::SetCommand(uint8_t command) {
140   genlmsghdr* genl_header = reinterpret_cast<genlmsghdr*>(
141       data_.data() + NLMSG_HDRLEN);
142   genl_header->cmd = command;
143 }
144 
AddFlag(uint16_t flag)145 void NL80211Packet::AddFlag(uint16_t flag) {
146   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
147   nl_header->nlmsg_flags |= flag;
148 }
149 
SetFlags(uint16_t flags)150 void NL80211Packet::SetFlags(uint16_t flags) {
151   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
152   nl_header->nlmsg_flags = flags;
153 }
154 
SetMessageType(uint16_t message_type)155 void NL80211Packet::SetMessageType(uint16_t message_type) {
156   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
157   nl_header->nlmsg_type = message_type;
158 }
159 
SetMessageSequence(uint32_t message_sequence)160 void NL80211Packet::SetMessageSequence(uint32_t message_sequence) {
161   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
162   nl_header->nlmsg_seq = message_sequence;
163 }
164 
SetPortId(uint32_t pid)165 void NL80211Packet::SetPortId(uint32_t pid) {
166   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
167   nl_header->nlmsg_pid = pid;
168 }
169 
AddAttribute(const BaseNL80211Attr & attribute)170 void NL80211Packet::AddAttribute(const BaseNL80211Attr& attribute) {
171   const vector<uint8_t>& append_data = attribute.GetConstData();
172   // Append the data of |attribute| to |this|.
173   data_.insert(data_.end(), append_data.begin(), append_data.end());
174   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
175   // We don't need to worry about padding for a nl80211 packet.
176   // Because as long as all sub attributes have padding, the payload is aligned.
177   nl_header->nlmsg_len += append_data.size();
178 }
179 
AddFlagAttribute(int attribute_id)180 void NL80211Packet::AddFlagAttribute(int attribute_id) {
181   // We only need to append a header for flag attribute.
182   // Make space for the new attribute.
183   data_.resize(data_.size() + NLA_HDRLEN, 0);
184   nlattr* flag_header =
185       reinterpret_cast<nlattr*>(data_.data() + data_.size() - NLA_HDRLEN);
186   flag_header->nla_type = attribute_id;
187   flag_header->nla_len = NLA_HDRLEN;
188   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
189   nl_header->nlmsg_len += NLA_HDRLEN;
190 }
191 
HasAttribute(int id) const192 bool NL80211Packet::HasAttribute(int id) const {
193   return BaseNL80211Attr::GetAttributeImpl(
194       data_.data() + NLMSG_HDRLEN + GENL_HDRLEN,
195       data_.size() - NLMSG_HDRLEN - GENL_HDRLEN,
196       id, nullptr, nullptr);
197 }
198 
GetAttribute(int id,NL80211NestedAttr * attribute) const199 bool NL80211Packet::GetAttribute(int id,
200     NL80211NestedAttr* attribute) const {
201   uint8_t* start = nullptr;
202   uint8_t* end = nullptr;
203   if (!BaseNL80211Attr::GetAttributeImpl(
204           data_.data() + NLMSG_HDRLEN + GENL_HDRLEN,
205           data_.size() - NLMSG_HDRLEN - GENL_HDRLEN,
206           id, &start, &end) ||
207       start == nullptr ||
208       end == nullptr) {
209     return false;
210   }
211   *attribute = NL80211NestedAttr(vector<uint8_t>(start, end));
212   if (!attribute->IsValid()) {
213     return false;
214   }
215   return true;
216 }
217 
GetAllAttributes(vector<BaseNL80211Attr> * attributes) const218 bool NL80211Packet::GetAllAttributes(
219     vector<BaseNL80211Attr>* attributes) const {
220   const uint8_t* ptr = data_.data() + NLMSG_HDRLEN + GENL_HDRLEN;
221   const uint8_t* end_ptr = data_.data() + data_.size();
222   while (ptr + NLA_HDRLEN <= end_ptr) {
223     auto header = reinterpret_cast<const nlattr*>(ptr);
224     if (ptr + NLA_ALIGN(header->nla_len) > end_ptr ||
225       header->nla_len == 0) {
226       LOG(ERROR) << "broken nl80211 atrribute.";
227       return false;
228     }
229     attributes->emplace_back(
230         header->nla_type,
231         vector<uint8_t>(ptr + NLA_HDRLEN, ptr + header->nla_len));
232     ptr += NLA_ALIGN(header->nla_len);
233   }
234   return true;
235 }
236 
DebugLog() const237 void NL80211Packet::DebugLog() const {
238   const uint8_t* ptr = data_.data() + NLMSG_HDRLEN + GENL_HDRLEN;
239   const uint8_t* end_ptr = data_.data() + data_.size();
240   while (ptr + NLA_HDRLEN <= end_ptr) {
241     const nlattr* header = reinterpret_cast<const nlattr*>(ptr);
242     if (ptr + NLA_ALIGN(header->nla_len) > end_ptr) {
243       LOG(ERROR) << "broken nl80211 atrribute.";
244       return;
245     }
246     LOG(INFO) << "Have attribute with nla_type=" << header->nla_type
247               << " and nla_len=" << header->nla_len;
248     if (header->nla_len == 0) {
249       LOG(ERROR) << "0 is a bad nla_len";
250       return;
251     }
252     ptr += NLA_ALIGN(header->nla_len);
253   }
254 }
255 
256 }  // namespace wificond
257 }  // namespace android
258