1 // InsecureConnection.cpp
2
3 // TODO: Asynchronous connecting on client side?
4
5 #include <new>
6
7 #include <errno.h>
8 #include <netdb.h>
9 #include <stdio.h>
10 #include <string.h>
11 #include <unistd.h>
12
13 #include <ByteOrder.h>
14
15 #include "Compatibility.h"
16 #include "DebugSupport.h"
17 #include "InsecureChannel.h"
18 #include "InsecureConnection.h"
19 #include "NetAddress.h"
20 #include "NetFSDefs.h"
21
22 namespace InsecureConnectionDefs {
23
24 const int32 kProtocolVersion = 1;
25
26 const bigtime_t kAcceptingTimeout = 10000000; // 10 s
27
28 // number of client up/down stream channels
29 const int32 kMinUpStreamChannels = 1;
30 const int32 kMaxUpStreamChannels = 10;
31 const int32 kDefaultUpStreamChannels = 5;
32 const int32 kMinDownStreamChannels = 1;
33 const int32 kMaxDownStreamChannels = 5;
34 const int32 kDefaultDownStreamChannels = 1;
35
36 } // namespace InsecureConnectionDefs
37
38 using namespace InsecureConnectionDefs;
39
40 // SocketCloser
41 struct SocketCloser {
SocketCloserSocketCloser42 SocketCloser(int fd) : fFD(fd) {}
~SocketCloserSocketCloser43 ~SocketCloser()
44 {
45 if (fFD >= 0)
46 closesocket(fFD);
47 }
48
DetachSocketCloser49 int Detach()
50 {
51 int fd = fFD;
52 fFD = -1;
53 return fd;
54 }
55
56 private:
57 int fFD;
58 };
59
60
61 // #pragma mark -
62 // #pragma mark ----- InsecureConnection -----
63
64 // constructor
InsecureConnection()65 InsecureConnection::InsecureConnection()
66 {
67 }
68
69 // destructor
~InsecureConnection()70 InsecureConnection::~InsecureConnection()
71 {
72 }
73
74 // Init (server side)
75 status_t
Init(int fd)76 InsecureConnection::Init(int fd)
77 {
78 status_t error = AbstractConnection::Init();
79 if (error != B_OK) {
80 closesocket(fd);
81 return error;
82 }
83 // create the initial channel
84 Channel* channel = new(std::nothrow) InsecureChannel(fd);
85 if (!channel) {
86 closesocket(fd);
87 return B_NO_MEMORY;
88 }
89 // add it
90 error = AddDownStreamChannel(channel);
91 if (error != B_OK) {
92 delete channel;
93 return error;
94 }
95 return B_OK;
96 }
97
98 // Init (client side)
99 status_t
Init(const char * parameters)100 InsecureConnection::Init(const char* parameters)
101 {
102 PRINT(("InsecureConnection::Init\n"));
103 if (!parameters)
104 return B_BAD_VALUE;
105 status_t error = AbstractConnection::Init();
106 if (error != B_OK)
107 return error;
108 // parse the parameters to get a server name and a port we shall connect to
109 // parameter format is "<server>[:port] [ <up> [ <down> ] ]"
110 char server[256];
111 uint16 port = kDefaultInsecureConnectionPort;
112 int upStreamChannels = kDefaultUpStreamChannels;
113 int downStreamChannels = kDefaultDownStreamChannels;
114 if (strchr(parameters, ':')) {
115 int result = sscanf(parameters, "%255[^:]:%hu %d %d", server, &port,
116 &upStreamChannels, &downStreamChannels);
117 if (result < 2)
118 return B_BAD_VALUE;
119 } else {
120 int result = sscanf(parameters, "%255[^:] %d %d", server,
121 &upStreamChannels, &downStreamChannels);
122 if (result < 1)
123 return B_BAD_VALUE;
124 }
125 // resolve server address
126 NetAddress netAddress;
127 error = NetAddressResolver().GetHostAddress(server, &netAddress);
128 if (error != B_OK)
129 return error;
130 in_addr serverAddr = netAddress.GetAddress().sin_addr;
131 // open the initial channel
132 Channel* channel;
133 error = _OpenClientChannel(serverAddr, port, &channel);
134 if (error != B_OK)
135 return error;
136 error = AddUpStreamChannel(channel);
137 if (error != B_OK) {
138 delete channel;
139 return error;
140 }
141 // send the server a connect request
142 ConnectRequest request;
143 request.protocolVersion = B_HOST_TO_BENDIAN_INT32(kProtocolVersion);
144 request.serverAddress = serverAddr.s_addr;
145 request.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
146 request.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
147 error = channel->Send(&request, sizeof(ConnectRequest));
148 if (error != B_OK)
149 return error;
150 // get the server reply
151 ConnectReply reply;
152 error = channel->Receive(&reply, sizeof(ConnectReply));
153 if (error != B_OK)
154 return error;
155 error = B_BENDIAN_TO_HOST_INT32(reply.error);
156 if (error != B_OK)
157 return error;
158 upStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.upStreamChannels);
159 downStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.downStreamChannels);
160 port = B_BENDIAN_TO_HOST_INT16(reply.port);
161 // open the remaining channels
162 int32 allChannels = upStreamChannels + downStreamChannels;
163 for (int32 i = 1; i < allChannels; i++) {
164 PRINT(" creating channel %" B_PRId32 "\n", i);
165 // open the channel
166 error = _OpenClientChannel(serverAddr, port, &channel);
167 if (error != B_OK)
168 RETURN_ERROR(error);
169 // add it
170 if (i < upStreamChannels)
171 error = AddUpStreamChannel(channel);
172 else
173 error = AddDownStreamChannel(channel);
174 if (error != B_OK) {
175 delete channel;
176 return error;
177 }
178 }
179 return B_OK;
180 }
181
182 // FinishInitialization
183 status_t
FinishInitialization()184 InsecureConnection::FinishInitialization()
185 {
186 PRINT(("InsecureConnection::FinishInitialization()\n"));
187 // get the down stream channel
188 InsecureChannel* channel
189 = dynamic_cast<InsecureChannel*>(DownStreamChannelAt(0));
190 if (!channel)
191 return B_BAD_VALUE;
192 // receive the connect request
193 ConnectRequest request;
194 status_t error = channel->Receive(&request, sizeof(ConnectRequest));
195 if (error != B_OK)
196 return error;
197 // check the protocol version
198 int32 protocolVersion = B_BENDIAN_TO_HOST_INT32(request.protocolVersion);
199 if (protocolVersion != kProtocolVersion) {
200 _SendErrorReply(channel, B_ERROR);
201 return B_ERROR;
202 }
203 // get our address (we need it for binding)
204 in_addr serverAddr;
205 serverAddr.s_addr = request.serverAddress;
206 // check number of up and down stream channels
207 int32 upStreamChannels = B_BENDIAN_TO_HOST_INT32(request.upStreamChannels);
208 int32 downStreamChannels = B_BENDIAN_TO_HOST_INT32(
209 request.downStreamChannels);
210 if (upStreamChannels < kMinUpStreamChannels)
211 upStreamChannels = kMinUpStreamChannels;
212 else if (upStreamChannels > kMaxUpStreamChannels)
213 upStreamChannels = kMaxUpStreamChannels;
214 if (downStreamChannels < kMinDownStreamChannels)
215 downStreamChannels = kMinDownStreamChannels;
216 else if (downStreamChannels > kMaxDownStreamChannels)
217 downStreamChannels = kMaxDownStreamChannels;
218 // due to a bug on BONE we have a maximum of 2 working connections
219 // accepted on one listener socket.
220 NetAddress peerAddress;
221 if (channel->GetPeerAddress(&peerAddress) == B_OK
222 && peerAddress.IsLocal()) {
223 upStreamChannels = 1;
224 if (downStreamChannels > 2)
225 downStreamChannels = 2;
226 }
227 int32 allChannels = upStreamChannels + downStreamChannels;
228 // create a listener socket
229 int fd = socket(AF_INET, SOCK_STREAM, 0);
230 if (fd < 0) {
231 error = errno;
232 _SendErrorReply(channel, error);
233 return error;
234 }
235 SocketCloser _(fd);
236 // bind it to some port
237 sockaddr_in addr;
238 addr.sin_family = AF_INET;
239 addr.sin_port = 0;
240 addr.sin_addr = serverAddr;
241 if (bind(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
242 error = errno;
243 _SendErrorReply(channel, error);
244 return error;
245 }
246 // get the port
247 socklen_t addrSize = sizeof(addr);
248 if (getsockname(fd, (sockaddr*)&addr, &addrSize) < 0) {
249 error = errno;
250 _SendErrorReply(channel, error);
251 return error;
252 }
253 // set socket to non-blocking
254 int dontBlock = 1;
255 if (setsockopt(fd, SOL_SOCKET, SO_NONBLOCK, &dontBlock, sizeof(int)) < 0) {
256 error = errno;
257 _SendErrorReply(channel, error);
258 return error;
259 }
260 // start listening
261 if (listen(fd, allChannels - 1) < 0) {
262 error = errno;
263 _SendErrorReply(channel, error);
264 return error;
265 }
266 // send the reply
267 ConnectReply reply;
268 reply.error = B_HOST_TO_BENDIAN_INT32(B_OK);
269 reply.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
270 reply.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
271 reply.port = addr.sin_port;
272 error = channel->Send(&reply, sizeof(ConnectReply));
273 if (error != B_OK)
274 return error;
275 // start accepting
276 bigtime_t startAccepting = system_time();
277 for (int32 i = 1; i < allChannels; ) {
278 // accept a connection
279 int channelFD = accept(fd, NULL, 0);
280 if (channelFD < 0) {
281 error = errno;
282 if (error == B_INTERRUPTED) {
283 error = B_OK;
284 continue;
285 }
286 if (error == B_WOULD_BLOCK) {
287 bigtime_t now = system_time();
288 if (now - startAccepting > kAcceptingTimeout)
289 RETURN_ERROR(B_TIMED_OUT);
290 snooze(10000);
291 continue;
292 }
293 RETURN_ERROR(error);
294 }
295 PRINT(" accepting channel %" B_PRId32 "\n", i);
296 // create a channel
297 channel = new(std::nothrow) InsecureChannel(channelFD);
298 if (!channel) {
299 closesocket(channelFD);
300 return B_NO_MEMORY;
301 }
302 // add it
303 if (i < upStreamChannels) // inverse, since we are on server side
304 error = AddDownStreamChannel(channel);
305 else
306 error = AddUpStreamChannel(channel);
307 if (error != B_OK) {
308 delete channel;
309 return error;
310 }
311 i++;
312 startAccepting = system_time();
313 }
314 return B_OK;
315 }
316
317 // _OpenClientChannel
318 status_t
_OpenClientChannel(in_addr serverAddr,uint16 port,Channel ** _channel)319 InsecureConnection::_OpenClientChannel(in_addr serverAddr, uint16 port,
320 Channel** _channel)
321 {
322 // create a socket
323 int fd = socket(AF_INET, SOCK_STREAM, 0);
324 if (fd < 0)
325 return errno;
326 // connect
327 sockaddr_in addr;
328 addr.sin_family = AF_INET;
329 addr.sin_port = htons(port);
330 addr.sin_addr = serverAddr;
331 if (connect(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
332 status_t error = errno;
333 closesocket(fd);
334 RETURN_ERROR(error);
335 }
336 // create the channel
337 Channel* channel = new(std::nothrow) InsecureChannel(fd);
338 if (!channel) {
339 closesocket(fd);
340 return B_NO_MEMORY;
341 }
342 *_channel = channel;
343 return B_OK;
344 }
345
346 // _SendErrorReply
347 status_t
_SendErrorReply(Channel * channel,status_t error)348 InsecureConnection::_SendErrorReply(Channel* channel, status_t error)
349 {
350 ConnectReply reply;
351 reply.error = B_HOST_TO_BENDIAN_INT32(error);
352 return channel->Send(&reply, sizeof(ConnectReply));
353 }
354
355