1 /* 2 * Copyright 2006, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Andrew Galante, haiku.galante@gmail.com 8 */ 9 10 11 #include "EndpointManager.h" 12 #include "TCPEndpoint.h" 13 14 #include <net_protocol.h> 15 #include <net_stat.h> 16 17 #include <KernelExport.h> 18 #include <util/list.h> 19 20 #include <netinet/in.h> 21 #include <netinet/ip.h> 22 #include <new> 23 #include <stdlib.h> 24 #include <string.h> 25 26 #include <lock.h> 27 #include <util/AutoLock.h> 28 29 #include <NetBufferUtilities.h> 30 #include <NetUtilities.h> 31 32 //#define TRACE_TCP 33 #ifdef TRACE_TCP 34 # define TRACE(x) dprintf x 35 # define TRACE_BLOCK(x) dump_block x 36 #else 37 # define TRACE(x) 38 # define TRACE_BLOCK(x) 39 #endif 40 41 42 net_domain *gDomain; 43 net_address_module_info *gAddressModule; 44 net_buffer_module_info *gBufferModule; 45 net_datalink_module_info *gDatalinkModule; 46 net_socket_module_info *gSocketModule; 47 net_stack_module_info *gStackModule; 48 EndpointManager *gEndpointManager; 49 50 51 status_t 52 set_domain(net_interface *interface = NULL) 53 { 54 if (gDomain == NULL) { 55 // domain and address module are not known yet, we copy them from 56 // the buffer's interface (if any): 57 if (interface == NULL || interface->domain == NULL) 58 gDomain = gStackModule->get_domain(AF_INET); 59 else 60 gDomain = interface->domain; 61 62 if (gDomain == NULL) { 63 // this shouldn't occur, of course, but who knows... 64 return B_BAD_VALUE; 65 } 66 gAddressModule = gDomain->address_module; 67 } 68 69 return B_OK; 70 } 71 72 73 static inline void 74 bump_option(tcp_option *&option, size_t &length) 75 { 76 if (option->kind <= TCP_OPTION_NOP) { 77 length++; 78 option = (tcp_option *)((uint8 *)option + 1); 79 } else { 80 length += option->length; 81 option = (tcp_option *)((uint8 *)option + option->length); 82 } 83 } 84 85 86 static inline size_t 87 add_options(tcp_segment_header &segment, uint8 *buffer, size_t bufferSize) 88 { 89 tcp_option *option = (tcp_option *)buffer; 90 size_t length = 0; 91 92 if (segment.max_segment_size > 0 && length + 8 < bufferSize) { 93 option->kind = TCP_OPTION_MAX_SEGMENT_SIZE; 94 option->length = 4; 95 option->max_segment_size = htons(segment.max_segment_size); 96 bump_option(option, length); 97 } 98 if (segment.has_window_shift && length + 4 < bufferSize) { 99 // insert one NOP so that the subsequent data is aligned on a 4 byte boundary 100 option->kind = TCP_OPTION_NOP; 101 bump_option(option, length); 102 103 option->kind = TCP_OPTION_WINDOW_SHIFT; 104 option->length = 3; 105 option->window_shift = segment.window_shift; 106 bump_option(option, length); 107 } 108 109 if ((length & 3) == 0) { 110 // options completely fill out the option space 111 return length; 112 } 113 114 option->kind = TCP_OPTION_END; 115 return (length + 3) & ~3; 116 // bump to a multiple of 4 length 117 } 118 119 120 /*! 121 Constructs a TCP header on \a buffer with the specified values 122 for \a flags, \a seq \a ack and \a advertisedWindow. 123 */ 124 status_t 125 add_tcp_header(tcp_segment_header &segment, net_buffer *buffer) 126 { 127 buffer->protocol = IPPROTO_TCP; 128 129 uint8 optionsBuffer[32]; 130 uint32 optionsLength = add_options(segment, optionsBuffer, sizeof(optionsBuffer)); 131 132 NetBufferPrepend<tcp_header> bufferHeader(buffer, sizeof(tcp_header) + optionsLength); 133 if (bufferHeader.Status() != B_OK) 134 return bufferHeader.Status(); 135 136 tcp_header &header = bufferHeader.Data(); 137 138 header.source_port = gAddressModule->get_port((sockaddr *)&buffer->source); 139 header.destination_port = gAddressModule->get_port((sockaddr *)&buffer->destination); 140 header.sequence = htonl(segment.sequence); 141 header.acknowledge = (segment.flags & TCP_FLAG_ACKNOWLEDGE) 142 ? htonl(segment.acknowledge) : 0; 143 header.reserved = 0; 144 header.header_length = (sizeof(tcp_header) + optionsLength) >> 2; 145 header.flags = segment.flags; 146 header.advertised_window = htons(segment.advertised_window); 147 header.checksum = 0; 148 header.urgent_offset = 0; 149 // TODO: urgent pointer not yet supported 150 151 if (optionsLength > 0) 152 gBufferModule->write(buffer, sizeof(tcp_header), optionsBuffer, optionsLength); 153 154 TRACE(("add_tcp_header(): buffer %p, flags 0x%x, seq %lu, ack %lu, win %u\n", buffer, 155 segment.flags, segment.sequence, segment.acknowledge, segment.advertised_window)); 156 157 // compute and store checksum 158 Checksum checksum; 159 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source); 160 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination); 161 checksum 162 << (uint16)htons(IPPROTO_TCP) 163 << (uint16)htons(buffer->size) 164 << Checksum::BufferHelper(buffer, gBufferModule); 165 header.checksum = checksum; 166 167 return B_OK; 168 } 169 170 171 void 172 process_options(tcp_segment_header &segment, net_buffer *buffer, int32 size) 173 { 174 if (size == 0) 175 return; 176 177 tcp_option *option; 178 uint8 optionsBuffer[32]; 179 if (gBufferModule->direct_access(buffer, sizeof(tcp_header), size, 180 (void **)&option) != B_OK) { 181 if (size > 32) { 182 dprintf("options too large to take into account (%ld bytes)\n", size); 183 return; 184 } 185 186 gBufferModule->read(buffer, sizeof(tcp_header), optionsBuffer, size); 187 option = (tcp_option *)optionsBuffer; 188 } 189 190 while (size > 0) { 191 uint32 length = 1; 192 switch (option->kind) { 193 case TCP_OPTION_END: 194 case TCP_OPTION_NOP: 195 break; 196 case TCP_OPTION_MAX_SEGMENT_SIZE: 197 segment.max_segment_size = ntohs(option->max_segment_size); 198 length = 4; 199 break; 200 case TCP_OPTION_WINDOW_SHIFT: 201 segment.has_window_shift = true; 202 segment.window_shift = option->window_shift; 203 length = 3; 204 break; 205 case TCP_OPTION_TIMESTAMP: 206 // TODO: support timestamp! 207 length = 10; 208 break; 209 210 default: 211 length = option->length; 212 // make sure we don't end up in an endless loop 213 if (length == 0) 214 return; 215 break; 216 } 217 218 size -= length; 219 option = (tcp_option *)((uint8 *)option + length); 220 } 221 // TODO: check if options are valid! 222 } 223 224 225 status_t 226 reply_with_reset(tcp_segment_header &segment, net_buffer *buffer) 227 { 228 TRACE(("TCP: Sending RST...\n")); 229 230 net_buffer *reply = gBufferModule->create(512); 231 if (reply == NULL) 232 return B_NO_MEMORY; 233 234 gAddressModule->set_to((sockaddr *)&reply->source, 235 (sockaddr *)&buffer->destination); 236 gAddressModule->set_to((sockaddr *)&reply->destination, 237 (sockaddr *)&buffer->source); 238 239 tcp_segment_header outSegment; 240 outSegment.flags = TCP_FLAG_RESET; 241 outSegment.sequence = 0; 242 outSegment.acknowledge = 0; 243 outSegment.advertised_window = 0; 244 outSegment.urgent_offset = 0; 245 246 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 247 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 248 outSegment.acknowledge = segment.sequence + buffer->size; 249 } else 250 outSegment.sequence = segment.acknowledge; 251 252 status_t status = add_tcp_header(outSegment, reply); 253 if (status == B_OK) 254 status = gDomain->module->send_data(NULL, reply); 255 256 if (status != B_OK) 257 gBufferModule->free(reply); 258 259 return status; 260 } 261 262 263 static const char * 264 name_for_state(tcp_state state) 265 { 266 switch (state) { 267 case CLOSED: 268 return "closed"; 269 case LISTEN: 270 return "listen"; 271 case SYNCHRONIZE_SENT: 272 return "syn-sent"; 273 case SYNCHRONIZE_RECEIVED: 274 return "syn-received"; 275 case ESTABLISHED: 276 return "established"; 277 278 // peer closes the connection 279 case FINISH_RECEIVED: 280 return "close-wait"; 281 case WAIT_FOR_FINISH_ACKNOWLEDGE: 282 return "last-ack"; 283 284 // we close the connection 285 case FINISH_SENT: 286 return "fin-wait1"; 287 case FINISH_ACKNOWLEDGED: 288 return "fin-wait2"; 289 case CLOSING: 290 return "closing"; 291 292 case TIME_WAIT: 293 return "time-wait"; 294 } 295 296 return "-"; 297 } 298 299 300 #if 0 301 static void 302 dump_tcp_header(tcp_header &header) 303 { 304 dprintf(" source port: %u\n", ntohs(header.source_port)); 305 dprintf(" dest port: %u\n", ntohs(header.destination_port)); 306 dprintf(" sequence: %lu\n", header.Sequence()); 307 dprintf(" ack: %lu\n", header.Acknowledge()); 308 dprintf(" flags: %s%s%s%s%s%s\n", (header.flags & TCP_FLAG_FINISH) ? "FIN " : "", 309 (header.flags & TCP_FLAG_SYNCHRONIZE) ? "SYN " : "", 310 (header.flags & TCP_FLAG_RESET) ? "RST " : "", 311 (header.flags & TCP_FLAG_PUSH) ? "PUSH " : "", 312 (header.flags & TCP_FLAG_ACKNOWLEDGE) ? "ACK " : "", 313 (header.flags & TCP_FLAG_URGENT) ? "URG " : ""); 314 dprintf(" window: %u\n", header.AdvertisedWindow()); 315 dprintf(" urgent offset: %u\n", header.UrgentOffset()); 316 } 317 #endif 318 319 320 // #pragma mark - protocol API 321 322 323 net_protocol * 324 tcp_init_protocol(net_socket *socket) 325 { 326 TCPEndpoint *protocol = new (std::nothrow) TCPEndpoint(socket); 327 if (protocol == NULL) 328 return NULL; 329 330 if (protocol->InitCheck() != B_OK) { 331 delete protocol; 332 return NULL; 333 } 334 335 TRACE(("Creating new TCPEndpoint: %p\n", protocol)); 336 socket->protocol = IPPROTO_TCP; 337 return protocol; 338 } 339 340 341 status_t 342 tcp_uninit_protocol(net_protocol *protocol) 343 { 344 TRACE(("Deleting TCPEndpoint: %p\n", protocol)); 345 delete (TCPEndpoint *)protocol; 346 return B_OK; 347 } 348 349 350 status_t 351 tcp_open(net_protocol *protocol) 352 { 353 if (gDomain == NULL && set_domain() != B_OK) 354 return B_ERROR; 355 356 return ((TCPEndpoint *)protocol)->Open(); 357 } 358 359 360 status_t 361 tcp_close(net_protocol *protocol) 362 { 363 return ((TCPEndpoint *)protocol)->Close(); 364 } 365 366 367 status_t 368 tcp_free(net_protocol *protocol) 369 { 370 return ((TCPEndpoint *)protocol)->Free(); 371 } 372 373 374 status_t 375 tcp_connect(net_protocol *protocol, const struct sockaddr *address) 376 { 377 return ((TCPEndpoint *)protocol)->Connect(address); 378 } 379 380 381 status_t 382 tcp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket) 383 { 384 return ((TCPEndpoint *)protocol)->Accept(_acceptedSocket); 385 } 386 387 388 status_t 389 tcp_control(net_protocol *_protocol, int level, int option, void *value, 390 size_t *_length) 391 { 392 TCPEndpoint *protocol = (TCPEndpoint *)_protocol; 393 394 switch (level & LEVEL_MASK) { 395 case IPPROTO_TCP: 396 if (option == NET_STAT_SOCKET) { 397 net_stat *stat = (net_stat *)value; 398 strlcpy(stat->state, name_for_state(protocol->State()), 399 sizeof(stat->state)); 400 return B_OK; 401 } 402 break; 403 case SOL_SOCKET: 404 break; 405 406 default: 407 return protocol->next->module->control(protocol->next, level, option, 408 value, _length); 409 } 410 411 return B_BAD_VALUE; 412 } 413 414 415 status_t 416 tcp_bind(net_protocol *protocol, struct sockaddr *address) 417 { 418 return ((TCPEndpoint *)protocol)->Bind(address); 419 } 420 421 422 status_t 423 tcp_unbind(net_protocol *protocol, struct sockaddr *address) 424 { 425 return ((TCPEndpoint *)protocol)->Unbind(address); 426 } 427 428 429 status_t 430 tcp_listen(net_protocol *protocol, int count) 431 { 432 return ((TCPEndpoint *)protocol)->Listen(count); 433 } 434 435 436 status_t 437 tcp_shutdown(net_protocol *protocol, int direction) 438 { 439 return ((TCPEndpoint *)protocol)->Shutdown(direction); 440 } 441 442 443 status_t 444 tcp_send_data(net_protocol *protocol, net_buffer *buffer) 445 { 446 return ((TCPEndpoint *)protocol)->SendData(buffer); 447 } 448 449 450 status_t 451 tcp_send_routed_data(net_protocol *protocol, struct net_route *route, 452 net_buffer *buffer) 453 { 454 // TCP never sends routed data 455 return B_ERROR; 456 } 457 458 459 ssize_t 460 tcp_send_avail(net_protocol *protocol) 461 { 462 return ((TCPEndpoint *)protocol)->SendAvailable(); 463 } 464 465 466 status_t 467 tcp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags, 468 net_buffer **_buffer) 469 { 470 return ((TCPEndpoint *)protocol)->ReadData(numBytes, flags, _buffer); 471 } 472 473 474 ssize_t 475 tcp_read_avail(net_protocol *protocol) 476 { 477 return ((TCPEndpoint *)protocol)->ReadAvailable(); 478 } 479 480 481 struct net_domain * 482 tcp_get_domain(net_protocol *protocol) 483 { 484 return protocol->next->module->get_domain(protocol->next); 485 } 486 487 488 size_t 489 tcp_get_mtu(net_protocol *protocol, const struct sockaddr *address) 490 { 491 return protocol->next->module->get_mtu(protocol->next, address); 492 } 493 494 495 status_t 496 tcp_receive_data(net_buffer *buffer) 497 { 498 TRACE(("TCP: Received buffer %p\n", buffer)); 499 500 if (gDomain == NULL && set_domain(buffer->interface) != B_OK) 501 return B_ERROR; 502 503 NetBufferHeader<tcp_header> bufferHeader(buffer); 504 if (bufferHeader.Status() < B_OK) 505 return bufferHeader.Status(); 506 507 tcp_header &header = bufferHeader.Data(); 508 509 uint16 headerLength = header.HeaderLength(); 510 if (headerLength < sizeof(tcp_header)) 511 return B_BAD_DATA; 512 513 // compute checksum using a pseudo IP header 514 Checksum checksum; 515 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source); 516 gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination); 517 checksum << (uint16)htons(IPPROTO_TCP) 518 << (uint16)htons(buffer->size) 519 << Checksum::BufferHelper(buffer, gBufferModule); 520 521 if (checksum != 0) 522 return B_BAD_DATA; 523 524 gAddressModule->set_port((struct sockaddr *)&buffer->source, header.source_port); 525 gAddressModule->set_port((struct sockaddr *)&buffer->destination, header.destination_port); 526 527 TRACE((" Looking for: peer %s, local %s\n", 528 AddressString(gDomain, (sockaddr *)&buffer->source, true).Data(), 529 AddressString(gDomain, (sockaddr *)&buffer->destination, true).Data())); 530 //dump_tcp_header(header); 531 //gBufferModule->dump(buffer); 532 533 tcp_segment_header segment; 534 segment.sequence = header.Sequence(); 535 segment.acknowledge = header.Acknowledge(); 536 segment.advertised_window = header.AdvertisedWindow(); 537 segment.urgent_offset = header.UrgentOffset(); 538 segment.flags = header.flags; 539 if ((segment.flags & TCP_FLAG_SYNCHRONIZE) != 0) { 540 // for now, we only process the options in the SYN segment 541 // TODO: when we support timestamps, they could be handled specifically 542 process_options(segment, buffer, headerLength - sizeof(tcp_header)); 543 } 544 545 bufferHeader.Remove(headerLength); 546 // we no longer need to keep the header around 547 548 RecursiveLocker locker(gEndpointManager->Locker()); 549 int32 segmentAction = DROP; 550 551 TCPEndpoint *endpoint = gEndpointManager->FindConnection( 552 (struct sockaddr *)&buffer->destination, (struct sockaddr *)&buffer->source); 553 if (endpoint != NULL) { 554 RecursiveLocker locker(endpoint->Lock()); 555 TRACE(("Endpoint %p in state %s\n", endpoint, name_for_state(endpoint->State()))); 556 557 switch (endpoint->State()) { 558 case LISTEN: 559 segmentAction = endpoint->ListenReceive(segment, buffer); 560 break; 561 562 case SYNCHRONIZE_SENT: 563 segmentAction = endpoint->SynchronizeSentReceive(segment, buffer); 564 break; 565 566 case SYNCHRONIZE_RECEIVED: 567 case ESTABLISHED: 568 case FINISH_RECEIVED: 569 case WAIT_FOR_FINISH_ACKNOWLEDGE: 570 case FINISH_SENT: 571 case FINISH_ACKNOWLEDGED: 572 case CLOSING: 573 case TIME_WAIT: 574 case CLOSED: 575 segmentAction = endpoint->Receive(segment, buffer); 576 break; 577 } 578 579 // process acknowledge action as asked for by the *Receive() method 580 if (segmentAction & IMMEDIATE_ACKNOWLEDGE) 581 endpoint->SendAcknowledge(); 582 else if (segmentAction & ACKNOWLEDGE) 583 endpoint->DelayedAcknowledge(); 584 else if (segmentAction & DELETE) 585 gSocketModule->delete_socket(endpoint->socket); 586 } else if ((segment.flags & TCP_FLAG_RESET) == 0) 587 segmentAction = DROP | RESET; 588 589 if (segmentAction & RESET) { 590 // send reset 591 reply_with_reset(segment, buffer); 592 } 593 if (segmentAction & DROP) 594 gBufferModule->free(buffer); 595 596 return B_OK; 597 } 598 599 600 status_t 601 tcp_error(uint32 code, net_buffer *data) 602 { 603 return B_ERROR; 604 } 605 606 607 status_t 608 tcp_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code, 609 void *errorData) 610 { 611 return B_ERROR; 612 } 613 614 615 // #pragma mark - 616 617 618 static status_t 619 tcp_init() 620 { 621 status_t status; 622 623 gDomain = NULL; 624 gAddressModule = NULL; 625 626 status = get_module(NET_STACK_MODULE_NAME, (module_info **)&gStackModule); 627 if (status < B_OK) 628 return status; 629 status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule); 630 if (status < B_OK) 631 goto err1; 632 status = get_module(NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule); 633 if (status < B_OK) 634 goto err2; 635 status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule); 636 if (status < B_OK) 637 goto err3; 638 639 gEndpointManager = new (std::nothrow) EndpointManager(); 640 if (gEndpointManager == NULL) { 641 status = B_NO_MEMORY; 642 goto err4; 643 } 644 status = gEndpointManager->InitCheck(); 645 if (status < B_OK) 646 goto err5; 647 648 status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, 0, 649 "network/protocols/tcp/v1", 650 "network/protocols/ipv4/v1", 651 NULL); 652 if (status < B_OK) 653 goto err5; 654 655 status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_TCP, 656 "network/protocols/tcp/v1", 657 "network/protocols/ipv4/v1", 658 NULL); 659 if (status < B_OK) 660 goto err5; 661 662 status = gStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_TCP, 663 "network/protocols/tcp/v1"); 664 if (status < B_OK) 665 goto err5; 666 667 return B_OK; 668 669 err5: 670 delete gEndpointManager; 671 err4: 672 put_module(NET_DATALINK_MODULE_NAME); 673 err3: 674 put_module(NET_SOCKET_MODULE_NAME); 675 err2: 676 put_module(NET_BUFFER_MODULE_NAME); 677 err1: 678 put_module(NET_STACK_MODULE_NAME); 679 680 TRACE(("init_tcp() fails with %lx (%s)\n", status, strerror(status))); 681 return status; 682 } 683 684 685 static status_t 686 tcp_uninit() 687 { 688 delete gEndpointManager; 689 690 put_module(NET_DATALINK_MODULE_NAME); 691 put_module(NET_SOCKET_MODULE_NAME); 692 put_module(NET_BUFFER_MODULE_NAME); 693 put_module(NET_STACK_MODULE_NAME); 694 695 return B_OK; 696 } 697 698 699 static status_t 700 tcp_std_ops(int32 op, ...) 701 { 702 switch (op) { 703 case B_MODULE_INIT: 704 return tcp_init(); 705 706 case B_MODULE_UNINIT: 707 return tcp_uninit(); 708 709 default: 710 return B_ERROR; 711 } 712 } 713 714 715 net_protocol_module_info sTCPModule = { 716 { 717 "network/protocols/tcp/v1", 718 0, 719 tcp_std_ops 720 }, 721 tcp_init_protocol, 722 tcp_uninit_protocol, 723 tcp_open, 724 tcp_close, 725 tcp_free, 726 tcp_connect, 727 tcp_accept, 728 tcp_control, 729 tcp_bind, 730 tcp_unbind, 731 tcp_listen, 732 tcp_shutdown, 733 tcp_send_data, 734 tcp_send_routed_data, 735 tcp_send_avail, 736 tcp_read_data, 737 tcp_read_avail, 738 tcp_get_domain, 739 tcp_get_mtu, 740 tcp_receive_data, 741 tcp_error, 742 tcp_error_reply, 743 }; 744 745 module_info *modules[] = { 746 (module_info *)&sTCPModule, 747 NULL 748 }; 749