1 /* 2 * Copyright 2006-2007, Haiku, Inc. All Rights Reserved. 3 * Distributed under the terms of the MIT License. 4 * 5 * Authors: 6 * Axel Dörfler, axeld@pinc-software.de 7 */ 8 9 10 #include "EndpointManager.h" 11 #include "TCPEndpoint.h" 12 13 #include <NetUtilities.h> 14 15 #include <util/AutoLock.h> 16 17 #include <KernelExport.h> 18 19 20 //#define TRACE_ENDPOINT_MANAGER 21 #ifdef TRACE_ENDPOINT_MANAGER 22 # define TRACE(x) dprintf x 23 #else 24 # define TRACE(x) 25 #endif 26 27 28 static const uint16 kLastReservedPort = 1023; 29 static const uint16 kFirstEphemeralPort = 40000; 30 31 32 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager *manager) 33 : fManager(manager) {} 34 35 size_t 36 ConnectionHashDefinition::HashKey(const KeyType &key) const 37 { 38 return ConstSocketAddress(fManager->AddressModule(), 39 key.first).HashPair(key.second); 40 } 41 42 43 size_t 44 ConnectionHashDefinition::Hash(TCPEndpoint *endpoint) const 45 { 46 return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress()); 47 } 48 49 50 bool 51 ConnectionHashDefinition::Compare(const KeyType &key, 52 TCPEndpoint *endpoint) const 53 { 54 return endpoint->LocalAddress().EqualTo(key.first, true) 55 && endpoint->PeerAddress().EqualTo(key.second, true); 56 } 57 58 59 HashTableLink<TCPEndpoint> * 60 ConnectionHashDefinition::GetLink(TCPEndpoint *endpoint) const 61 { 62 return &endpoint->fConnectionHashLink; 63 } 64 65 66 size_t 67 EndpointHashDefinition::HashKey(uint16 port) const 68 { 69 return port; 70 } 71 72 73 size_t 74 EndpointHashDefinition::Hash(TCPEndpoint *endpoint) const 75 { 76 return endpoint->LocalAddress().Port(); 77 } 78 79 80 bool 81 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint *endpoint) const 82 { 83 return endpoint->LocalAddress().Port() == port; 84 } 85 86 87 bool 88 EndpointHashDefinition::CompareValues(TCPEndpoint *first, 89 TCPEndpoint *second) const 90 { 91 return first->LocalAddress().Port() == second->LocalAddress().Port(); 92 } 93 94 95 HashTableLink<TCPEndpoint> * 96 EndpointHashDefinition::GetLink(TCPEndpoint *endpoint) const 97 { 98 return &endpoint->fEndpointHashLink; 99 } 100 101 102 EndpointManager::EndpointManager(net_domain *domain) 103 : fDomain(domain), fConnectionHash(this) 104 { 105 benaphore_init(&fLock, "endpoint manager"); 106 } 107 108 109 EndpointManager::~EndpointManager() 110 { 111 benaphore_destroy(&fLock); 112 } 113 114 115 status_t 116 EndpointManager::InitCheck() const 117 { 118 if (fConnectionHash.InitCheck() < B_OK) 119 return fConnectionHash.InitCheck(); 120 121 if (fEndpointHash.InitCheck() < B_OK) 122 return fEndpointHash.InitCheck(); 123 124 if (fLock.sem < B_OK) 125 return fLock.sem; 126 127 return B_OK; 128 } 129 130 131 // #pragma mark - connections 132 133 134 /*! 135 Returns the endpoint matching the connection. 136 You must hold the manager's lock when calling this method. 137 */ 138 TCPEndpoint * 139 EndpointManager::_LookupConnection(const sockaddr *local, const sockaddr *peer) 140 { 141 return fConnectionHash.Lookup(std::make_pair(local, peer)); 142 } 143 144 145 status_t 146 EndpointManager::SetConnection(TCPEndpoint *endpoint, 147 const sockaddr *_local, const sockaddr *peer, const sockaddr *interfaceLocal) 148 { 149 TRACE(("EndpointManager::SetConnection(%p)\n", endpoint)); 150 151 BenaphoreLocker _(fLock); 152 153 SocketAddressStorage local(AddressModule()); 154 local.SetTo(_local); 155 156 if (local.IsEmpty(false)) { 157 uint16 port = local.Port(); 158 local.SetTo(interfaceLocal); 159 local.SetPort(port); 160 } 161 162 if (_LookupConnection(*local, peer) != NULL) 163 return EADDRINUSE; 164 165 endpoint->LocalAddress().SetTo(*local); 166 endpoint->PeerAddress().SetTo(peer); 167 168 fConnectionHash.Insert(endpoint); 169 return B_OK; 170 } 171 172 173 status_t 174 EndpointManager::SetPassive(TCPEndpoint *endpoint) 175 { 176 BenaphoreLocker _(fLock); 177 178 if (!endpoint->IsBound()) { 179 // if the socket is unbound first bind it to ephemeral 180 SocketAddressStorage local(AddressModule()); 181 local.SetToEmpty(); 182 183 status_t status = _BindToEphemeral(endpoint, *local); 184 if (status < B_OK) 185 return status; 186 } 187 188 SocketAddressStorage passive(AddressModule()); 189 passive.SetToEmpty(); 190 191 if (_LookupConnection(*endpoint->LocalAddress(), *passive)) 192 return EADDRINUSE; 193 194 endpoint->PeerAddress().SetTo(*passive); 195 fConnectionHash.Insert(endpoint); 196 return B_OK; 197 } 198 199 200 TCPEndpoint * 201 EndpointManager::FindConnection(sockaddr *local, sockaddr *peer) 202 { 203 BenaphoreLocker _(fLock); 204 205 TCPEndpoint *endpoint = _LookupConnection(local, peer); 206 if (endpoint != NULL) { 207 TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", endpoint)); 208 return endpoint; 209 } 210 211 // no explicit endpoint exists, check for wildcard endpoints 212 213 SocketAddressStorage wildcard(AddressModule()); 214 wildcard.SetToEmpty(); 215 216 endpoint = _LookupConnection(local, *wildcard); 217 if (endpoint != NULL) { 218 TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint)); 219 return endpoint; 220 } 221 222 SocketAddressStorage localWildcard(AddressModule()); 223 localWildcard.SetToEmpty(); 224 localWildcard.SetPort(AddressModule()->get_port(local)); 225 226 endpoint = _LookupConnection(*localWildcard, *wildcard); 227 if (endpoint != NULL) { 228 TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint)); 229 return endpoint; 230 } 231 232 // no matching endpoint exists 233 TRACE(("TCP: no matching endpoint!\n")); 234 235 return NULL; 236 } 237 238 239 // #pragma mark - endpoints 240 241 242 status_t 243 EndpointManager::Bind(TCPEndpoint *endpoint, const sockaddr *address) 244 { 245 // TODO check the family: 246 // 247 // if (!AddressModule()->is_understandable(address)) 248 // return EAFNOSUPPORT; 249 250 BenaphoreLocker _(fLock); 251 252 if (AddressModule()->get_port(address) == 0) 253 return _BindToEphemeral(endpoint, address); 254 255 return _BindToAddress(endpoint, address); 256 } 257 258 259 status_t 260 EndpointManager::BindChild(TCPEndpoint *endpoint) 261 { 262 BenaphoreLocker _(fLock); 263 return _Bind(endpoint, *endpoint->LocalAddress()); 264 } 265 266 267 status_t 268 EndpointManager::_BindToAddress(TCPEndpoint *endpoint, const sockaddr *_address) 269 { 270 TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint)); 271 272 ConstSocketAddress address(AddressModule(), _address); 273 274 uint16 port = address.Port(); 275 276 // TODO this check follows very typical UNIX semantics 277 // and generally should be improved. 278 if (ntohs(port) <= kLastReservedPort && geteuid() != 0) 279 return B_PERMISSION_DENIED; 280 281 EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port); 282 283 while (portUsers.HasNext()) { 284 TCPEndpoint *user = portUsers.Next(); 285 286 if (user->LocalAddress().IsEmpty(false) 287 || address.EqualTo(*user->LocalAddress(), false)) { 288 if ((endpoint->socket->options & SO_REUSEADDR) == 0) 289 return EADDRINUSE; 290 // TODO lock endpoint before retriving state? 291 if (user->State() != TIME_WAIT && user->State() != CLOSED) 292 return EADDRINUSE; 293 } 294 } 295 296 return _Bind(endpoint, *address); 297 } 298 299 300 status_t 301 EndpointManager::_BindToEphemeral(TCPEndpoint *endpoint, 302 const sockaddr *address) 303 { 304 TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint)); 305 306 uint32 max = kFirstEphemeralPort + 65536; 307 308 for (int32 i = 1; i < 5; i++) { 309 // try to retrieve a more or less random port 310 uint32 counter = kFirstEphemeralPort; 311 uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1; 312 313 while (counter < max) { 314 uint16 port = counter & 0xffff; 315 if (port <= kLastReservedPort) 316 port += kLastReservedPort; 317 318 port = htons(port); 319 320 if (!fEndpointHash.Lookup(port).HasNext()) { 321 SocketAddressStorage newAddress(AddressModule()); 322 newAddress.SetTo(address); 323 newAddress.SetPort(port); 324 325 // found a port 326 TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint, 327 AddressString(Domain(), *newAddress, true).Data())); 328 329 return _Bind(endpoint, *newAddress); 330 } 331 332 counter += step; 333 } 334 } 335 336 // could not find a port! 337 return EADDRINUSE; 338 } 339 340 341 status_t 342 EndpointManager::_Bind(TCPEndpoint *endpoint, const sockaddr *address) 343 { 344 // Thus far we have checked if the Bind() is allowed 345 346 status_t status = endpoint->next->module->bind(endpoint->next, address); 347 if (status < B_OK) 348 return status; 349 350 fEndpointHash.Insert(endpoint); 351 352 return B_OK; 353 } 354 355 356 status_t 357 EndpointManager::Unbind(TCPEndpoint *endpoint) 358 { 359 TRACE(("EndpointManager::Unbind(%p)\n", endpoint)); 360 361 if (endpoint == NULL || !endpoint->IsBound()) { 362 TRACE((" endpoint is unbound.\n")); 363 return B_BAD_VALUE; 364 } 365 366 BenaphoreLocker _(fLock); 367 368 if (!fEndpointHash.Remove(endpoint)) 369 panic("bound endpoint %p not in hash!", endpoint); 370 371 fConnectionHash.Remove(endpoint); 372 373 (*endpoint->LocalAddress())->sa_len = 0; 374 375 return B_OK; 376 } 377 378 379 status_t 380 EndpointManager::ReplyWithReset(tcp_segment_header &segment, 381 net_buffer *buffer) 382 { 383 TRACE(("TCP: Sending RST...\n")); 384 385 net_buffer *reply = gBufferModule->create(512); 386 if (reply == NULL) 387 return B_NO_MEMORY; 388 389 AddressModule()->set_to(reply->source, buffer->destination); 390 AddressModule()->set_to(reply->destination, buffer->source); 391 392 tcp_segment_header outSegment(TCP_FLAG_RESET); 393 outSegment.sequence = 0; 394 outSegment.acknowledge = 0; 395 outSegment.advertised_window = 0; 396 outSegment.urgent_offset = 0; 397 398 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) { 399 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE; 400 outSegment.acknowledge = segment.sequence + buffer->size; 401 } else 402 outSegment.sequence = segment.acknowledge; 403 404 status_t status = add_tcp_header(AddressModule(), outSegment, reply); 405 if (status == B_OK) 406 status = Domain()->module->send_data(NULL, reply); 407 408 if (status != B_OK) 409 gBufferModule->free(reply); 410 411 return status; 412 } 413 414 415 void 416 EndpointManager::DumpEndpoints() const 417 { 418 kprintf("-------- TCP Domain %p ---------\n", this); 419 kprintf("%10s %20s %20s %8s %8s %12s\n", "address", "local", "peer", 420 "recv-q", "send-q", "state"); 421 422 ConnectionTable::Iterator it = fConnectionHash.GetIterator(); 423 424 while (it.HasNext()) { 425 TCPEndpoint *endpoint = it.Next(); 426 427 char localBuf[64], peerBuf[64]; 428 endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true); 429 endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true); 430 431 kprintf("%p %20s %20s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf, 432 endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(), 433 name_for_state(endpoint->State())); 434 } 435 } 436 437