1 /* 2 * Copyright 2006-2009, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Hugo Santos, hugosantos@gmail.com 8 */ 9 10 11 #include "EndpointManager.h" 12 13 #include <new> 14 #include <unistd.h> 15 16 #include <KernelExport.h> 17 18 #include <NetUtilities.h> 19 #include <tracing.h> 20 21 #include "TCPEndpoint.h" 22 23 24 //#define TRACE_ENDPOINT_MANAGER 25 #ifdef TRACE_ENDPOINT_MANAGER 26 # define TRACE(x) dprintf x 27 #else 28 # define TRACE(x) 29 #endif 30 31 #if TCP_TRACING 32 # define ENDPOINT_TRACING 33 #endif 34 #ifdef ENDPOINT_TRACING 35 namespace EndpointTracing { 36 37 class Bind : public AbstractTraceEntry { 38 public: 39 Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral) 40 : 41 fEndpoint(endpoint), 42 fEphemeral(ephemeral) 43 { 44 address.AsString(fAddress, sizeof(fAddress), true); 45 Initialized(); 46 } 47 48 Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral) 49 : 50 fEndpoint(endpoint), 51 fEphemeral(ephemeral) 52 { 53 address.AsString(fAddress, sizeof(fAddress), true); 54 Initialized(); 55 } 56 57 virtual void AddDump(TraceOutput& out) 58 { 59 out.Print("tcp:e:%p bind%s address %s", fEndpoint, 60 fEphemeral ? " ephemeral" : "", fAddress); 61 } 62 63 protected: 64 TCPEndpoint* fEndpoint; 65 char fAddress[32]; 66 bool fEphemeral; 67 }; 68 69 class Connect : public AbstractTraceEntry { 70 public: 71 Connect(TCPEndpoint* endpoint) 72 : 73 fEndpoint(endpoint) 74 { 75 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true); 76 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true); 77 Initialized(); 78 } 79 80 virtual void AddDump(TraceOutput& out) 81 { 82 out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal, 83 fPeer); 84 } 85 86 protected: 87 TCPEndpoint* fEndpoint; 88 char fLocal[32]; 89 char fPeer[32]; 90 }; 91 92 class Unbind : public AbstractTraceEntry { 93 public: 94 Unbind(TCPEndpoint* endpoint) 95 : 96 fEndpoint(endpoint) 97 { 98 //fStackTrace = capture_tracing_stack_trace(10, 0, false); 99 100 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true); 101 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true); 102 Initialized(); 103 } 104 105 #if 0 106 virtual void DumpStackTrace(TraceOutput& out) 107 { 108 out.PrintStackTrace(fStackTrace); 109 } 110 #endif 111 112 virtual void AddDump(TraceOutput& out) 113 { 114 out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal, 115 fPeer); 116 } 117 118 protected: 119 TCPEndpoint* fEndpoint; 120 //tracing_stack_trace* fStackTrace; 121 char fLocal[32]; 122 char fPeer[32]; 123 }; 124 125 } // namespace EndpointTracing 126 127 # define T(x) new(std::nothrow) EndpointTracing::x 128 #else 129 # define T(x) 130 #endif // ENDPOINT_TRACING 131 132 133 static const uint16 kLastReservedPort = 1023; 134 static const uint16 kFirstEphemeralPort = 40000; 135 136 137 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager) 138 : 139 fManager(manager) 140 { 141 } 142 143 144 size_t 145 ConnectionHashDefinition::HashKey(const KeyType& key) const 146 { 147 return ConstSocketAddress(fManager->AddressModule(), 148 key.first).HashPair(key.second); 149 } 150 151 152 size_t 153 ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const 154 { 155 return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress()); 156 } 157 158 159 bool 160 ConnectionHashDefinition::Compare(const KeyType& key, 161 TCPEndpoint* endpoint) const 162 { 163 return endpoint->LocalAddress().EqualTo(key.first, true) 164 && endpoint->PeerAddress().EqualTo(key.second, true); 165 } 166 167 168 TCPEndpoint*& 169 ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const 170 { 171 return endpoint->fConnectionHashLink; 172 } 173 174 175 // #pragma mark - 176 177 178 size_t 179 EndpointHashDefinition::HashKey(uint16 port) const 180 { 181 return port; 182 } 183 184 185 size_t 186 EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const 187 { 188 return endpoint->LocalAddress().Port(); 189 } 190 191 192 bool 193 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const 194 { 195 return endpoint->LocalAddress().Port() == port; 196 } 197 198 199 bool 200 EndpointHashDefinition::CompareValues(TCPEndpoint* first, 201 TCPEndpoint* second) const 202 { 203 return first->LocalAddress().Port() == second->LocalAddress().Port(); 204 } 205 206 207 TCPEndpoint*& 208 EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const 209 { 210 return endpoint->fEndpointHashLink; 211 } 212 213 214 // #pragma mark - 215 216 217 EndpointManager::EndpointManager(net_domain* domain) 218 : 219 fDomain(domain), 220 fConnectionHash(this), 221 fLastPort(kFirstEphemeralPort) 222 { 223 rw_lock_init(&fLock, "TCP endpoint manager"); 224 } 225 226 227 EndpointManager::~EndpointManager() 228 { 229 rw_lock_destroy(&fLock); 230 } 231 232 233 status_t 234 EndpointManager::Init() 235 { 236 status_t status = fConnectionHash.Init(); 237 if (status == B_OK) 238 status = fEndpointHash.Init(); 239 240 return status; 241 } 242 243 244 // #pragma mark - connections 245 246 247 /*! Returns the endpoint matching the connection. 248 You must hold the manager's lock when calling this method (either read or 249 write). 250 */ 251 TCPEndpoint* 252 EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer) 253 { 254 return fConnectionHash.Lookup(std::make_pair(local, peer)); 255 } 256 257 258 status_t 259 EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local, 260 const sockaddr* peer, const sockaddr* interfaceLocal) 261 { 262 TRACE(("EndpointManager::SetConnection(%p)\n", endpoint)); 263 264 WriteLocker _(fLock); 265 266 SocketAddressStorage local(AddressModule()); 267 local.SetTo(_local); 268 269 if (local.IsEmpty(false)) { 270 uint16 port = local.Port(); 271 local.SetTo(interfaceLocal); 272 local.SetPort(port); 273 } 274 275 // We want to create a connection for (local, peer), so check to make sure 276 // that this pair is not already in use by an existing connection. 277 if (_LookupConnection(*local, peer) != NULL) 278 return EADDRINUSE; 279 280 endpoint->LocalAddress().SetTo(*local); 281 endpoint->PeerAddress().SetTo(peer); 282 T(Connect(endpoint)); 283 284 // BOpenHashTable doesn't support inserting duplicate objects. Since 285 // BOpenHashTable is a chained hash table where the items are required to 286 // be intrusive linked list nodes, inserting the same object twice will 287 // create a cycle in the linked list, which is not handled currently. 288 // 289 // We need to makes sure to remove any existing copy of this endpoint 290 // object from the table in order to handle calling connect() on a closed 291 // socket to connect to a different remote (address, port) than it was 292 // originally used for. 293 // 294 // We use RemoveUnchecked here because we don't want the hash table to 295 // resize itself after this removal when we are planning to just add 296 // another. 297 fConnectionHash.RemoveUnchecked(endpoint); 298 299 fConnectionHash.Insert(endpoint); 300 return B_OK; 301 } 302 303 304 status_t 305 EndpointManager::SetPassive(TCPEndpoint* endpoint) 306 { 307 WriteLocker _(fLock); 308 309 if (!endpoint->IsBound()) { 310 // if the socket is unbound first bind it to ephemeral 311 SocketAddressStorage local(AddressModule()); 312 local.SetToEmpty(); 313 314 status_t status = _BindToEphemeral(endpoint, *local); 315 if (status < B_OK) 316 return status; 317 } 318 319 SocketAddressStorage passive(AddressModule()); 320 passive.SetToEmpty(); 321 322 if (_LookupConnection(*endpoint->LocalAddress(), *passive)) 323 return EADDRINUSE; 324 325 endpoint->PeerAddress().SetTo(*passive); 326 fConnectionHash.Insert(endpoint); 327 return B_OK; 328 } 329 330 331 TCPEndpoint* 332 EndpointManager::FindConnection(sockaddr* local, sockaddr* peer) 333 { 334 ReadLocker _(fLock); 335 336 TCPEndpoint *endpoint = _LookupConnection(local, peer); 337 if (endpoint != NULL) { 338 TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", 339 endpoint)); 340 if (gSocketModule->acquire_socket(endpoint->socket)) 341 return endpoint; 342 } 343 344 // no explicit endpoint exists, check for wildcard endpoints 345 346 SocketAddressStorage wildcard(AddressModule()); 347 wildcard.SetToEmpty(); 348 349 endpoint = _LookupConnection(local, *wildcard); 350 if (endpoint != NULL) { 351 TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", 352 endpoint)); 353 if (gSocketModule->acquire_socket(endpoint->socket)) 354 return endpoint; 355 } 356 357 SocketAddressStorage localWildcard(AddressModule()); 358 localWildcard.SetToEmpty(); 359 localWildcard.SetPort(AddressModule()->get_port(local)); 360 361 endpoint = _LookupConnection(*localWildcard, *wildcard); 362 if (endpoint != NULL) { 363 TRACE(("TCP: Received packet corresponds to local wildcard endpoint " 364 "%p\n", endpoint)); 365 if (gSocketModule->acquire_socket(endpoint->socket)) 366 return endpoint; 367 } 368 369 // no matching endpoint exists 370 TRACE(("TCP: no matching endpoint!\n")); 371 372 return NULL; 373 } 374 375 376 // #pragma mark - endpoints 377 378 379 status_t 380 EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address) 381 { 382 // check the family 383 if (!AddressModule()->is_same_family(address)) 384 return EAFNOSUPPORT; 385 386 WriteLocker locker(fLock); 387 388 if (AddressModule()->get_port(address) == 0) 389 return _BindToEphemeral(endpoint, address); 390 391 return _BindToAddress(locker, endpoint, address); 392 } 393 394 395 status_t 396 EndpointManager::BindChild(TCPEndpoint* endpoint, const sockaddr* address) 397 { 398 WriteLocker _(fLock); 399 return _Bind(endpoint, address); 400 } 401 402 403 /*! You must have fLock write locked when calling this method. */ 404 status_t 405 EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint, 406 const sockaddr* _address) 407 { 408 ConstSocketAddress address(AddressModule(), _address); 409 uint16 port = address.Port(); 410 411 TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint)); 412 T(Bind(endpoint, address, false)); 413 414 // TODO: this check follows very typical UNIX semantics 415 // and generally should be improved. 416 if (ntohs(port) <= kLastReservedPort && geteuid() != 0) 417 return B_PERMISSION_DENIED; 418 419 bool retrying = false; 420 int32 retry = 0; 421 do { 422 EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port); 423 retry = false; 424 425 while (portUsers.HasNext()) { 426 TCPEndpoint* user = portUsers.Next(); 427 428 if (user->LocalAddress().IsEmpty(false) 429 || address.EqualTo(*user->LocalAddress(), false)) { 430 // Check if this belongs to a local connection 431 432 // Note, while we hold our lock, the endpoint cannot go away, 433 // it can only change its state - IsLocal() is safe to be used 434 // without having the endpoint locked. 435 tcp_state userState = user->State(); 436 if (user->IsLocal() 437 && (userState > ESTABLISHED || userState == CLOSED)) { 438 // This is a closing local connection - wait until it's 439 // gone away for real 440 locker.Unlock(); 441 snooze(10000); 442 locker.Lock(); 443 // TODO: make this better 444 if (!retrying) { 445 retrying = true; 446 retry = 5; 447 } 448 break; 449 } 450 451 if ((endpoint->socket->options & SO_REUSEADDR) == 0) 452 return EADDRINUSE; 453 454 if (userState != TIME_WAIT && userState != CLOSED) 455 return EADDRINUSE; 456 } 457 } 458 } while (retry-- > 0); 459 460 return _Bind(endpoint, *address); 461 } 462 463 464 /*! You must have fLock write locked when calling this method. */ 465 status_t 466 EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint, 467 const sockaddr* address) 468 { 469 TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint)); 470 471 uint32 max = fLastPort + 65536; 472 473 for (int32 i = 1; i < 5; i++) { 474 // try to retrieve a more or less random port 475 uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1; 476 uint32 counter = fLastPort + step; 477 478 while (counter < max) { 479 uint16 port = counter & 0xffff; 480 if (port <= kLastReservedPort) 481 port += kLastReservedPort; 482 483 fLastPort = port; 484 port = htons(port); 485 486 if (!fEndpointHash.Lookup(port).HasNext()) { 487 // found a port 488 SocketAddressStorage newAddress(AddressModule()); 489 newAddress.SetTo(address); 490 newAddress.SetPort(port); 491 492 TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", 493 endpoint, AddressString(Domain(), *newAddress, 494 true).Data())); 495 T(Bind(endpoint, newAddress, true)); 496 497 return _Bind(endpoint, *newAddress); 498 } 499 500 counter += step; 501 } 502 } 503 504 // could not find a port! 505 return EADDRINUSE; 506 } 507 508 509 status_t 510 EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address) 511 { 512 // Thus far we have checked if the Bind() is allowed 513 514 status_t status = endpoint->next->module->bind(endpoint->next, address); 515 if (status < B_OK) 516 return status; 517 518 fEndpointHash.Insert(endpoint); 519 520 return B_OK; 521 } 522 523 524 status_t 525 EndpointManager::Unbind(TCPEndpoint* endpoint) 526 { 527 TRACE(("EndpointManager::Unbind(%p)\n", endpoint)); 528 T(Unbind(endpoint)); 529 530 if (endpoint == NULL || !endpoint->IsBound()) { 531 TRACE((" endpoint is unbound.\n")); 532 return B_BAD_VALUE; 533 } 534 535 WriteLocker _(fLock); 536 537 if (!fEndpointHash.Remove(endpoint)) 538 panic("bound endpoint %p not in hash!", endpoint); 539 540 fConnectionHash.Remove(endpoint); 541 542 (*endpoint->LocalAddress())->sa_len = 0; 543 544 return B_OK; 545 } 546 547 548 status_t 549 EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer) 550 { 551 TRACE(("TCP: Sending RST...\n")); 552 553 net_buffer* reply = gBufferModule->create(512); 554 if (reply == NULL) 555 return B_NO_MEMORY; 556 557 AddressModule()->set_to(reply->source, buffer->destination); 558 AddressModule()->set_to(reply->destination, buffer->source); 559 560 tcp_segment_header outSegment(TCP_FLAG_RESET); 561 outSegment.sequence = 0; 562 outSegment.acknowledge = 0; 563 outSegment.advertised_window = 0; 564 outSegment.urgent_offset = 0; 565 566 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 567 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 568 outSegment.acknowledge = segment.sequence + buffer->size; 569 // TODO: Confirm: 570 if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0) 571 outSegment.acknowledge++; 572 } else 573 outSegment.sequence = segment.acknowledge; 574 575 status_t status = add_tcp_header(AddressModule(), outSegment, reply); 576 if (status == B_OK) 577 status = Domain()->module->send_data(NULL, reply); 578 579 if (status != B_OK) 580 gBufferModule->free(reply); 581 582 return status; 583 } 584 585 586 void 587 EndpointManager::Dump() const 588 { 589 kprintf("-------- TCP Domain %p ---------\n", this); 590 kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer", 591 "recv-q", "send-q", "state"); 592 593 ConnectionTable::Iterator iterator = fConnectionHash.GetIterator(); 594 595 while (iterator.HasNext()) { 596 TCPEndpoint *endpoint = iterator.Next(); 597 598 char localBuf[64], peerBuf[64]; 599 endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true); 600 endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true); 601 602 kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf, 603 endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(), 604 name_for_state(endpoint->State())); 605 } 606 } 607 608