1 /* 2 * Copyright 2006-2008, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Hugo Santos, hugosantos@gmail.com 8 */ 9 10 11 #include "EndpointManager.h" 12 13 #include <new> 14 #include <unistd.h> 15 16 #include <KernelExport.h> 17 18 #include <NetUtilities.h> 19 #include <util/AutoLock.h> 20 #include <tracing.h> 21 22 #include "TCPEndpoint.h" 23 24 25 //#define TRACE_ENDPOINT_MANAGER 26 #ifdef TRACE_ENDPOINT_MANAGER 27 # define TRACE(x) dprintf x 28 #else 29 # define TRACE(x) 30 #endif 31 32 #if TCP_TRACING 33 # define ENDPOINT_TRACING 34 #endif 35 #ifdef ENDPOINT_TRACING 36 namespace EndpointTracing { 37 38 class Bind : public AbstractTraceEntry { 39 public: 40 Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral) 41 : 42 fEndpoint(endpoint), 43 fEphemeral(ephemeral) 44 { 45 address.AsString(fAddress, sizeof(fAddress), true); 46 Initialized(); 47 } 48 49 Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral) 50 : 51 fEndpoint(endpoint), 52 fEphemeral(ephemeral) 53 { 54 address.AsString(fAddress, sizeof(fAddress), true); 55 Initialized(); 56 } 57 58 virtual void AddDump(TraceOutput& out) 59 { 60 out.Print("tcp:e:%p bind%s address %s", fEndpoint, 61 fEphemeral ? " ephemeral" : "", fAddress); 62 } 63 64 protected: 65 TCPEndpoint* fEndpoint; 66 char fAddress[32]; 67 bool fEphemeral; 68 }; 69 70 class Connect : public AbstractTraceEntry { 71 public: 72 Connect(TCPEndpoint* endpoint) 73 : 74 fEndpoint(endpoint) 75 { 76 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true); 77 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true); 78 Initialized(); 79 } 80 81 virtual void AddDump(TraceOutput& out) 82 { 83 out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal, 84 fPeer); 85 } 86 87 protected: 88 TCPEndpoint* fEndpoint; 89 char fLocal[32]; 90 char fPeer[32]; 91 }; 92 93 class Unbind : public AbstractTraceEntry { 94 public: 95 Unbind(TCPEndpoint* endpoint) 96 : 97 fEndpoint(endpoint) 98 { 99 //fStackTrace = capture_tracing_stack_trace(10, 0, false); 100 101 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true); 102 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true); 103 Initialized(); 104 } 105 106 #if 0 107 virtual void DumpStackTrace(TraceOutput& out) 108 { 109 out.PrintStackTrace(fStackTrace); 110 } 111 #endif 112 113 virtual void AddDump(TraceOutput& out) 114 { 115 out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal, 116 fPeer); 117 } 118 119 protected: 120 TCPEndpoint* fEndpoint; 121 //tracing_stack_trace* fStackTrace; 122 char fLocal[32]; 123 char fPeer[32]; 124 }; 125 126 } // namespace EndpointTracing 127 128 # define T(x) new(std::nothrow) EndpointTracing::x 129 #else 130 # define T(x) 131 #endif // ENDPOINT_TRACING 132 133 134 static const uint16 kLastReservedPort = 1023; 135 static const uint16 kFirstEphemeralPort = 40000; 136 137 138 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager) 139 : 140 fManager(manager) 141 { 142 } 143 144 145 size_t 146 ConnectionHashDefinition::HashKey(const KeyType& key) const 147 { 148 return ConstSocketAddress(fManager->AddressModule(), 149 key.first).HashPair(key.second); 150 } 151 152 153 size_t 154 ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const 155 { 156 return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress()); 157 } 158 159 160 bool 161 ConnectionHashDefinition::Compare(const KeyType& key, 162 TCPEndpoint* endpoint) const 163 { 164 return endpoint->LocalAddress().EqualTo(key.first, true) 165 && endpoint->PeerAddress().EqualTo(key.second, true); 166 } 167 168 169 HashTableLink<TCPEndpoint>* 170 ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const 171 { 172 return &endpoint->fConnectionHashLink; 173 } 174 175 176 // #pragma mark - 177 178 179 size_t 180 EndpointHashDefinition::HashKey(uint16 port) const 181 { 182 return port; 183 } 184 185 186 size_t 187 EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const 188 { 189 return endpoint->LocalAddress().Port(); 190 } 191 192 193 bool 194 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const 195 { 196 return endpoint->LocalAddress().Port() == port; 197 } 198 199 200 bool 201 EndpointHashDefinition::CompareValues(TCPEndpoint* first, 202 TCPEndpoint* second) const 203 { 204 return first->LocalAddress().Port() == second->LocalAddress().Port(); 205 } 206 207 208 HashTableLink<TCPEndpoint>* 209 EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const 210 { 211 return &endpoint->fEndpointHashLink; 212 } 213 214 215 // #pragma mark - 216 217 218 EndpointManager::EndpointManager(net_domain* domain) 219 : 220 fDomain(domain), 221 fConnectionHash(this), 222 fLastPort(kFirstEphemeralPort) 223 { 224 mutex_init(&fLock, "endpoint manager"); 225 } 226 227 228 EndpointManager::~EndpointManager() 229 { 230 mutex_destroy(&fLock); 231 } 232 233 234 status_t 235 EndpointManager::Init() 236 { 237 status_t status = fConnectionHash.Init(); 238 if (status == B_OK) 239 status = fEndpointHash.Init(); 240 241 return status; 242 } 243 244 245 // #pragma mark - connections 246 247 248 /*! Returns the endpoint matching the connection. 249 You must hold the manager's lock when calling this method. 250 */ 251 TCPEndpoint* 252 EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer) 253 { 254 return fConnectionHash.Lookup(std::make_pair(local, peer)); 255 } 256 257 258 status_t 259 EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local, 260 const sockaddr* peer, const sockaddr* interfaceLocal) 261 { 262 TRACE(("EndpointManager::SetConnection(%p)\n", endpoint)); 263 264 MutexLocker _(fLock); 265 266 SocketAddressStorage local(AddressModule()); 267 local.SetTo(_local); 268 269 if (local.IsEmpty(false)) { 270 uint16 port = local.Port(); 271 local.SetTo(interfaceLocal); 272 local.SetPort(port); 273 } 274 275 if (_LookupConnection(*local, peer) != NULL) 276 return EADDRINUSE; 277 278 endpoint->LocalAddress().SetTo(*local); 279 endpoint->PeerAddress().SetTo(peer); 280 T(Connect(endpoint)); 281 282 fConnectionHash.Insert(endpoint); 283 return B_OK; 284 } 285 286 287 status_t 288 EndpointManager::SetPassive(TCPEndpoint* endpoint) 289 { 290 MutexLocker _(fLock); 291 292 if (!endpoint->IsBound()) { 293 // if the socket is unbound first bind it to ephemeral 294 SocketAddressStorage local(AddressModule()); 295 local.SetToEmpty(); 296 297 status_t status = _BindToEphemeral(endpoint, *local); 298 if (status < B_OK) 299 return status; 300 } 301 302 SocketAddressStorage passive(AddressModule()); 303 passive.SetToEmpty(); 304 305 if (_LookupConnection(*endpoint->LocalAddress(), *passive)) 306 return EADDRINUSE; 307 308 endpoint->PeerAddress().SetTo(*passive); 309 fConnectionHash.Insert(endpoint); 310 return B_OK; 311 } 312 313 314 TCPEndpoint* 315 EndpointManager::FindConnection(sockaddr* local, sockaddr* peer) 316 { 317 MutexLocker _(fLock); 318 319 TCPEndpoint *endpoint = _LookupConnection(local, peer); 320 if (endpoint != NULL) { 321 TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", endpoint)); 322 return endpoint; 323 } 324 325 // no explicit endpoint exists, check for wildcard endpoints 326 327 SocketAddressStorage wildcard(AddressModule()); 328 wildcard.SetToEmpty(); 329 330 endpoint = _LookupConnection(local, *wildcard); 331 if (endpoint != NULL) { 332 TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint)); 333 return endpoint; 334 } 335 336 SocketAddressStorage localWildcard(AddressModule()); 337 localWildcard.SetToEmpty(); 338 localWildcard.SetPort(AddressModule()->get_port(local)); 339 340 endpoint = _LookupConnection(*localWildcard, *wildcard); 341 if (endpoint != NULL) { 342 TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint)); 343 return endpoint; 344 } 345 346 // no matching endpoint exists 347 TRACE(("TCP: no matching endpoint!\n")); 348 349 return NULL; 350 } 351 352 353 // #pragma mark - endpoints 354 355 356 status_t 357 EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address) 358 { 359 // TODO check the family: 360 // 361 // if (!AddressModule()->is_understandable(address)) 362 // return EAFNOSUPPORT; 363 364 MutexLocker _(fLock); 365 366 if (AddressModule()->get_port(address) == 0) 367 return _BindToEphemeral(endpoint, address); 368 369 return _BindToAddress(endpoint, address); 370 } 371 372 373 status_t 374 EndpointManager::BindChild(TCPEndpoint* endpoint) 375 { 376 MutexLocker _(fLock); 377 return _Bind(endpoint, *endpoint->LocalAddress()); 378 } 379 380 381 /*! You must hold fLock when calling this method. */ 382 status_t 383 EndpointManager::_BindToAddress(TCPEndpoint* endpoint, const sockaddr* _address) 384 { 385 ConstSocketAddress address(AddressModule(), _address); 386 uint16 port = address.Port(); 387 388 TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint)); 389 T(Bind(endpoint, address, false)); 390 391 // TODO this check follows very typical UNIX semantics 392 // and generally should be improved. 393 if (ntohs(port) <= kLastReservedPort && geteuid() != 0) 394 return B_PERMISSION_DENIED; 395 396 bool retrying = false; 397 int32 retry = 0; 398 do { 399 EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port); 400 retry = false; 401 402 while (portUsers.HasNext()) { 403 TCPEndpoint* user = portUsers.Next(); 404 405 if (user->LocalAddress().IsEmpty(false) 406 || address.EqualTo(*user->LocalAddress(), false)) { 407 // Check if this belongs to a local connection 408 409 // Note, while we hold our lock, the endpoint cannot go away, 410 // it can only change its state - IsLocal() is safe to be used 411 // without having the endpoint locked. 412 tcp_state userState = user->State(); 413 if (user->IsLocal() 414 && (userState > ESTABLISHED || userState == CLOSED)) { 415 // This is a closing local connection - wait until it's 416 // gone away for real 417 mutex_unlock(&fLock); 418 snooze(10000); 419 mutex_lock(&fLock); 420 // TODO: make this better 421 if (!retrying) { 422 retrying = true; 423 retry = 5; 424 } 425 break; 426 } 427 428 if ((endpoint->socket->options & SO_REUSEADDR) == 0) 429 return EADDRINUSE; 430 431 if (userState != TIME_WAIT && userState != CLOSED) 432 return EADDRINUSE; 433 } 434 } 435 } while (retry-- > 0); 436 437 return _Bind(endpoint, *address); 438 } 439 440 441 /*! You must hold fLock when calling this method. */ 442 status_t 443 EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint, 444 const sockaddr* address) 445 { 446 TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint)); 447 448 uint32 max = fLastPort + 65536; 449 450 for (int32 i = 1; i < 5; i++) { 451 // try to retrieve a more or less random port 452 uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1; 453 uint32 counter = fLastPort + step; 454 455 while (counter < max) { 456 uint16 port = counter & 0xffff; 457 if (port <= kLastReservedPort) 458 port += kLastReservedPort; 459 460 fLastPort = port; 461 port = htons(port); 462 463 if (!fEndpointHash.Lookup(port).HasNext()) { 464 // found a port 465 SocketAddressStorage newAddress(AddressModule()); 466 newAddress.SetTo(address); 467 newAddress.SetPort(port); 468 469 TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", 470 endpoint, AddressString(Domain(), *newAddress, 471 true).Data())); 472 T(Bind(endpoint, newAddress, true)); 473 474 return _Bind(endpoint, *newAddress); 475 } 476 477 counter += step; 478 } 479 } 480 481 // could not find a port! 482 return EADDRINUSE; 483 } 484 485 486 status_t 487 EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address) 488 { 489 // Thus far we have checked if the Bind() is allowed 490 491 status_t status = endpoint->next->module->bind(endpoint->next, address); 492 if (status < B_OK) 493 return status; 494 495 fEndpointHash.Insert(endpoint); 496 497 return B_OK; 498 } 499 500 501 status_t 502 EndpointManager::Unbind(TCPEndpoint* endpoint) 503 { 504 TRACE(("EndpointManager::Unbind(%p)\n", endpoint)); 505 T(Unbind(endpoint)); 506 507 if (endpoint == NULL || !endpoint->IsBound()) { 508 TRACE((" endpoint is unbound.\n")); 509 return B_BAD_VALUE; 510 } 511 512 MutexLocker _(fLock); 513 514 if (!fEndpointHash.Remove(endpoint)) 515 panic("bound endpoint %p not in hash!", endpoint); 516 517 fConnectionHash.Remove(endpoint); 518 519 (*endpoint->LocalAddress())->sa_len = 0; 520 521 return B_OK; 522 } 523 524 525 status_t 526 EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer) 527 { 528 TRACE(("TCP: Sending RST...\n")); 529 530 net_buffer* reply = gBufferModule->create(512); 531 if (reply == NULL) 532 return B_NO_MEMORY; 533 534 AddressModule()->set_to(reply->source, buffer->destination); 535 AddressModule()->set_to(reply->destination, buffer->source); 536 537 tcp_segment_header outSegment(TCP_FLAG_RESET); 538 outSegment.sequence = 0; 539 outSegment.acknowledge = 0; 540 outSegment.advertised_window = 0; 541 outSegment.urgent_offset = 0; 542 543 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 544 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 545 outSegment.acknowledge = segment.sequence + buffer->size; 546 // TODO: Confirm: 547 if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0) 548 outSegment.acknowledge++; 549 } else 550 outSegment.sequence = segment.acknowledge; 551 552 status_t status = add_tcp_header(AddressModule(), outSegment, reply); 553 if (status == B_OK) 554 status = Domain()->module->send_data(NULL, reply); 555 556 if (status != B_OK) 557 gBufferModule->free(reply); 558 559 return status; 560 } 561 562 563 void 564 EndpointManager::Dump() const 565 { 566 kprintf("-------- TCP Domain %p ---------\n", this); 567 kprintf("%10s %20s %20s %8s %8s %12s\n", "address", "local", "peer", 568 "recv-q", "send-q", "state"); 569 570 ConnectionTable::Iterator iterator = fConnectionHash.GetIterator(); 571 572 while (iterator.HasNext()) { 573 TCPEndpoint *endpoint = iterator.Next(); 574 575 char localBuf[64], peerBuf[64]; 576 endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true); 577 endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true); 578 579 kprintf("%p %20s %20s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf, 580 endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(), 581 name_for_state(endpoint->State())); 582 } 583 } 584 585