1 /* 2 * Copyright 2006, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Andrew Galante, haiku.galante@gmail.com 8 */ 9 10 11 #include "EndpointManager.h" 12 #include "TCPEndpoint.h" 13 14 #include <net_protocol.h> 15 #include <net_stat.h> 16 17 #include <KernelExport.h> 18 #include <util/list.h> 19 20 #include <netinet/in.h> 21 #include <netinet/ip.h> 22 #include <new> 23 #include <stdlib.h> 24 #include <string.h> 25 26 #include <lock.h> 27 #include <util/AutoLock.h> 28 29 #include <NetBufferUtilities.h> 30 #include <NetUtilities.h> 31 32 //#define TRACE_TCP 33 #ifdef TRACE_TCP 34 # define TRACE(x) dprintf x 35 # define TRACE_BLOCK(x) dump_block x 36 #else 37 # define TRACE(x) 38 # define TRACE_BLOCK(x) 39 #endif 40 41 42 typedef NetBufferField<uint16, offsetof(tcp_header, checksum)> TCPChecksumField; 43 44 45 net_domain *gDomain; 46 net_address_module_info *gAddressModule; 47 net_buffer_module_info *gBufferModule; 48 net_datalink_module_info *gDatalinkModule; 49 net_socket_module_info *gSocketModule; 50 net_stack_module_info *gStackModule; 51 EndpointManager *gEndpointManager; 52 53 54 status_t 55 set_domain(net_interface *interface = NULL) 56 { 57 if (gDomain == NULL) { 58 // domain and address module are not known yet, we copy them from 59 // the buffer's interface (if any): 60 if (interface == NULL || interface->domain == NULL) 61 gDomain = gStackModule->get_domain(AF_INET); 62 else 63 gDomain = interface->domain; 64 65 if (gDomain == NULL) { 66 // this shouldn't occur, of course, but who knows... 67 return B_BAD_VALUE; 68 } 69 gAddressModule = gDomain->address_module; 70 } 71 72 return B_OK; 73 } 74 75 76 static inline void 77 bump_option(tcp_option *&option, size_t &length) 78 { 79 if (option->kind <= TCP_OPTION_NOP) { 80 length++; 81 option = (tcp_option *)((uint8 *)option + 1); 82 } else { 83 length += option->length; 84 option = (tcp_option *)((uint8 *)option + option->length); 85 } 86 } 87 88 89 static inline size_t 90 add_options(tcp_segment_header &segment, uint8 *buffer, size_t bufferSize) 91 { 92 tcp_option *option = (tcp_option *)buffer; 93 size_t length = 0; 94 95 if (segment.max_segment_size > 0 && length + 8 < bufferSize) { 96 option->kind = TCP_OPTION_MAX_SEGMENT_SIZE; 97 option->length = 4; 98 option->max_segment_size = htons(segment.max_segment_size); 99 bump_option(option, length); 100 } 101 if (segment.has_window_shift && length + 4 < bufferSize) { 102 // insert one NOP so that the subsequent data is aligned on a 4 byte boundary 103 option->kind = TCP_OPTION_NOP; 104 bump_option(option, length); 105 106 option->kind = TCP_OPTION_WINDOW_SHIFT; 107 option->length = 3; 108 option->window_shift = segment.window_shift; 109 bump_option(option, length); 110 } 111 112 if ((length & 3) == 0) { 113 // options completely fill out the option space 114 return length; 115 } 116 117 option->kind = TCP_OPTION_END; 118 return (length + 3) & ~3; 119 // bump to a multiple of 4 length 120 } 121 122 123 /*! 124 Constructs a TCP header on \a buffer with the specified values 125 for \a flags, \a seq \a ack and \a advertisedWindow. 126 */ 127 status_t 128 add_tcp_header(tcp_segment_header &segment, net_buffer *buffer) 129 { 130 buffer->protocol = IPPROTO_TCP; 131 132 uint8 optionsBuffer[32]; 133 uint32 optionsLength = add_options(segment, optionsBuffer, sizeof(optionsBuffer)); 134 135 NetBufferPrepend<tcp_header> bufferHeader(buffer, sizeof(tcp_header) + optionsLength); 136 if (bufferHeader.Status() != B_OK) 137 return bufferHeader.Status(); 138 139 tcp_header &header = bufferHeader.Data(); 140 141 header.source_port = gAddressModule->get_port((sockaddr *)&buffer->source); 142 header.destination_port = gAddressModule->get_port((sockaddr *)&buffer->destination); 143 header.sequence = htonl(segment.sequence); 144 header.acknowledge = (segment.flags & TCP_FLAG_ACKNOWLEDGE) 145 ? htonl(segment.acknowledge) : 0; 146 header.reserved = 0; 147 header.header_length = (sizeof(tcp_header) + optionsLength) >> 2; 148 header.flags = segment.flags; 149 header.advertised_window = htons(segment.advertised_window); 150 header.checksum = 0; 151 header.urgent_offset = 0; 152 // TODO: urgent pointer not yet supported 153 154 // we must detach before calculating the checksum as we may 155 // not have a contiguous buffer. 156 bufferHeader.Sync(); 157 158 if (optionsLength > 0) 159 gBufferModule->write(buffer, sizeof(tcp_header), optionsBuffer, optionsLength); 160 161 TRACE(("add_tcp_header(): buffer %p, flags 0x%x, seq %lu, ack %lu, win %u\n", buffer, 162 segment.flags, segment.sequence, segment.acknowledge, segment.advertised_window)); 163 164 // compute and store checksum 165 Checksum checksum; 166 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source); 167 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination); 168 checksum 169 << (uint16)htons(IPPROTO_TCP) 170 << (uint16)htons(buffer->size) 171 << Checksum::BufferHelper(buffer, gBufferModule); 172 173 *TCPChecksumField(buffer) = checksum; 174 175 return B_OK; 176 } 177 178 179 void 180 process_options(tcp_segment_header &segment, net_buffer *buffer, int32 size) 181 { 182 if (size == 0) 183 return; 184 185 tcp_option *option; 186 uint8 optionsBuffer[32]; 187 if (gBufferModule->direct_access(buffer, sizeof(tcp_header), size, 188 (void **)&option) != B_OK) { 189 if (size > 32) { 190 dprintf("options too large to take into account (%ld bytes)\n", size); 191 return; 192 } 193 194 gBufferModule->read(buffer, sizeof(tcp_header), optionsBuffer, size); 195 option = (tcp_option *)optionsBuffer; 196 } 197 198 while (size > 0) { 199 uint32 length = 1; 200 switch (option->kind) { 201 case TCP_OPTION_END: 202 case TCP_OPTION_NOP: 203 break; 204 case TCP_OPTION_MAX_SEGMENT_SIZE: 205 segment.max_segment_size = ntohs(option->max_segment_size); 206 length = 4; 207 break; 208 case TCP_OPTION_WINDOW_SHIFT: 209 segment.has_window_shift = true; 210 segment.window_shift = option->window_shift; 211 length = 3; 212 break; 213 case TCP_OPTION_TIMESTAMP: 214 // TODO: support timestamp! 215 length = 10; 216 break; 217 218 default: 219 length = option->length; 220 // make sure we don't end up in an endless loop 221 if (length == 0) 222 return; 223 break; 224 } 225 226 size -= length; 227 option = (tcp_option *)((uint8 *)option + length); 228 } 229 // TODO: check if options are valid! 230 } 231 232 233 status_t 234 reply_with_reset(tcp_segment_header &segment, net_buffer *buffer) 235 { 236 TRACE(("TCP: Sending RST...\n")); 237 238 net_buffer *reply = gBufferModule->create(512); 239 if (reply == NULL) 240 return B_NO_MEMORY; 241 242 gAddressModule->set_to((sockaddr *)&reply->source, 243 (sockaddr *)&buffer->destination); 244 gAddressModule->set_to((sockaddr *)&reply->destination, 245 (sockaddr *)&buffer->source); 246 247 tcp_segment_header outSegment; 248 outSegment.flags = TCP_FLAG_RESET; 249 outSegment.sequence = 0; 250 outSegment.acknowledge = 0; 251 outSegment.advertised_window = 0; 252 outSegment.urgent_offset = 0; 253 254 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 255 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 256 outSegment.acknowledge = segment.sequence + buffer->size; 257 } else 258 outSegment.sequence = segment.acknowledge; 259 260 status_t status = add_tcp_header(outSegment, reply); 261 if (status == B_OK) 262 status = gDomain->module->send_data(NULL, reply); 263 264 if (status != B_OK) 265 gBufferModule->free(reply); 266 267 return status; 268 } 269 270 271 const char * 272 name_for_state(tcp_state state) 273 { 274 switch (state) { 275 case CLOSED: 276 return "closed"; 277 case LISTEN: 278 return "listen"; 279 case SYNCHRONIZE_SENT: 280 return "syn-sent"; 281 case SYNCHRONIZE_RECEIVED: 282 return "syn-received"; 283 case ESTABLISHED: 284 return "established"; 285 286 // peer closes the connection 287 case FINISH_RECEIVED: 288 return "close-wait"; 289 case WAIT_FOR_FINISH_ACKNOWLEDGE: 290 return "last-ack"; 291 292 // we close the connection 293 case FINISH_SENT: 294 return "fin-wait1"; 295 case FINISH_ACKNOWLEDGED: 296 return "fin-wait2"; 297 case CLOSING: 298 return "closing"; 299 300 case TIME_WAIT: 301 return "time-wait"; 302 } 303 304 return "-"; 305 } 306 307 308 #if 0 309 static void 310 dump_tcp_header(tcp_header &header) 311 { 312 dprintf(" source port: %u\n", ntohs(header.source_port)); 313 dprintf(" dest port: %u\n", ntohs(header.destination_port)); 314 dprintf(" sequence: %lu\n", header.Sequence()); 315 dprintf(" ack: %lu\n", header.Acknowledge()); 316 dprintf(" flags: %s%s%s%s%s%s\n", (header.flags & TCP_FLAG_FINISH) ? "FIN " : "", 317 (header.flags & TCP_FLAG_SYNCHRONIZE) ? "SYN " : "", 318 (header.flags & TCP_FLAG_RESET) ? "RST " : "", 319 (header.flags & TCP_FLAG_PUSH) ? "PUSH " : "", 320 (header.flags & TCP_FLAG_ACKNOWLEDGE) ? "ACK " : "", 321 (header.flags & TCP_FLAG_URGENT) ? "URG " : ""); 322 dprintf(" window: %u\n", header.AdvertisedWindow()); 323 dprintf(" urgent offset: %u\n", header.UrgentOffset()); 324 } 325 #endif 326 327 328 // #pragma mark - protocol API 329 330 331 net_protocol * 332 tcp_init_protocol(net_socket *socket) 333 { 334 TCPEndpoint *protocol = new (std::nothrow) TCPEndpoint(socket); 335 if (protocol == NULL) 336 return NULL; 337 338 if (protocol->InitCheck() != B_OK) { 339 delete protocol; 340 return NULL; 341 } 342 343 TRACE(("Creating new TCPEndpoint: %p\n", protocol)); 344 socket->protocol = IPPROTO_TCP; 345 return protocol; 346 } 347 348 349 status_t 350 tcp_uninit_protocol(net_protocol *protocol) 351 { 352 TRACE(("Deleting TCPEndpoint: %p\n", protocol)); 353 delete (TCPEndpoint *)protocol; 354 return B_OK; 355 } 356 357 358 status_t 359 tcp_open(net_protocol *protocol) 360 { 361 if (gDomain == NULL && set_domain() != B_OK) 362 return B_ERROR; 363 364 return ((TCPEndpoint *)protocol)->Open(); 365 } 366 367 368 status_t 369 tcp_close(net_protocol *protocol) 370 { 371 return ((TCPEndpoint *)protocol)->Close(); 372 } 373 374 375 status_t 376 tcp_free(net_protocol *protocol) 377 { 378 return ((TCPEndpoint *)protocol)->Free(); 379 } 380 381 382 status_t 383 tcp_connect(net_protocol *protocol, const struct sockaddr *address) 384 { 385 return ((TCPEndpoint *)protocol)->Connect(address); 386 } 387 388 389 status_t 390 tcp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket) 391 { 392 return ((TCPEndpoint *)protocol)->Accept(_acceptedSocket); 393 } 394 395 396 status_t 397 tcp_control(net_protocol *_protocol, int level, int option, void *value, 398 size_t *_length) 399 { 400 TCPEndpoint *protocol = (TCPEndpoint *)_protocol; 401 402 switch (level & LEVEL_MASK) { 403 case IPPROTO_TCP: 404 if (option == NET_STAT_SOCKET) { 405 net_stat *stat = (net_stat *)value; 406 strlcpy(stat->state, name_for_state(protocol->State()), 407 sizeof(stat->state)); 408 return B_OK; 409 } 410 break; 411 case SOL_SOCKET: 412 break; 413 414 default: 415 return protocol->next->module->control(protocol->next, level, option, 416 value, _length); 417 } 418 419 return B_BAD_VALUE; 420 } 421 422 423 status_t 424 tcp_bind(net_protocol *protocol, struct sockaddr *address) 425 { 426 return ((TCPEndpoint *)protocol)->Bind(address); 427 } 428 429 430 status_t 431 tcp_unbind(net_protocol *protocol, struct sockaddr *address) 432 { 433 return ((TCPEndpoint *)protocol)->Unbind(address); 434 } 435 436 437 status_t 438 tcp_listen(net_protocol *protocol, int count) 439 { 440 return ((TCPEndpoint *)protocol)->Listen(count); 441 } 442 443 444 status_t 445 tcp_shutdown(net_protocol *protocol, int direction) 446 { 447 return ((TCPEndpoint *)protocol)->Shutdown(direction); 448 } 449 450 451 status_t 452 tcp_send_data(net_protocol *protocol, net_buffer *buffer) 453 { 454 return ((TCPEndpoint *)protocol)->SendData(buffer); 455 } 456 457 458 status_t 459 tcp_send_routed_data(net_protocol *protocol, struct net_route *route, 460 net_buffer *buffer) 461 { 462 // TCP never sends routed data 463 return B_ERROR; 464 } 465 466 467 ssize_t 468 tcp_send_avail(net_protocol *protocol) 469 { 470 return ((TCPEndpoint *)protocol)->SendAvailable(); 471 } 472 473 474 status_t 475 tcp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags, 476 net_buffer **_buffer) 477 { 478 return ((TCPEndpoint *)protocol)->ReadData(numBytes, flags, _buffer); 479 } 480 481 482 ssize_t 483 tcp_read_avail(net_protocol *protocol) 484 { 485 return ((TCPEndpoint *)protocol)->ReadAvailable(); 486 } 487 488 489 struct net_domain * 490 tcp_get_domain(net_protocol *protocol) 491 { 492 return protocol->next->module->get_domain(protocol->next); 493 } 494 495 496 size_t 497 tcp_get_mtu(net_protocol *protocol, const struct sockaddr *address) 498 { 499 return protocol->next->module->get_mtu(protocol->next, address); 500 } 501 502 503 status_t 504 tcp_receive_data(net_buffer *buffer) 505 { 506 TRACE(("TCP: Received buffer %p\n", buffer)); 507 508 if (gDomain == NULL && set_domain(buffer->interface) != B_OK) 509 return B_ERROR; 510 511 NetBufferHeaderReader<tcp_header> bufferHeader(buffer); 512 if (bufferHeader.Status() < B_OK) 513 return bufferHeader.Status(); 514 515 tcp_header &header = bufferHeader.Data(); 516 517 uint16 headerLength = header.HeaderLength(); 518 if (headerLength < sizeof(tcp_header)) 519 return B_BAD_DATA; 520 521 // compute checksum using a pseudo IP header 522 Checksum checksum; 523 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source); 524 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination); 525 checksum << (uint16)htons(IPPROTO_TCP) 526 << (uint16)htons(buffer->size) 527 << Checksum::BufferHelper(buffer, gBufferModule); 528 529 if (checksum != 0) 530 return B_BAD_DATA; 531 532 gAddressModule->set_port((struct sockaddr *)&buffer->source, header.source_port); 533 gAddressModule->set_port((struct sockaddr *)&buffer->destination, header.destination_port); 534 535 TRACE((" Looking for: peer %s, local %s\n", 536 AddressString(gDomain, (sockaddr *)&buffer->source, true).Data(), 537 AddressString(gDomain, (sockaddr *)&buffer->destination, true).Data())); 538 //dump_tcp_header(header); 539 //gBufferModule->dump(buffer); 540 541 tcp_segment_header segment; 542 segment.sequence = header.Sequence(); 543 segment.acknowledge = header.Acknowledge(); 544 segment.advertised_window = header.AdvertisedWindow(); 545 segment.urgent_offset = header.UrgentOffset(); 546 segment.flags = header.flags; 547 if ((segment.flags & TCP_FLAG_SYNCHRONIZE) != 0) { 548 // for now, we only process the options in the SYN segment 549 // TODO: when we support timestamps, they could be handled specifically 550 process_options(segment, buffer, headerLength - sizeof(tcp_header)); 551 } 552 553 bufferHeader.Remove(headerLength); 554 // we no longer need to keep the header around 555 556 RecursiveLocker locker(gEndpointManager->Locker()); 557 int32 segmentAction = DROP; 558 559 TCPEndpoint *endpoint = gEndpointManager->FindConnection( 560 (struct sockaddr *)&buffer->destination, (struct sockaddr *)&buffer->source); 561 if (endpoint != NULL) { 562 RecursiveLocker locker(endpoint->Lock()); 563 TRACE(("Endpoint %p in state %s\n", endpoint, name_for_state(endpoint->State()))); 564 565 switch (endpoint->State()) { 566 case LISTEN: 567 segmentAction = endpoint->ListenReceive(segment, buffer); 568 break; 569 570 case SYNCHRONIZE_SENT: 571 segmentAction = endpoint->SynchronizeSentReceive(segment, buffer); 572 break; 573 574 case SYNCHRONIZE_RECEIVED: 575 case ESTABLISHED: 576 case FINISH_RECEIVED: 577 case WAIT_FOR_FINISH_ACKNOWLEDGE: 578 case FINISH_SENT: 579 case FINISH_ACKNOWLEDGED: 580 case CLOSING: 581 case TIME_WAIT: 582 case CLOSED: 583 segmentAction = endpoint->Receive(segment, buffer); 584 break; 585 } 586 587 // process acknowledge action as asked for by the *Receive() method 588 if (segmentAction & IMMEDIATE_ACKNOWLEDGE) 589 endpoint->SendAcknowledge(); 590 else if (segmentAction & ACKNOWLEDGE) 591 endpoint->DelayedAcknowledge(); 592 else if (segmentAction & DELETE) 593 gSocketModule->delete_socket(endpoint->socket); 594 } else if ((segment.flags & TCP_FLAG_RESET) == 0) 595 segmentAction = DROP | RESET; 596 597 if (segmentAction & RESET) { 598 // send reset 599 reply_with_reset(segment, buffer); 600 } 601 if (segmentAction & DROP) 602 gBufferModule->free(buffer); 603 604 return B_OK; 605 } 606 607 608 status_t 609 tcp_error(uint32 code, net_buffer *data) 610 { 611 return B_ERROR; 612 } 613 614 615 status_t 616 tcp_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code, 617 void *errorData) 618 { 619 return B_ERROR; 620 } 621 622 623 // #pragma mark - 624 625 626 static status_t 627 tcp_init() 628 { 629 status_t status; 630 631 gDomain = NULL; 632 gAddressModule = NULL; 633 634 status = get_module(NET_STACK_MODULE_NAME, (module_info **)&gStackModule); 635 if (status < B_OK) 636 return status; 637 status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule); 638 if (status < B_OK) 639 goto err1; 640 status = get_module(NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule); 641 if (status < B_OK) 642 goto err2; 643 status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule); 644 if (status < B_OK) 645 goto err3; 646 647 gEndpointManager = new (std::nothrow) EndpointManager(); 648 if (gEndpointManager == NULL) { 649 status = B_NO_MEMORY; 650 goto err4; 651 } 652 status = gEndpointManager->InitCheck(); 653 if (status < B_OK) 654 goto err5; 655 656 status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, 0, 657 "network/protocols/tcp/v1", 658 "network/protocols/ipv4/v1", 659 NULL); 660 if (status < B_OK) 661 goto err5; 662 663 status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_TCP, 664 "network/protocols/tcp/v1", 665 "network/protocols/ipv4/v1", 666 NULL); 667 if (status < B_OK) 668 goto err5; 669 670 status = gStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_TCP, 671 "network/protocols/tcp/v1"); 672 if (status < B_OK) 673 goto err5; 674 675 return B_OK; 676 677 err5: 678 delete gEndpointManager; 679 err4: 680 put_module(NET_DATALINK_MODULE_NAME); 681 err3: 682 put_module(NET_SOCKET_MODULE_NAME); 683 err2: 684 put_module(NET_BUFFER_MODULE_NAME); 685 err1: 686 put_module(NET_STACK_MODULE_NAME); 687 688 TRACE(("init_tcp() fails with %lx (%s)\n", status, strerror(status))); 689 return status; 690 } 691 692 693 static status_t 694 tcp_uninit() 695 { 696 delete gEndpointManager; 697 698 put_module(NET_DATALINK_MODULE_NAME); 699 put_module(NET_SOCKET_MODULE_NAME); 700 put_module(NET_BUFFER_MODULE_NAME); 701 put_module(NET_STACK_MODULE_NAME); 702 703 return B_OK; 704 } 705 706 707 static status_t 708 tcp_std_ops(int32 op, ...) 709 { 710 switch (op) { 711 case B_MODULE_INIT: 712 return tcp_init(); 713 714 case B_MODULE_UNINIT: 715 return tcp_uninit(); 716 717 default: 718 return B_ERROR; 719 } 720 } 721 722 723 net_protocol_module_info sTCPModule = { 724 { 725 "network/protocols/tcp/v1", 726 0, 727 tcp_std_ops 728 }, 729 tcp_init_protocol, 730 tcp_uninit_protocol, 731 tcp_open, 732 tcp_close, 733 tcp_free, 734 tcp_connect, 735 tcp_accept, 736 tcp_control, 737 tcp_bind, 738 tcp_unbind, 739 tcp_listen, 740 tcp_shutdown, 741 tcp_send_data, 742 tcp_send_routed_data, 743 tcp_send_avail, 744 tcp_read_data, 745 tcp_read_avail, 746 tcp_get_domain, 747 tcp_get_mtu, 748 tcp_receive_data, 749 tcp_error, 750 tcp_error_reply, 751 }; 752 753 module_info *modules[] = { 754 (module_info *)&sTCPModule, 755 NULL 756 }; 757