1 /* 2 * Copyright 2006-2007, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 * Hugo Santos, hugosantos@gmail.com 8 */ 9 10 11 #include "EndpointManager.h" 12 13 #include <unistd.h> 14 15 #include <KernelExport.h> 16 17 #include <NetUtilities.h> 18 #include <util/AutoLock.h> 19 20 #include "TCPEndpoint.h" 21 22 23 //#define TRACE_ENDPOINT_MANAGER 24 #ifdef TRACE_ENDPOINT_MANAGER 25 # define TRACE(x) dprintf x 26 #else 27 # define TRACE(x) 28 #endif 29 30 31 static const uint16 kLastReservedPort = 1023; 32 static const uint16 kFirstEphemeralPort = 40000; 33 34 35 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager *manager) 36 : 37 fManager(manager) 38 { 39 } 40 41 42 size_t 43 ConnectionHashDefinition::HashKey(const KeyType &key) const 44 { 45 return ConstSocketAddress(fManager->AddressModule(), 46 key.first).HashPair(key.second); 47 } 48 49 50 size_t 51 ConnectionHashDefinition::Hash(TCPEndpoint *endpoint) const 52 { 53 return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress()); 54 } 55 56 57 bool 58 ConnectionHashDefinition::Compare(const KeyType &key, 59 TCPEndpoint *endpoint) const 60 { 61 return endpoint->LocalAddress().EqualTo(key.first, true) 62 && endpoint->PeerAddress().EqualTo(key.second, true); 63 } 64 65 66 HashTableLink<TCPEndpoint> * 67 ConnectionHashDefinition::GetLink(TCPEndpoint *endpoint) const 68 { 69 return &endpoint->fConnectionHashLink; 70 } 71 72 73 size_t 74 EndpointHashDefinition::HashKey(uint16 port) const 75 { 76 return port; 77 } 78 79 80 size_t 81 EndpointHashDefinition::Hash(TCPEndpoint *endpoint) const 82 { 83 return endpoint->LocalAddress().Port(); 84 } 85 86 87 bool 88 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint *endpoint) const 89 { 90 return endpoint->LocalAddress().Port() == port; 91 } 92 93 94 bool 95 EndpointHashDefinition::CompareValues(TCPEndpoint *first, 96 TCPEndpoint *second) const 97 { 98 return first->LocalAddress().Port() == second->LocalAddress().Port(); 99 } 100 101 102 HashTableLink<TCPEndpoint> * 103 EndpointHashDefinition::GetLink(TCPEndpoint *endpoint) const 104 { 105 return &endpoint->fEndpointHashLink; 106 } 107 108 109 EndpointManager::EndpointManager(net_domain *domain) 110 : fDomain(domain), fConnectionHash(this) 111 { 112 benaphore_init(&fLock, "endpoint manager"); 113 } 114 115 116 EndpointManager::~EndpointManager() 117 { 118 benaphore_destroy(&fLock); 119 } 120 121 122 status_t 123 EndpointManager::InitCheck() const 124 { 125 if (fConnectionHash.InitCheck() < B_OK) 126 return fConnectionHash.InitCheck(); 127 128 if (fEndpointHash.InitCheck() < B_OK) 129 return fEndpointHash.InitCheck(); 130 131 if (fLock.sem < B_OK) 132 return fLock.sem; 133 134 return B_OK; 135 } 136 137 138 // #pragma mark - connections 139 140 141 /*! 142 Returns the endpoint matching the connection. 143 You must hold the manager's lock when calling this method. 144 */ 145 TCPEndpoint * 146 EndpointManager::_LookupConnection(const sockaddr *local, const sockaddr *peer) 147 { 148 return fConnectionHash.Lookup(std::make_pair(local, peer)); 149 } 150 151 152 status_t 153 EndpointManager::SetConnection(TCPEndpoint *endpoint, 154 const sockaddr *_local, const sockaddr *peer, const sockaddr *interfaceLocal) 155 { 156 TRACE(("EndpointManager::SetConnection(%p)\n", endpoint)); 157 158 BenaphoreLocker _(fLock); 159 160 SocketAddressStorage local(AddressModule()); 161 local.SetTo(_local); 162 163 if (local.IsEmpty(false)) { 164 uint16 port = local.Port(); 165 local.SetTo(interfaceLocal); 166 local.SetPort(port); 167 } 168 169 if (_LookupConnection(*local, peer) != NULL) 170 return EADDRINUSE; 171 172 endpoint->LocalAddress().SetTo(*local); 173 endpoint->PeerAddress().SetTo(peer); 174 175 fConnectionHash.Insert(endpoint); 176 return B_OK; 177 } 178 179 180 status_t 181 EndpointManager::SetPassive(TCPEndpoint *endpoint) 182 { 183 BenaphoreLocker _(fLock); 184 185 if (!endpoint->IsBound()) { 186 // if the socket is unbound first bind it to ephemeral 187 SocketAddressStorage local(AddressModule()); 188 local.SetToEmpty(); 189 190 status_t status = _BindToEphemeral(endpoint, *local); 191 if (status < B_OK) 192 return status; 193 } 194 195 SocketAddressStorage passive(AddressModule()); 196 passive.SetToEmpty(); 197 198 if (_LookupConnection(*endpoint->LocalAddress(), *passive)) 199 return EADDRINUSE; 200 201 endpoint->PeerAddress().SetTo(*passive); 202 fConnectionHash.Insert(endpoint); 203 return B_OK; 204 } 205 206 207 TCPEndpoint * 208 EndpointManager::FindConnection(sockaddr *local, sockaddr *peer) 209 { 210 BenaphoreLocker _(fLock); 211 212 TCPEndpoint *endpoint = _LookupConnection(local, peer); 213 if (endpoint != NULL) { 214 TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", endpoint)); 215 return endpoint; 216 } 217 218 // no explicit endpoint exists, check for wildcard endpoints 219 220 SocketAddressStorage wildcard(AddressModule()); 221 wildcard.SetToEmpty(); 222 223 endpoint = _LookupConnection(local, *wildcard); 224 if (endpoint != NULL) { 225 TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint)); 226 return endpoint; 227 } 228 229 SocketAddressStorage localWildcard(AddressModule()); 230 localWildcard.SetToEmpty(); 231 localWildcard.SetPort(AddressModule()->get_port(local)); 232 233 endpoint = _LookupConnection(*localWildcard, *wildcard); 234 if (endpoint != NULL) { 235 TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint)); 236 return endpoint; 237 } 238 239 // no matching endpoint exists 240 TRACE(("TCP: no matching endpoint!\n")); 241 242 return NULL; 243 } 244 245 246 // #pragma mark - endpoints 247 248 249 status_t 250 EndpointManager::Bind(TCPEndpoint *endpoint, const sockaddr *address) 251 { 252 // TODO check the family: 253 // 254 // if (!AddressModule()->is_understandable(address)) 255 // return EAFNOSUPPORT; 256 257 BenaphoreLocker _(fLock); 258 259 if (AddressModule()->get_port(address) == 0) 260 return _BindToEphemeral(endpoint, address); 261 262 return _BindToAddress(endpoint, address); 263 } 264 265 266 status_t 267 EndpointManager::BindChild(TCPEndpoint *endpoint) 268 { 269 BenaphoreLocker _(fLock); 270 return _Bind(endpoint, *endpoint->LocalAddress()); 271 } 272 273 274 status_t 275 EndpointManager::_BindToAddress(TCPEndpoint *endpoint, const sockaddr *_address) 276 { 277 TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint)); 278 279 ConstSocketAddress address(AddressModule(), _address); 280 281 uint16 port = address.Port(); 282 283 // TODO this check follows very typical UNIX semantics 284 // and generally should be improved. 285 if (ntohs(port) <= kLastReservedPort && geteuid() != 0) 286 return B_PERMISSION_DENIED; 287 288 EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port); 289 290 while (portUsers.HasNext()) { 291 TCPEndpoint *user = portUsers.Next(); 292 293 if (user->LocalAddress().IsEmpty(false) 294 || address.EqualTo(*user->LocalAddress(), false)) { 295 if ((endpoint->socket->options & SO_REUSEADDR) == 0) 296 return EADDRINUSE; 297 // TODO lock endpoint before retriving state? 298 if (user->State() != TIME_WAIT && user->State() != CLOSED) 299 return EADDRINUSE; 300 } 301 } 302 303 return _Bind(endpoint, *address); 304 } 305 306 307 status_t 308 EndpointManager::_BindToEphemeral(TCPEndpoint *endpoint, 309 const sockaddr *address) 310 { 311 TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint)); 312 313 uint32 max = kFirstEphemeralPort + 65536; 314 315 for (int32 i = 1; i < 5; i++) { 316 // try to retrieve a more or less random port 317 uint32 counter = kFirstEphemeralPort; 318 uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1; 319 320 while (counter < max) { 321 uint16 port = counter & 0xffff; 322 if (port <= kLastReservedPort) 323 port += kLastReservedPort; 324 325 port = htons(port); 326 327 if (!fEndpointHash.Lookup(port).HasNext()) { 328 SocketAddressStorage newAddress(AddressModule()); 329 newAddress.SetTo(address); 330 newAddress.SetPort(port); 331 332 // found a port 333 TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint, 334 AddressString(Domain(), *newAddress, true).Data())); 335 336 return _Bind(endpoint, *newAddress); 337 } 338 339 counter += step; 340 } 341 } 342 343 // could not find a port! 344 return EADDRINUSE; 345 } 346 347 348 status_t 349 EndpointManager::_Bind(TCPEndpoint *endpoint, const sockaddr *address) 350 { 351 // Thus far we have checked if the Bind() is allowed 352 353 status_t status = endpoint->next->module->bind(endpoint->next, address); 354 if (status < B_OK) 355 return status; 356 357 fEndpointHash.Insert(endpoint); 358 359 return B_OK; 360 } 361 362 363 status_t 364 EndpointManager::Unbind(TCPEndpoint *endpoint) 365 { 366 TRACE(("EndpointManager::Unbind(%p)\n", endpoint)); 367 368 if (endpoint == NULL || !endpoint->IsBound()) { 369 TRACE((" endpoint is unbound.\n")); 370 return B_BAD_VALUE; 371 } 372 373 BenaphoreLocker _(fLock); 374 375 if (!fEndpointHash.Remove(endpoint)) 376 panic("bound endpoint %p not in hash!", endpoint); 377 378 fConnectionHash.Remove(endpoint); 379 380 (*endpoint->LocalAddress())->sa_len = 0; 381 382 return B_OK; 383 } 384 385 386 status_t 387 EndpointManager::ReplyWithReset(tcp_segment_header &segment, 388 net_buffer *buffer) 389 { 390 TRACE(("TCP: Sending RST...\n")); 391 392 net_buffer *reply = gBufferModule->create(512); 393 if (reply == NULL) 394 return B_NO_MEMORY; 395 396 AddressModule()->set_to(reply->source, buffer->destination); 397 AddressModule()->set_to(reply->destination, buffer->source); 398 399 tcp_segment_header outSegment(TCP_FLAG_RESET); 400 outSegment.sequence = 0; 401 outSegment.acknowledge = 0; 402 outSegment.advertised_window = 0; 403 outSegment.urgent_offset = 0; 404 405 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 406 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 407 outSegment.acknowledge = segment.sequence + buffer->size; 408 // TODO: Confirm: 409 if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0) 410 outSegment.acknowledge++; 411 } else 412 outSegment.sequence = segment.acknowledge; 413 414 status_t status = add_tcp_header(AddressModule(), outSegment, reply); 415 if (status == B_OK) 416 status = Domain()->module->send_data(NULL, reply); 417 418 if (status != B_OK) 419 gBufferModule->free(reply); 420 421 return status; 422 } 423 424 425 void 426 EndpointManager::DumpEndpoints() const 427 { 428 kprintf("-------- TCP Domain %p ---------\n", this); 429 kprintf("%10s %20s %20s %8s %8s %12s\n", "address", "local", "peer", 430 "recv-q", "send-q", "state"); 431 432 ConnectionTable::Iterator it = fConnectionHash.GetIterator(); 433 434 while (it.HasNext()) { 435 TCPEndpoint *endpoint = it.Next(); 436 437 char localBuf[64], peerBuf[64]; 438 endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true); 439 endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true); 440 441 kprintf("%p %20s %20s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf, 442 endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(), 443 name_for_state(endpoint->State())); 444 } 445 } 446 447