xref: /haiku/src/add-ons/kernel/file_systems/netfs/shared/InsecureConnection.cpp (revision 22440f4105cafc95cc1d49f9bc65bb395c527d86)
1 // InsecureConnection.cpp
2 
3 // TODO: Asynchronous connecting on client side?
4 
5 #include <new>
6 
7 #include <errno.h>
8 #include <netdb.h>
9 #include <stdio.h>
10 #include <string.h>
11 #include <unistd.h>
12 
13 #include <ByteOrder.h>
14 
15 #include "Compatibility.h"
16 #include "DebugSupport.h"
17 #include "InsecureChannel.h"
18 #include "InsecureConnection.h"
19 #include "NetAddress.h"
20 #include "NetFSDefs.h"
21 
22 namespace InsecureConnectionDefs {
23 
24 const int32 kProtocolVersion = 1;
25 
26 const bigtime_t kAcceptingTimeout = 10000000;	// 10 s
27 
28 // number of client up/down stream channels
29 const int32 kMinUpStreamChannels		= 1;
30 const int32 kMaxUpStreamChannels		= 10;
31 const int32 kDefaultUpStreamChannels	= 5;
32 const int32 kMinDownStreamChannels		= 1;
33 const int32 kMaxDownStreamChannels		= 5;
34 const int32 kDefaultDownStreamChannels	= 1;
35 
36 } // namespace InsecureConnectionDefs
37 
38 using namespace InsecureConnectionDefs;
39 
40 // SocketCloser
41 struct SocketCloser {
42 	SocketCloser(int fd) : fFD(fd) {}
43 	~SocketCloser()
44 	{
45 		if (fFD >= 0)
46 			closesocket(fFD);
47 	}
48 
49 	int Detach()
50 	{
51 		int fd = fFD;
52 		fFD = -1;
53 		return fd;
54 	}
55 
56 private:
57 	int	fFD;
58 };
59 
60 
61 // #pragma mark -
62 // #pragma mark ----- InsecureConnection -----
63 
64 // constructor
65 InsecureConnection::InsecureConnection()
66 {
67 }
68 
69 // destructor
70 InsecureConnection::~InsecureConnection()
71 {
72 }
73 
74 // Init (server side)
75 status_t
76 InsecureConnection::Init(int fd)
77 {
78 	status_t error = AbstractConnection::Init();
79 	if (error != B_OK) {
80 		closesocket(fd);
81 		return error;
82 	}
83 	// create the initial channel
84 	Channel* channel = new(std::nothrow) InsecureChannel(fd);
85 	if (!channel) {
86 		closesocket(fd);
87 		return B_NO_MEMORY;
88 	}
89 	// add it
90 	error = AddDownStreamChannel(channel);
91 	if (error != B_OK) {
92 		delete channel;
93 		return error;
94 	}
95 	return B_OK;
96 }
97 
98 // Init (client side)
99 status_t
100 InsecureConnection::Init(const char* parameters)
101 {
102 PRINT(("InsecureConnection::Init\n"));
103 	if (!parameters)
104 		return B_BAD_VALUE;
105 	status_t error = AbstractConnection::Init();
106 	if (error != B_OK)
107 		return error;
108 	// parse the parameters to get a server name and a port we shall connect to
109 	// parameter format is "<server>[:port] [ <up> [ <down> ] ]"
110 	char server[256];
111 	uint16 port = kDefaultInsecureConnectionPort;
112 	int upStreamChannels = kDefaultUpStreamChannels;
113 	int downStreamChannels = kDefaultDownStreamChannels;
114 	if (strchr(parameters, ':')) {
115 		int result = sscanf(parameters, "%255[^:]:%hu %d %d", server, &port,
116 			&upStreamChannels, &downStreamChannels);
117 		if (result < 2)
118 			return B_BAD_VALUE;
119 	} else {
120 		int result = sscanf(parameters, "%255[^:] %d %d", server,
121 			&upStreamChannels, &downStreamChannels);
122 		if (result < 1)
123 			return B_BAD_VALUE;
124 	}
125 	// resolve server address
126 	NetAddress netAddress;
127 	error = NetAddressResolver().GetHostAddress(server, &netAddress);
128 	if (error != B_OK)
129 		return error;
130 	in_addr serverAddr = netAddress.GetAddress().sin_addr;
131 	// open the initial channel
132 	Channel* channel;
133 	error = _OpenClientChannel(serverAddr, port, &channel);
134 	if (error != B_OK)
135 		return error;
136 	error = AddUpStreamChannel(channel);
137 	if (error != B_OK) {
138 		delete channel;
139 		return error;
140 	}
141 	// send the server a connect request
142 	ConnectRequest request;
143 	request.protocolVersion = B_HOST_TO_BENDIAN_INT32(kProtocolVersion);
144 	request.serverAddress = serverAddr.s_addr;
145 	request.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
146 	request.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
147 	error = channel->Send(&request, sizeof(ConnectRequest));
148 	if (error != B_OK)
149 		return error;
150 	// get the server reply
151 	ConnectReply reply;
152 	error = channel->Receive(&reply, sizeof(ConnectReply));
153 	if (error != B_OK)
154 		return error;
155 	error = B_BENDIAN_TO_HOST_INT32(reply.error);
156 	if (error != B_OK)
157 		return error;
158 	upStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.upStreamChannels);
159 	downStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.downStreamChannels);
160 	port = B_BENDIAN_TO_HOST_INT16(reply.port);
161 	// open the remaining channels
162 	int32 allChannels = upStreamChannels + downStreamChannels;
163 	for (int32 i = 1; i < allChannels; i++) {
164 		PRINT("  creating channel %ld\n", i);
165 		// open the channel
166 		error = _OpenClientChannel(serverAddr, port, &channel);
167 		if (error != B_OK)
168 			RETURN_ERROR(error);
169 		// add it
170 		if (i < upStreamChannels)
171 			error = AddUpStreamChannel(channel);
172 		else
173 			error = AddDownStreamChannel(channel);
174 		if (error != B_OK) {
175 			delete channel;
176 			return error;
177 		}
178 	}
179 	return B_OK;
180 }
181 
182 // FinishInitialization
183 status_t
184 InsecureConnection::FinishInitialization()
185 {
186 PRINT(("InsecureConnection::FinishInitialization()\n"));
187 	// get the down stream channel
188 	InsecureChannel* channel
189 		= dynamic_cast<InsecureChannel*>(DownStreamChannelAt(0));
190 	if (!channel)
191 		return B_BAD_VALUE;
192 	// receive the connect request
193 	ConnectRequest request;
194 	status_t error = channel->Receive(&request, sizeof(ConnectRequest));
195 	if (error != B_OK)
196 		return error;
197 	// check the protocol version
198 	int32 protocolVersion = B_BENDIAN_TO_HOST_INT32(request.protocolVersion);
199 	if (protocolVersion != kProtocolVersion) {
200 		_SendErrorReply(channel, B_ERROR);
201 		return B_ERROR;
202 	}
203 	// get our address (we need it for binding)
204 	in_addr serverAddr;
205 	serverAddr.s_addr = request.serverAddress;
206 	// check number of up and down stream channels
207 	int32 upStreamChannels = B_BENDIAN_TO_HOST_INT32(request.upStreamChannels);
208 	int32 downStreamChannels = B_BENDIAN_TO_HOST_INT32(
209 		request.downStreamChannels);
210 	if (upStreamChannels < kMinUpStreamChannels)
211 		upStreamChannels = kMinUpStreamChannels;
212 	else if (upStreamChannels > kMaxUpStreamChannels)
213 		upStreamChannels = kMaxUpStreamChannels;
214 	if (downStreamChannels < kMinDownStreamChannels)
215 		downStreamChannels = kMinDownStreamChannels;
216 	else if (downStreamChannels > kMaxDownStreamChannels)
217 		downStreamChannels = kMaxDownStreamChannels;
218 	// due to a bug on BONE we have a maximum of 2 working connections
219 	// accepted on one listener socket.
220 	NetAddress peerAddress;
221 	if (channel->GetPeerAddress(&peerAddress) == B_OK
222 		&& peerAddress.IsLocal()) {
223 		upStreamChannels = 1;
224 		if (downStreamChannels > 2)
225 			downStreamChannels = 2;
226 	}
227 	int32 allChannels = upStreamChannels + downStreamChannels;
228 	// create a listener socket
229 	int fd = socket(AF_INET, SOCK_STREAM, 0);
230 	if (fd < 0) {
231 		error = errno;
232 		_SendErrorReply(channel, error);
233 		return error;
234 	}
235 	SocketCloser _(fd);
236 	// bind it to some port
237 	sockaddr_in addr;
238 	addr.sin_family = AF_INET;
239 	addr.sin_port = 0;
240 	addr.sin_addr = serverAddr;
241 	if (bind(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
242 		error = errno;
243 		_SendErrorReply(channel, error);
244 		return error;
245 	}
246 	// get the port
247 	socklen_t addrSize = sizeof(addr);
248 	if (getsockname(fd, (sockaddr*)&addr, &addrSize) < 0) {
249 		error = errno;
250 		_SendErrorReply(channel, error);
251 		return error;
252 	}
253 	// set socket to non-blocking
254 	int dontBlock = 1;
255 	if (setsockopt(fd, SOL_SOCKET, SO_NONBLOCK, &dontBlock, sizeof(int)) < 0) {
256 		error = errno;
257 		_SendErrorReply(channel, error);
258 		return error;
259 	}
260 	// start listening
261 	if (listen(fd, allChannels - 1) < 0) {
262 		error = errno;
263 		_SendErrorReply(channel, error);
264 		return error;
265 	}
266 	// send the reply
267 	ConnectReply reply;
268 	reply.error = B_HOST_TO_BENDIAN_INT32(B_OK);
269 	reply.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels);
270 	reply.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels);
271 	reply.port = addr.sin_port;
272 	error = channel->Send(&reply, sizeof(ConnectReply));
273 	if (error != B_OK)
274 		return error;
275 	// start accepting
276 	bigtime_t startAccepting = system_time();
277 	for (int32 i = 1; i < allChannels; ) {
278 		// accept a connection
279 		int channelFD = accept(fd, NULL, 0);
280 		if (channelFD < 0) {
281 			error = errno;
282 			if (error == B_INTERRUPTED) {
283 				error = B_OK;
284 				continue;
285 			}
286 			if (error == B_WOULD_BLOCK) {
287 				bigtime_t now = system_time();
288 				if (now - startAccepting > kAcceptingTimeout)
289 					RETURN_ERROR(B_TIMED_OUT);
290 				snooze(10000);
291 				continue;
292 			}
293 			RETURN_ERROR(error);
294 		}
295 		PRINT("  accepting channel %ld\n", i);
296 		// create a channel
297 		channel = new(std::nothrow) InsecureChannel(channelFD);
298 		if (!channel) {
299 			closesocket(channelFD);
300 			return B_NO_MEMORY;
301 		}
302 		// add it
303 		if (i < upStreamChannels)	// inverse, since we are on server side
304 			error = AddDownStreamChannel(channel);
305 		else
306 			error = AddUpStreamChannel(channel);
307 		if (error != B_OK) {
308 			delete channel;
309 			return error;
310 		}
311 		i++;
312 		startAccepting = system_time();
313 	}
314 	return B_OK;
315 }
316 
317 // _OpenClientChannel
318 status_t
319 InsecureConnection::_OpenClientChannel(in_addr serverAddr, uint16 port,
320 	Channel** _channel)
321 {
322 	// create a socket
323 	int fd = socket(AF_INET, SOCK_STREAM, 0);
324 	if (fd < 0)
325 		return errno;
326 	// connect
327 	sockaddr_in addr;
328 	addr.sin_family = AF_INET;
329 	addr.sin_port = htons(port);
330 	addr.sin_addr = serverAddr;
331 	if (connect(fd, (sockaddr*)&addr, sizeof(addr)) < 0) {
332 		status_t error = errno;
333 		closesocket(fd);
334 		RETURN_ERROR(error);
335 	}
336 	// create the channel
337 	Channel* channel = new(std::nothrow) InsecureChannel(fd);
338 	if (!channel) {
339 		closesocket(fd);
340 		return B_NO_MEMORY;
341 	}
342 	*_channel = channel;
343 	return B_OK;
344 }
345 
346 // _SendErrorReply
347 status_t
348 InsecureConnection::_SendErrorReply(Channel* channel, status_t error)
349 {
350 	ConnectReply reply;
351 	reply.error = B_HOST_TO_BENDIAN_INT32(error);
352 	return channel->Send(&reply, sizeof(ConnectReply));
353 }
354 
355