xref: /haiku/src/add-ons/kernel/network/protocols/tcp/tcp.cpp (revision f23596149e0d173463f70629581aa10cc305d32e)
1 /*
2  * Copyright 2006, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  *		Andrew Galante, haiku.galante@gmail.com
8  */
9 
10 
11 #include "TCPConnection.h"
12 
13 #include <net_protocol.h>
14 
15 #include <KernelExport.h>
16 #include <util/list.h>
17 
18 #include <netinet/in.h>
19 #include <netinet/ip.h>
20 #include <new>
21 #include <stdlib.h>
22 #include <string.h>
23 
24 #include <lock.h>
25 #include <util/AutoLock.h>
26 
27 #include <NetBufferUtilities.h>
28 #include <NetUtilities.h>
29 
30 #define TRACE_TCP
31 #ifdef TRACE_TCP
32 #	define TRACE(x) dprintf x
33 #	define TRACE_BLOCK(x) dump_block x
34 #else
35 #	define TRACE(x)
36 #	define TRACE_BLOCK(x)
37 #endif
38 
39 
40 #define MAX_HASH_TCP	64
41 
42 
43 net_domain *gDomain;
44 net_address_module_info *gAddressModule;
45 net_buffer_module_info *gBufferModule;
46 net_datalink_module_info *gDatalinkModule;
47 net_stack_module_info *gStackModule;
48 hash_table *gConnectionHash;
49 benaphore gConnectionLock;
50 
51 
52 #ifdef TRACE_TCP
53 #	define DUMP_TCP_HASH tcp_dump_hash()
54 // Dumps the TCP Connection hash.  gConnectionLock must NOT be held when calling
55 void
56 tcp_dump_hash()
57 {
58 	BenaphoreLocker lock(&gConnectionLock);
59 	if (gDomain == NULL) {
60 		TRACE(("Unable to dump TCP Connections!\n"));
61 		return;
62 	}
63 	struct hash_iterator iterator;
64 	hash_open(gConnectionHash, &iterator);
65 	TCPConnection *connection;
66 	hash_rewind(gConnectionHash, &iterator);
67 	TRACE(("Active TCP Connections:\n"));
68 	while ((connection = (TCPConnection *)hash_next(gConnectionHash, &iterator)) != NULL) {
69 		TRACE(("  TCPConnection %p: %s, %s\n", connection,
70 		AddressString(gDomain, (sockaddr *)&connection->socket->address, true).Data(),
71 		AddressString(gDomain, (sockaddr *)&connection->socket->peer, true).Data()));
72 	}
73 	hash_close(gConnectionHash, &iterator, false);
74 }
75 #else
76 #	define DUMP_TCP_HASH 0
77 #endif
78 
79 
80 status_t
81 set_domain(net_interface *interface = NULL)
82 {
83 	if (gDomain == NULL) {
84 		// domain and address module are not known yet, we copy them from
85 		// the buffer's interface (if any):
86 		if (interface == NULL || interface->domain == NULL)
87 			gDomain = gStackModule->get_domain(AF_INET);
88 		else
89 			gDomain = interface->domain;
90 
91 		if (gDomain == NULL) {
92 			// this shouldn't occur, of course, but who knows...
93 			return B_BAD_VALUE;
94 		}
95 		gAddressModule = gDomain->address_module;
96 	}
97 
98 	return B_OK;
99 }
100 
101 
102 /*!
103 	Constructs a TCP header on \a buffer with the specified values
104 	for \a flags, \a seq \a ack and \a advertisedWindow.
105 */
106 status_t
107 add_tcp_header(net_buffer *buffer, uint16 flags, uint32 sequence, uint32 ack,
108 	uint16 advertisedWindow)
109 {
110 	buffer->protocol = IPPROTO_TCP;
111 
112 	NetBufferPrepend<tcp_header> bufferHeader(buffer);
113 	if (bufferHeader.Status() != B_OK)
114 		return bufferHeader.Status();
115 
116 	tcp_header &header = bufferHeader.Data();
117 
118 	header.source_port = gAddressModule->get_port((sockaddr *)&buffer->source);
119 	header.destination_port = gAddressModule->get_port((sockaddr *)&buffer->destination);
120 	header.sequence_num = htonl(sequence);
121 	header.acknowledge_num = htonl(ack);
122 	header.reserved = 0;
123 	header.header_length = 5;
124 		// currently no options supported
125 	header.flags = (uint8)flags;
126 	header.advertised_window = htons(advertisedWindow);
127 	header.checksum = 0;
128 	header.urgent_ptr = 0;
129 		// urgent pointer not supported
130 
131 	// compute and store checksum
132 	Checksum checksum;
133 	gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source);
134 	gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination);
135 	checksum
136 		<< (uint16)htons(IPPROTO_TCP)
137 		<< (uint16)htons(buffer->size)
138 		<< Checksum::BufferHelper(buffer, gBufferModule);
139 	header.checksum = checksum;
140 	TRACE(("TCP: Checksum for segment %p is %X\n", buffer, header.checksum));
141 	return B_OK;
142 }
143 
144 
145 //	#pragma mark - protocol API
146 
147 
148 net_protocol *
149 tcp_init_protocol(net_socket *socket)
150 {
151 	DUMP_TCP_HASH;
152 	socket->protocol = IPPROTO_TCP;
153 	TCPConnection *protocol = new (std::nothrow) TCPConnection(socket);
154 	TRACE(("Creating new TCPConnection: %p\n", protocol));
155 	return protocol;
156 }
157 
158 
159 status_t
160 tcp_uninit_protocol(net_protocol *protocol)
161 {
162 	DUMP_TCP_HASH;
163 	TRACE(("Deleting TCPConnection: %p\n", protocol));
164 	delete (TCPConnection *)protocol;
165 	return B_OK;
166 }
167 
168 
169 status_t
170 tcp_open(net_protocol *protocol)
171 {
172 	if (gDomain == NULL && set_domain() != B_OK)
173 		return B_ERROR;
174 
175 	DUMP_TCP_HASH;
176 
177 	return ((TCPConnection *)protocol)->Open();
178 }
179 
180 
181 status_t
182 tcp_close(net_protocol *protocol)
183 {
184 	DUMP_TCP_HASH;
185 	return ((TCPConnection *)protocol)->Close();
186 }
187 
188 
189 status_t
190 tcp_free(net_protocol *protocol)
191 {
192 	DUMP_TCP_HASH;
193 	return ((TCPConnection *)protocol)->Free();
194 }
195 
196 
197 status_t
198 tcp_connect(net_protocol *protocol, const struct sockaddr *address)
199 {
200 	DUMP_TCP_HASH;
201 	return ((TCPConnection *)protocol)->Connect(address);
202 }
203 
204 
205 status_t
206 tcp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
207 {
208 	return ((TCPConnection *)protocol)->Accept(_acceptedSocket);
209 }
210 
211 
212 status_t
213 tcp_control(net_protocol *protocol, int level, int option, void *value,
214 	size_t *_length)
215 {
216 	return protocol->next->module->control(protocol->next, level, option,
217 		value, _length);
218 }
219 
220 
221 status_t
222 tcp_bind(net_protocol *protocol, struct sockaddr *address)
223 {
224 	DUMP_TCP_HASH;
225 	return ((TCPConnection *)protocol)->Bind(address);
226 }
227 
228 
229 status_t
230 tcp_unbind(net_protocol *protocol, struct sockaddr *address)
231 {
232 	DUMP_TCP_HASH;
233 	return ((TCPConnection *)protocol)->Unbind(address);
234 }
235 
236 
237 status_t
238 tcp_listen(net_protocol *protocol, int count)
239 {
240 	return ((TCPConnection *)protocol)->Listen(count);
241 }
242 
243 
244 status_t
245 tcp_shutdown(net_protocol *protocol, int direction)
246 {
247 	return ((TCPConnection *)protocol)->Shutdown(direction);
248 }
249 
250 
251 status_t
252 tcp_send_data(net_protocol *protocol, net_buffer *buffer)
253 {
254 	return ((TCPConnection *)protocol)->SendData(buffer);
255 }
256 
257 
258 status_t
259 tcp_send_routed_data(net_protocol *protocol, struct net_route *route,
260 	net_buffer *buffer)
261 {
262 	return ((TCPConnection *)protocol)->SendRoutedData(route, buffer);
263 }
264 
265 
266 ssize_t
267 tcp_send_avail(net_protocol *protocol)
268 {
269 	return ((TCPConnection *)protocol)->SendAvailable();
270 }
271 
272 
273 status_t
274 tcp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
275 	net_buffer **_buffer)
276 {
277 	return ((TCPConnection *)protocol)->ReadData(numBytes, flags, _buffer);
278 }
279 
280 
281 ssize_t
282 tcp_read_avail(net_protocol *protocol)
283 {
284 	return ((TCPConnection *)protocol)->ReadAvailable();
285 }
286 
287 
288 struct net_domain *
289 tcp_get_domain(net_protocol *protocol)
290 {
291 	return protocol->next->module->get_domain(protocol->next);
292 }
293 
294 
295 size_t
296 tcp_get_mtu(net_protocol *protocol, const struct sockaddr *address)
297 {
298 	return protocol->next->module->get_mtu(protocol->next, address);
299 }
300 
301 
302 status_t
303 tcp_receive_data(net_buffer *buffer)
304 {
305 	TRACE(("TCP: Received buffer %p\n", buffer));
306 
307 	if (gDomain == NULL && set_domain(buffer->interface) != B_OK)
308 		return B_ERROR;
309 
310 	NetBufferHeader<tcp_header> bufferHeader(buffer);
311 	if (bufferHeader.Status() < B_OK)
312 		return bufferHeader.Status();
313 
314 	tcp_header &header = bufferHeader.Data();
315 
316 	tcp_connection_key key;
317 	key.peer = (struct sockaddr *)&buffer->source;
318 	key.local = (struct sockaddr *)&buffer->destination;
319 
320 	// TODO: check TCP Checksum
321 
322 	gAddressModule->set_port((struct sockaddr *)&buffer->source, header.source_port);
323 	gAddressModule->set_port((struct sockaddr *)&buffer->destination, header.destination_port);
324 
325 	DUMP_TCP_HASH;
326 
327 	BenaphoreLocker hashLock(&gConnectionLock);
328 	TCPConnection *connection = (TCPConnection *)hash_lookup(gConnectionHash, &key);
329 	TRACE(("TCP: Received packet corresponds to connection %p\n", connection));
330 	if (connection != NULL){
331 		return connection->ReceiveData(buffer);
332 	} else {
333 		/* TODO:
334 		   No explicit connection exists.  Check for wildcard connections:
335 		   First check if any connections exist where local = IPADDR_ANY
336 		   then check when local = peer = IPADDR_ANY.
337 		   port numbers always remain the same */
338 
339 		// If no connection exists (and RST is not set) send RST
340 		if (!(header.flags & TCP_FLG_RST)) {
341 			TRACE(("TCP:  Connection does not exist!\n"));
342 			net_buffer *reply = gBufferModule->create(512);
343 			if (reply == NULL)
344 				return B_NO_MEMORY;
345 
346 			gAddressModule->set_to((sockaddr *)&reply->source,
347 				(sockaddr *)&buffer->destination);
348 			gAddressModule->set_to((sockaddr *)&reply->destination,
349 				(sockaddr *)&buffer->source);
350 
351 			uint32 sequence, acknowledge;
352 			uint16 flags;
353 			if (header.flags & TCP_FLG_ACK) {
354 				sequence = ntohl(header.acknowledge_num);
355 				acknowledge = 0;
356 				flags = TCP_FLG_RST;
357 			} else {
358 				sequence = 0;
359 				acknowledge = ntohl(header.sequence_num) + 1
360 					+ buffer->size - ((uint32)header.header_length << 2);
361 				flags = TCP_FLG_RST | TCP_FLG_ACK;
362 			}
363 
364 			status_t status = add_tcp_header(reply, flags, sequence, acknowledge, 0);
365 
366 			if (status == B_OK) {
367 				TRACE(("TCP:  Sending RST...\n"));
368 				status = gDomain->module->send_data(NULL, reply);
369 			}
370 
371 			if (status != B_OK) {
372 				gBufferModule->free(reply);
373 				return status;
374 			}
375 		}
376 	}
377 	return B_OK;
378 }
379 
380 
381 status_t
382 tcp_error(uint32 code, net_buffer *data)
383 {
384 	return B_ERROR;
385 }
386 
387 
388 status_t
389 tcp_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code,
390 	void *errorData)
391 {
392 	return B_ERROR;
393 }
394 
395 
396 //	#pragma mark -
397 
398 
399 static status_t
400 tcp_init()
401 {
402 	status_t status;
403 
404 	gDomain = NULL;
405 	gAddressModule = NULL;
406 
407 	status = get_module(NET_STACK_MODULE_NAME, (module_info **)&gStackModule);
408 	if (status < B_OK)
409 		return status;
410 	status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
411 	if (status < B_OK)
412 		goto err1;
413 	status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule);
414 	if (status < B_OK)
415 		goto err2;
416 
417 	gConnectionHash = hash_init(MAX_HASH_TCP, TCPConnection::HashOffset(),
418 		&TCPConnection::Compare, &TCPConnection::Hash);
419 	if (gConnectionHash == NULL)
420 		goto err3;
421 
422 	status = benaphore_init(&gConnectionLock, "TCP Hash Lock");
423 	if (status < B_OK)
424 		goto err4;
425 
426 	status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_IP,
427 		"network/protocols/tcp/v1",
428 		"network/protocols/ipv4/v1",
429 		NULL);
430 	if (status < B_OK)
431 		goto err5;
432 
433 	status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_TCP,
434 		"network/protocols/tcp/v1",
435 		"network/protocols/ipv4/v1",
436 		NULL);
437 	if (status < B_OK)
438 		goto err5;
439 
440 	status = gStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_TCP,
441 		"network/protocols/tcp/v1");
442 	if (status < B_OK)
443 		goto err5;
444 
445 	return B_OK;
446 
447 err5:
448 	benaphore_destroy(&gConnectionLock);
449 err4:
450 	hash_uninit(gConnectionHash);
451 err3:
452 	put_module(NET_DATALINK_MODULE_NAME);
453 err2:
454 	put_module(NET_BUFFER_MODULE_NAME);
455 err1:
456 	put_module(NET_STACK_MODULE_NAME);
457 
458 	TRACE(("init_tcp() fails with %lx (%s)\n", status, strerror(status)));
459 	return status;
460 }
461 
462 
463 static status_t
464 tcp_uninit()
465 {
466 	benaphore_destroy(&gConnectionLock);
467 	hash_uninit(gConnectionHash);
468 	put_module(NET_DATALINK_MODULE_NAME);
469 	put_module(NET_BUFFER_MODULE_NAME);
470 	put_module(NET_STACK_MODULE_NAME);
471 
472 	return B_OK;
473 }
474 
475 
476 static status_t
477 tcp_std_ops(int32 op, ...)
478 {
479 	switch (op) {
480 		case B_MODULE_INIT:
481 			return tcp_init();
482 
483 		case B_MODULE_UNINIT:
484 			return tcp_uninit();
485 
486 		default:
487 			return B_ERROR;
488 	}
489 }
490 
491 
492 net_protocol_module_info sTCPModule = {
493 	{
494 		"network/protocols/tcp/v1",
495 		0,
496 		tcp_std_ops
497 	},
498 	tcp_init_protocol,
499 	tcp_uninit_protocol,
500 	tcp_open,
501 	tcp_close,
502 	tcp_free,
503 	tcp_connect,
504 	tcp_accept,
505 	tcp_control,
506 	tcp_bind,
507 	tcp_unbind,
508 	tcp_listen,
509 	tcp_shutdown,
510 	tcp_send_data,
511 	tcp_send_routed_data,
512 	tcp_send_avail,
513 	tcp_read_data,
514 	tcp_read_avail,
515 	tcp_get_domain,
516 	tcp_get_mtu,
517 	tcp_receive_data,
518 	tcp_error,
519 	tcp_error_reply,
520 };
521 
522 module_info *modules[] = {
523 	(module_info *)&sTCPModule,
524 	NULL
525 };
526