1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/libs/net/netlink_client.h"
17 
18 #include <linux/rtnetlink.h>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include <android-base/logging.h>
23 
24 #include <iostream>
25 #include <memory>
26 
27 using ::testing::ElementsAreArray;
28 using ::testing::MatchResultListener;
29 using ::testing::Return;
30 
31 namespace cuttlefish {
32 namespace {
klog_write(int,const char *,...)33 extern "C" void klog_write(int /* level */, const char* /* format */, ...) {}
34 
35 // Dump hex buffer to test log.
Dump(MatchResultListener * result_listener,const char * title,const uint8_t * data,size_t length)36 void Dump(MatchResultListener* result_listener, const char* title,
37           const uint8_t* data, size_t length) {
38   for (size_t item = 0; item < length;) {
39     *result_listener << title;
40     do {
41       result_listener->stream()->width(2);
42       result_listener->stream()->fill('0');
43       *result_listener << std::hex << +data[item] << " ";
44       ++item;
45     } while (item & 0xf);
46     *result_listener << "\n";
47   }
48 }
49 
50 // Compare two memory areas byte by byte, print information about first
51 // difference. Dumps both bufferst to user log.
Compare(MatchResultListener * result_listener,const uint8_t * exp,const uint8_t * act,size_t length)52 bool Compare(MatchResultListener* result_listener,
53              const uint8_t* exp, const uint8_t* act, size_t length) {
54   for (size_t index = 0; index < length; ++index) {
55     if (exp[index] != act[index]) {
56       *result_listener << "\nUnexpected data at offset " << index << "\n";
57       Dump(result_listener, "Data Expected: ", exp, length);
58       Dump(result_listener, "  Data Actual: ", act, length);
59       return false;
60     }
61   }
62 
63   return true;
64 }
65 
66 // Matcher validating Netlink Request data.
67 MATCHER_P2(RequestDataIs, data, length, "Matches expected request data") {
68   size_t offset = sizeof(nlmsghdr);
69   if (offset + length != arg.RequestLength()) {
70     *result_listener << "Unexpected request length: "
71                      << arg.RequestLength() - offset << " vs " << length;
72     return false;
73   }
74 
75   // Note: Request begins with header (nlmsghdr). Header is not covered by this
76   // call.
77   const uint8_t* exp_data = static_cast<const uint8_t*>(
78       static_cast<const void*>(data));
79   const uint8_t* act_data = static_cast<const uint8_t*>(arg.RequestData());
80   return Compare(
81       result_listener, exp_data, &act_data[offset], length);
82 }
83 
84 MATCHER_P4(RequestHeaderIs, length, type, flags, seq,
85            "Matches request header") {
86   nlmsghdr* header = static_cast<nlmsghdr*>(arg.RequestData());
87   if (arg.RequestLength() < sizeof(header)) {
88     *result_listener << "Malformed header: too short.";
89     return false;
90   }
91 
92   if (header->nlmsg_len != length) {
93     *result_listener << "Invalid message length: "
94                      << header->nlmsg_len << " vs " << length;
95     return false;
96   }
97 
98   if (header->nlmsg_type != type) {
99     *result_listener << "Invalid header type: "
100                      << header->nlmsg_type << " vs " << type;
101     return false;
102   }
103 
104   if (header->nlmsg_flags != flags) {
105     *result_listener << "Invalid header flags: "
106                      << header->nlmsg_flags << " vs " << flags;
107     return false;
108   }
109 
110   if (header->nlmsg_seq != seq) {
111     *result_listener << "Invalid header sequence number: "
112                      << header->nlmsg_seq << " vs " << seq;
113     return false;
114   }
115 
116   return true;
117 }
118 }  // namespace
119 
TEST(NetlinkClientTest,BasicStringNode)120 TEST(NetlinkClientTest, BasicStringNode) {
121   constexpr uint16_t kDummyTag = 0xfce2;
122   constexpr char kLongString[] = "long string";
123 
124   struct {
125     // 11 bytes of text + padding 0 + 4 bytes of header.
126     const uint16_t attr_length = 0x10;
127     const uint16_t attr_type = kDummyTag;
128     char text[sizeof(kLongString)];  // sizeof includes padding 0.
129   } expected;
130 
131   memcpy(&expected.text, kLongString, sizeof(kLongString));
132 
133   NetlinkRequest request(RTM_SETLINK, 0);
134   request.AddString(kDummyTag, kLongString);
135   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
136 }
137 
TEST(NetlinkClientTest,BasicIntNode)138 TEST(NetlinkClientTest, BasicIntNode) {
139   // Basic { Dummy: Value } test.
140   constexpr uint16_t kDummyTag = 0xfce2;
141   constexpr int32_t kValue = 0x1badd00d;
142 
143   struct {
144     const uint16_t attr_length = 0x8;  // 4 bytes of value + 4 bytes of header.
145     const uint16_t attr_type = kDummyTag;
146     const uint32_t attr_value = kValue;
147   } expected;
148 
149   NetlinkRequest request(RTM_SETLINK, 0);
150   request.AddInt(kDummyTag, kValue);
151   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
152 }
153 
TEST(NetlinkClientTest,AllIntegerTypes)154 TEST(NetlinkClientTest, AllIntegerTypes) {
155   // Basic { Dummy: Value } test.
156   constexpr uint16_t kDummyTag = 0xfce2;
157   constexpr uint8_t kValue = 0x1b;
158 
159   // The attribute is necessary for correct binary alignment.
160   constexpr struct __attribute__((__packed__)) {
161     uint16_t attr_length_i64 = 12;
162     uint16_t attr_type_i64 = kDummyTag;
163     int64_t attr_value_i64 = kValue;
164     uint16_t attr_length_i32 = 8;
165     uint16_t attr_type_i32 = kDummyTag + 1;
166     int32_t attr_value_i32 = kValue;
167     uint16_t attr_length_i16 = 6;
168     uint16_t attr_type_i16 = kDummyTag + 2;
169     int16_t attr_value_i16 = kValue;
170     uint8_t attr_padding_i16[2] = {0, 0};
171     uint16_t attr_length_i8 = 5;
172     uint16_t attr_type_i8 = kDummyTag + 3;
173     int8_t attr_value_i8 = kValue;
174     uint8_t attr_padding_i8[3] = {0, 0, 0};
175     uint16_t attr_length_u64 = 12;
176     uint16_t attr_type_u64 = kDummyTag + 4;
177     uint64_t attr_value_u64 = kValue;
178     uint16_t attr_length_u32 = 8;
179     uint16_t attr_type_u32 = kDummyTag + 5;
180     uint32_t attr_value_u32 = kValue;
181     uint16_t attr_length_u16 = 6;
182     uint16_t attr_type_u16 = kDummyTag + 6;
183     uint16_t attr_value_u16 = kValue;
184     uint8_t attr_padding_u16[2] = {0, 0};
185     uint16_t attr_length_u8 = 5;
186     uint16_t attr_type_u8 = kDummyTag + 7;
187     uint8_t attr_value_u8 = kValue;
188     uint8_t attr_padding_u8[3] = {0, 0, 0};
189   } expected = {};
190 
191   NetlinkRequest request(RTM_SETLINK, 0);
192   request.AddInt<int64_t>(kDummyTag, kValue);
193   request.AddInt<int32_t>(kDummyTag + 1, kValue);
194   request.AddInt<int16_t>(kDummyTag + 2, kValue);
195   request.AddInt<int8_t>(kDummyTag + 3, kValue);
196   request.AddInt<uint64_t>(kDummyTag + 4, kValue);
197   request.AddInt<uint32_t>(kDummyTag + 5, kValue);
198   request.AddInt<int16_t>(kDummyTag + 6, kValue);
199   request.AddInt<int8_t>(kDummyTag + 7, kValue);
200 
201   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
202 }
203 
TEST(NetlinkClientTest,SingleList)204 TEST(NetlinkClientTest, SingleList) {
205   // List: { Dummy: Value}
206   constexpr uint16_t kDummyTag = 0xfce2;
207   constexpr uint16_t kListTag = 0xcafe;
208   constexpr int32_t kValue = 0x1badd00d;
209 
210   struct {
211     const uint16_t list_length = 0xc;
212     const uint16_t list_type = kListTag;
213     const uint16_t attr_length = 0x8;  // 4 bytes of value + 4 bytes of header.
214     const uint16_t attr_type = kDummyTag;
215     const uint32_t attr_value = kValue;
216   } expected;
217 
218   NetlinkRequest request(RTM_SETLINK, 0);
219   request.PushList(kListTag);
220   request.AddInt(kDummyTag, kValue);
221   request.PopList();
222 
223   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
224 }
225 
TEST(NetlinkClientTest,NestedList)226 TEST(NetlinkClientTest, NestedList) {
227   // List1: { List2: { Dummy: Value}}
228   constexpr uint16_t kDummyTag = 0xfce2;
229   constexpr uint16_t kList1Tag = 0xcafe;
230   constexpr uint16_t kList2Tag = 0xfeed;
231   constexpr int32_t kValue = 0x1badd00d;
232 
233   struct {
234     const uint16_t list1_length = 0x10;
235     const uint16_t list1_type = kList1Tag;
236     const uint16_t list2_length = 0xc;
237     const uint16_t list2_type = kList2Tag;
238     const uint16_t attr_length = 0x8;
239     const uint16_t attr_type = kDummyTag;
240     const uint32_t attr_value = kValue;
241   } expected;
242 
243   NetlinkRequest request(RTM_SETLINK, 0);
244   request.PushList(kList1Tag);
245   request.PushList(kList2Tag);
246   request.AddInt(kDummyTag, kValue);
247   request.PopList();
248   request.PopList();
249 
250   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
251 }
252 
TEST(NetlinkClientTest,ListSequence)253 TEST(NetlinkClientTest, ListSequence) {
254   // List1: { Dummy1: Value1}, List2: { Dummy2: Value2 }
255   constexpr uint16_t kDummy1Tag = 0xfce2;
256   constexpr uint16_t kDummy2Tag = 0xfd38;
257   constexpr uint16_t kList1Tag = 0xcafe;
258   constexpr uint16_t kList2Tag = 0xfeed;
259   constexpr int32_t kValue1 = 0x1badd00d;
260   constexpr int32_t kValue2 = 0xfee1;
261 
262   struct {
263     const uint16_t list1_length = 0xc;
264     const uint16_t list1_type = kList1Tag;
265     const uint16_t attr1_length = 0x8;
266     const uint16_t attr1_type = kDummy1Tag;
267     const uint32_t attr1_value = kValue1;
268     const uint16_t list2_length = 0xc;
269     const uint16_t list2_type = kList2Tag;
270     const uint16_t attr2_length = 0x8;
271     const uint16_t attr2_type = kDummy2Tag;
272     const uint32_t attr2_value = kValue2;
273   } expected;
274 
275   NetlinkRequest request(RTM_SETLINK, 0);
276   request.PushList(kList1Tag);
277   request.AddInt(kDummy1Tag, kValue1);
278   request.PopList();
279   request.PushList(kList2Tag);
280   request.AddInt(kDummy2Tag, kValue2);
281   request.PopList();
282 
283   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
284 }
285 
TEST(NetlinkClientTest,ComplexList)286 TEST(NetlinkClientTest, ComplexList) {
287   // List1: { List2: { Dummy1: Value1 }, Dummy2: Value2 }
288   constexpr uint16_t kDummy1Tag = 0xfce2;
289   constexpr uint16_t kDummy2Tag = 0xfd38;
290   constexpr uint16_t kList1Tag = 0xcafe;
291   constexpr uint16_t kList2Tag = 0xfeed;
292   constexpr int32_t kValue1 = 0x1badd00d;
293   constexpr int32_t kValue2 = 0xfee1;
294 
295   struct {
296     const uint16_t list1_length = 0x18;
297     const uint16_t list1_type = kList1Tag;
298     const uint16_t list2_length = 0xc;  // Note, this only covers until kValue1.
299     const uint16_t list2_type = kList2Tag;
300     const uint16_t attr1_length = 0x8;
301     const uint16_t attr1_type = kDummy1Tag;
302     const uint32_t attr1_value = kValue1;
303     const uint16_t attr2_length = 0x8;
304     const uint16_t attr2_type = kDummy2Tag;
305     const uint32_t attr2_value = kValue2;
306   } expected;
307 
308   NetlinkRequest request(RTM_SETLINK, 0);
309   request.PushList(kList1Tag);
310   request.PushList(kList2Tag);
311   request.AddInt(kDummy1Tag, kValue1);
312   request.PopList();
313   request.AddInt(kDummy2Tag, kValue2);
314   request.PopList();
315 
316   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
317 }
318 
TEST(NetlinkClientTest,SimpleNetlinkCreateHeader)319 TEST(NetlinkClientTest, SimpleNetlinkCreateHeader) {
320   NetlinkRequest request(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL);
321   constexpr char kValue[] = "random string";
322   request.AddString(0, kValue);  // Have something to work with.
323 
324   constexpr size_t kMsgLength =
325       sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue));
326   uint32_t base_seq = request.SeqNo();
327 
328   EXPECT_THAT(request, RequestHeaderIs(
329       kMsgLength,
330       RTM_NEWLINK,
331       NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST,
332       base_seq));
333 
334   NetlinkRequest request2(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL);
335   request2.AddString(0, kValue);  // Have something to work with.
336   EXPECT_THAT(request2, RequestHeaderIs(
337       kMsgLength,
338       RTM_NEWLINK,
339       NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST,
340       base_seq + 1));
341 }
342 
TEST(NetlinkClientTest,SimpleNetlinkUpdateHeader)343 TEST(NetlinkClientTest, SimpleNetlinkUpdateHeader) {
344   NetlinkRequest request(RTM_SETLINK, 0);
345   constexpr char kValue[] = "random string";
346   request.AddString(0, kValue);  // Have something to work with.
347 
348   constexpr size_t kMsgLength =
349       sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue));
350   uint32_t base_seq = request.SeqNo();
351 
352   EXPECT_THAT(request, RequestHeaderIs(
353       kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq));
354 
355   NetlinkRequest request2(RTM_SETLINK, 0);
356   request2.AddString(0, kValue);  // Have something to work with.
357   EXPECT_THAT(request2, RequestHeaderIs(
358       kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq + 1));
359 }
360 
361 }  // namespace cuttlefish
362