1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/libs/tcp_socket/tcp_socket.h"
18 
19 #include <netinet/in.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 
23 #include <cerrno>
24 
25 #include <android-base/logging.h>
26 
27 namespace cuttlefish {
28 
ClientSocket(int port)29 ClientSocket::ClientSocket(int port)
30     : fd_(SharedFD::SocketLocalClient(port, SOCK_STREAM)) {}
31 
RecvAny(std::size_t length)32 Message ClientSocket::RecvAny(std::size_t length) {
33   Message buf(length);
34   auto read_count = fd_->Read(buf.data(), buf.size());
35   if (read_count < 0) {
36     read_count = 0;
37   }
38   buf.resize(read_count);
39   return buf;
40 }
41 
closed() const42 bool ClientSocket::closed() const {
43   std::lock_guard<std::mutex> guard(closed_lock_);
44   return other_side_closed_;
45 }
46 
Recv(std::size_t length)47 Message ClientSocket::Recv(std::size_t length) {
48   Message buf(length);
49   ssize_t total_read = 0;
50   while (total_read < static_cast<ssize_t>(length)) {
51     auto just_read = fd_->Read(&buf[total_read], buf.size() - total_read);
52     if (just_read <= 0) {
53       if (just_read < 0) {
54         LOG(ERROR) << "read() error: " << strerror(errno);
55       }
56       {
57         std::lock_guard<std::mutex> guard(closed_lock_);
58         other_side_closed_ = true;
59       }
60       return Message{};
61     }
62     total_read += just_read;
63   }
64   CHECK(total_read == static_cast<ssize_t>(length));
65   return buf;
66 }
67 
SendNoSignal(const uint8_t * data,std::size_t size)68 ssize_t ClientSocket::SendNoSignal(const uint8_t* data, std::size_t size) {
69   std::lock_guard<std::mutex> lock(send_lock_);
70   ssize_t written{};
71   while (written < static_cast<ssize_t>(size)) {
72     if (!fd_->IsOpen()) {
73       LOG(ERROR) << "fd_ is closed";
74     }
75     auto just_written = fd_->Send(data + written, size - written, MSG_NOSIGNAL);
76     if (just_written <= 0) {
77       LOG(INFO) << "Couldn't write to client: " << strerror(errno);
78       {
79         std::lock_guard<std::mutex> guard(closed_lock_);
80         other_side_closed_ = true;
81       }
82       return just_written;
83     }
84     written += just_written;
85   }
86   return written;
87 }
88 
SendNoSignal(const Message & message)89 ssize_t ClientSocket::SendNoSignal(const Message& message) {
90   return SendNoSignal(&message[0], message.size());
91 }
92 
ServerSocket(int port)93 ServerSocket::ServerSocket(int port)
94     : fd_{SharedFD::SocketLocalServer(port, SOCK_STREAM)} {
95   if (!fd_->IsOpen()) {
96     LOG(FATAL) << "Couldn't open streaming server on port " << port;
97   }
98 }
99 
Accept()100 ClientSocket ServerSocket::Accept() {
101   SharedFD client = SharedFD::Accept(*fd_);
102   if (!client->IsOpen()) {
103     LOG(FATAL) << "Error attemping to accept: " << strerror(errno);
104   }
105   return ClientSocket{client};
106 }
107 
AppendInNetworkByteOrder(Message * msg,const std::uint8_t b)108 void AppendInNetworkByteOrder(Message* msg, const std::uint8_t b) {
109   msg->push_back(b);
110 }
111 
AppendInNetworkByteOrder(Message * msg,const std::uint16_t s)112 void AppendInNetworkByteOrder(Message* msg, const std::uint16_t s) {
113   const std::uint16_t n = htons(s);
114   auto p = reinterpret_cast<const std::uint8_t*>(&n);
115   msg->insert(msg->end(), p, p + sizeof n);
116 }
117 
AppendInNetworkByteOrder(Message * msg,const std::uint32_t w)118 void AppendInNetworkByteOrder(Message* msg, const std::uint32_t w) {
119   const std::uint32_t n = htonl(w);
120   auto p = reinterpret_cast<const std::uint8_t*>(&n);
121   msg->insert(msg->end(), p, p + sizeof n);
122 }
123 
AppendInNetworkByteOrder(Message * msg,const std::int32_t w)124 void AppendInNetworkByteOrder(Message* msg, const std::int32_t w) {
125   std::uint32_t u{};
126   std::memcpy(&u, &w, sizeof u);
127   AppendInNetworkByteOrder(msg, u);
128 }
129 
AppendInNetworkByteOrder(Message * msg,const std::string & str)130 void AppendInNetworkByteOrder(Message* msg, const std::string& str) {
131   msg->insert(msg->end(), str.begin(), str.end());
132 }
133 
134 }
135