1 /* 2 * Copyright 2006, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Andrew Galante, haiku.galante@gmail.com 8 */ 9 10 11 #include "TCPConnection.h" 12 13 #include <net_protocol.h> 14 15 #include <KernelExport.h> 16 #include <util/list.h> 17 18 #include <netinet/in.h> 19 #include <netinet/ip.h> 20 #include <new> 21 #include <stdlib.h> 22 #include <string.h> 23 24 #include <lock.h> 25 #include <util/AutoLock.h> 26 27 #include <NetBufferUtilities.h> 28 #include <NetUtilities.h> 29 30 #define TRACE_TCP 31 #ifdef TRACE_TCP 32 # define TRACE(x) dprintf x 33 # define TRACE_BLOCK(x) dump_block x 34 #else 35 # define TRACE(x) 36 # define TRACE_BLOCK(x) 37 #endif 38 39 40 #define MAX_HASH_TCP 64 41 42 43 net_domain *gDomain; 44 net_address_module_info *gAddressModule; 45 net_buffer_module_info *gBufferModule; 46 net_datalink_module_info *gDatalinkModule; 47 net_stack_module_info *gStackModule; 48 hash_table *gConnectionHash; 49 benaphore gConnectionLock; 50 51 52 #ifdef TRACE_TCP 53 # define DUMP_TCP_HASH tcp_dump_hash() 54 // Dumps the TCP Connection hash. gConnectionLock must NOT be held when calling 55 void 56 tcp_dump_hash() 57 { 58 BenaphoreLocker lock(&gConnectionLock); 59 if (gDomain == NULL) { 60 TRACE(("Unable to dump TCP Connections!\n")); 61 return; 62 } 63 struct hash_iterator iterator; 64 hash_open(gConnectionHash, &iterator); 65 TCPConnection *connection; 66 hash_rewind(gConnectionHash, &iterator); 67 TRACE(("Active TCP Connections:\n")); 68 while ((connection = (TCPConnection *)hash_next(gConnectionHash, &iterator)) != NULL) { 69 TRACE((" TCPConnection %p: %s, %s\n", connection, 70 AddressString(gDomain, (sockaddr *)&connection->socket->address, true).Data(), 71 AddressString(gDomain, (sockaddr *)&connection->socket->peer, true).Data())); 72 } 73 hash_close(gConnectionHash, &iterator, false); 74 } 75 #else 76 # define DUMP_TCP_HASH 0 77 #endif 78 79 80 status_t 81 set_domain(net_interface *interface = NULL) 82 { 83 if (gDomain == NULL) { 84 // domain and address module are not known yet, we copy them from 85 // the buffer's interface (if any): 86 if (interface == NULL || interface->domain == NULL) 87 gDomain = gStackModule->get_domain(AF_INET); 88 else 89 gDomain = interface->domain; 90 91 if (gDomain == NULL) { 92 // this shouldn't occur, of course, but who knows... 93 return B_BAD_VALUE; 94 } 95 gAddressModule = gDomain->address_module; 96 } 97 98 return B_OK; 99 } 100 101 102 /*! 103 Constructs a TCP header on \a buffer with the specified values 104 for \a flags, \a seq \a ack and \a advertisedWindow. 105 */ 106 status_t 107 add_tcp_header(net_buffer *buffer, uint16 flags, uint32 sequence, uint32 ack, 108 uint16 advertisedWindow) 109 { 110 buffer->protocol = IPPROTO_TCP; 111 112 NetBufferPrepend<tcp_header> bufferHeader(buffer); 113 if (bufferHeader.Status() != B_OK) 114 return bufferHeader.Status(); 115 116 tcp_header &header = bufferHeader.Data(); 117 118 header.source_port = gAddressModule->get_port((sockaddr *)&buffer->source); 119 header.destination_port = gAddressModule->get_port((sockaddr *)&buffer->destination); 120 header.sequence_num = htonl(sequence); 121 header.acknowledge_num = htonl(ack); 122 header.reserved = 0; 123 header.header_length = 5; 124 // currently no options supported 125 header.flags = (uint8)flags; 126 header.advertised_window = htons(advertisedWindow); 127 header.checksum = 0; 128 header.urgent_ptr = 0; 129 // urgent pointer not supported 130 131 // compute and store checksum 132 Checksum checksum; 133 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source); 134 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination); 135 checksum 136 << (uint16)htons(IPPROTO_TCP) 137 << (uint16)htons(buffer->size) 138 << Checksum::BufferHelper(buffer, gBufferModule); 139 header.checksum = checksum; 140 TRACE(("TCP: Checksum for segment %p is %X\n", buffer, header.checksum)); 141 return B_OK; 142 } 143 144 145 // #pragma mark - protocol API 146 147 148 net_protocol * 149 tcp_init_protocol(net_socket *socket) 150 { 151 DUMP_TCP_HASH; 152 socket->protocol = IPPROTO_TCP; 153 TCPConnection *protocol = new (std::nothrow) TCPConnection(socket); 154 TRACE(("Creating new TCPConnection: %p\n", protocol)); 155 return protocol; 156 } 157 158 159 status_t 160 tcp_uninit_protocol(net_protocol *protocol) 161 { 162 DUMP_TCP_HASH; 163 TRACE(("Deleting TCPConnection: %p\n", protocol)); 164 delete (TCPConnection *)protocol; 165 return B_OK; 166 } 167 168 169 status_t 170 tcp_open(net_protocol *protocol) 171 { 172 if (gDomain == NULL && set_domain() != B_OK) 173 return B_ERROR; 174 175 DUMP_TCP_HASH; 176 177 return ((TCPConnection *)protocol)->Open(); 178 } 179 180 181 status_t 182 tcp_close(net_protocol *protocol) 183 { 184 DUMP_TCP_HASH; 185 return ((TCPConnection *)protocol)->Close(); 186 } 187 188 189 status_t 190 tcp_free(net_protocol *protocol) 191 { 192 DUMP_TCP_HASH; 193 return ((TCPConnection *)protocol)->Free(); 194 } 195 196 197 status_t 198 tcp_connect(net_protocol *protocol, const struct sockaddr *address) 199 { 200 DUMP_TCP_HASH; 201 return ((TCPConnection *)protocol)->Connect(address); 202 } 203 204 205 status_t 206 tcp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket) 207 { 208 return ((TCPConnection *)protocol)->Accept(_acceptedSocket); 209 } 210 211 212 status_t 213 tcp_control(net_protocol *protocol, int level, int option, void *value, 214 size_t *_length) 215 { 216 return protocol->next->module->control(protocol->next, level, option, 217 value, _length); 218 } 219 220 221 status_t 222 tcp_bind(net_protocol *protocol, struct sockaddr *address) 223 { 224 DUMP_TCP_HASH; 225 return ((TCPConnection *)protocol)->Bind(address); 226 } 227 228 229 status_t 230 tcp_unbind(net_protocol *protocol, struct sockaddr *address) 231 { 232 DUMP_TCP_HASH; 233 return ((TCPConnection *)protocol)->Unbind(address); 234 } 235 236 237 status_t 238 tcp_listen(net_protocol *protocol, int count) 239 { 240 return ((TCPConnection *)protocol)->Listen(count); 241 } 242 243 244 status_t 245 tcp_shutdown(net_protocol *protocol, int direction) 246 { 247 return ((TCPConnection *)protocol)->Shutdown(direction); 248 } 249 250 251 status_t 252 tcp_send_data(net_protocol *protocol, net_buffer *buffer) 253 { 254 return ((TCPConnection *)protocol)->SendData(buffer); 255 } 256 257 258 status_t 259 tcp_send_routed_data(net_protocol *protocol, struct net_route *route, 260 net_buffer *buffer) 261 { 262 return ((TCPConnection *)protocol)->SendRoutedData(route, buffer); 263 } 264 265 266 ssize_t 267 tcp_send_avail(net_protocol *protocol) 268 { 269 return ((TCPConnection *)protocol)->SendAvailable(); 270 } 271 272 273 status_t 274 tcp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags, 275 net_buffer **_buffer) 276 { 277 return ((TCPConnection *)protocol)->ReadData(numBytes, flags, _buffer); 278 } 279 280 281 ssize_t 282 tcp_read_avail(net_protocol *protocol) 283 { 284 return ((TCPConnection *)protocol)->ReadAvailable(); 285 } 286 287 288 struct net_domain * 289 tcp_get_domain(net_protocol *protocol) 290 { 291 return protocol->next->module->get_domain(protocol->next); 292 } 293 294 295 size_t 296 tcp_get_mtu(net_protocol *protocol, const struct sockaddr *address) 297 { 298 return protocol->next->module->get_mtu(protocol->next, address); 299 } 300 301 302 status_t 303 tcp_receive_data(net_buffer *buffer) 304 { 305 TRACE(("TCP: Received buffer %p\n", buffer)); 306 307 if (gDomain == NULL && set_domain(buffer->interface) != B_OK) 308 return B_ERROR; 309 310 NetBufferHeader<tcp_header> bufferHeader(buffer); 311 if (bufferHeader.Status() < B_OK) 312 return bufferHeader.Status(); 313 314 tcp_header &header = bufferHeader.Data(); 315 316 tcp_connection_key key; 317 key.peer = (struct sockaddr *)&buffer->source; 318 key.local = (struct sockaddr *)&buffer->destination; 319 320 // TODO: check TCP Checksum 321 322 gAddressModule->set_port((struct sockaddr *)&buffer->source, header.source_port); 323 gAddressModule->set_port((struct sockaddr *)&buffer->destination, header.destination_port); 324 325 DUMP_TCP_HASH; 326 327 BenaphoreLocker hashLock(&gConnectionLock); 328 TCPConnection *connection = (TCPConnection *)hash_lookup(gConnectionHash, &key); 329 TRACE(("TCP: Received packet corresponds to connection %p\n", connection)); 330 if (connection != NULL){ 331 return connection->ReceiveData(buffer); 332 } else { 333 /* TODO: 334 No explicit connection exists. Check for wildcard connections: 335 First check if any connections exist where local = IPADDR_ANY 336 then check when local = peer = IPADDR_ANY. 337 port numbers always remain the same */ 338 339 // If no connection exists (and RST is not set) send RST 340 if (!(header.flags & TCP_FLG_RST)) { 341 TRACE(("TCP: Connection does not exist!\n")); 342 net_buffer *reply = gBufferModule->create(512); 343 if (reply == NULL) 344 return B_NO_MEMORY; 345 346 gAddressModule->set_to((sockaddr *)&reply->source, 347 (sockaddr *)&buffer->destination); 348 gAddressModule->set_to((sockaddr *)&reply->destination, 349 (sockaddr *)&buffer->source); 350 351 uint32 sequence, acknowledge; 352 uint16 flags; 353 if (header.flags & TCP_FLG_ACK) { 354 sequence = ntohl(header.acknowledge_num); 355 acknowledge = 0; 356 flags = TCP_FLG_RST; 357 } else { 358 sequence = 0; 359 acknowledge = ntohl(header.sequence_num) + 1 360 + buffer->size - ((uint32)header.header_length << 2); 361 flags = TCP_FLG_RST | TCP_FLG_ACK; 362 } 363 364 status_t status = add_tcp_header(reply, flags, sequence, acknowledge, 0); 365 366 if (status == B_OK) { 367 TRACE(("TCP: Sending RST...\n")); 368 status = gDomain->module->send_data(NULL, reply); 369 } 370 371 if (status != B_OK) { 372 gBufferModule->free(reply); 373 return status; 374 } 375 } 376 } 377 return B_OK; 378 } 379 380 381 status_t 382 tcp_error(uint32 code, net_buffer *data) 383 { 384 return B_ERROR; 385 } 386 387 388 status_t 389 tcp_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code, 390 void *errorData) 391 { 392 return B_ERROR; 393 } 394 395 396 // #pragma mark - 397 398 399 static status_t 400 tcp_init() 401 { 402 status_t status; 403 404 gDomain = NULL; 405 gAddressModule = NULL; 406 407 status = get_module(NET_STACK_MODULE_NAME, (module_info **)&gStackModule); 408 if (status < B_OK) 409 return status; 410 status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule); 411 if (status < B_OK) 412 goto err1; 413 status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule); 414 if (status < B_OK) 415 goto err2; 416 417 gConnectionHash = hash_init(MAX_HASH_TCP, TCPConnection::HashOffset(), 418 &TCPConnection::Compare, &TCPConnection::Hash); 419 if (gConnectionHash == NULL) 420 goto err3; 421 422 status = benaphore_init(&gConnectionLock, "TCP Hash Lock"); 423 if (status < B_OK) 424 goto err4; 425 426 status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_IP, 427 "network/protocols/tcp/v1", 428 "network/protocols/ipv4/v1", 429 NULL); 430 if (status < B_OK) 431 goto err5; 432 433 status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_TCP, 434 "network/protocols/tcp/v1", 435 "network/protocols/ipv4/v1", 436 NULL); 437 if (status < B_OK) 438 goto err5; 439 440 status = gStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_TCP, 441 "network/protocols/tcp/v1"); 442 if (status < B_OK) 443 goto err5; 444 445 return B_OK; 446 447 err5: 448 benaphore_destroy(&gConnectionLock); 449 err4: 450 hash_uninit(gConnectionHash); 451 err3: 452 put_module(NET_DATALINK_MODULE_NAME); 453 err2: 454 put_module(NET_BUFFER_MODULE_NAME); 455 err1: 456 put_module(NET_STACK_MODULE_NAME); 457 458 TRACE(("init_tcp() fails with %lx (%s)\n", status, strerror(status))); 459 return status; 460 } 461 462 463 static status_t 464 tcp_uninit() 465 { 466 benaphore_destroy(&gConnectionLock); 467 hash_uninit(gConnectionHash); 468 put_module(NET_DATALINK_MODULE_NAME); 469 put_module(NET_BUFFER_MODULE_NAME); 470 put_module(NET_STACK_MODULE_NAME); 471 472 return B_OK; 473 } 474 475 476 static status_t 477 tcp_std_ops(int32 op, ...) 478 { 479 switch (op) { 480 case B_MODULE_INIT: 481 return tcp_init(); 482 483 case B_MODULE_UNINIT: 484 return tcp_uninit(); 485 486 default: 487 return B_ERROR; 488 } 489 } 490 491 492 net_protocol_module_info sTCPModule = { 493 { 494 "network/protocols/tcp/v1", 495 0, 496 tcp_std_ops 497 }, 498 tcp_init_protocol, 499 tcp_uninit_protocol, 500 tcp_open, 501 tcp_close, 502 tcp_free, 503 tcp_connect, 504 tcp_accept, 505 tcp_control, 506 tcp_bind, 507 tcp_unbind, 508 tcp_listen, 509 tcp_shutdown, 510 tcp_send_data, 511 tcp_send_routed_data, 512 tcp_send_avail, 513 tcp_read_data, 514 tcp_read_avail, 515 tcp_get_domain, 516 tcp_get_mtu, 517 tcp_receive_data, 518 tcp_error, 519 tcp_error_reply, 520 }; 521 522 module_info *modules[] = { 523 (module_info *)&sTCPModule, 524 NULL 525 }; 526