1 /*
2 * Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
4 *
5 * Authors:
6 * Axel Dörfler, axeld@pinc-software.de
7 * Hugo Santos, hugosantos@gmail.com
8 */
9
10
11 #include "EndpointManager.h"
12
13 #include <new>
14 #include <unistd.h>
15
16 #include <KernelExport.h>
17
18 #include <NetUtilities.h>
19 #include <tracing.h>
20
21 #include "TCPEndpoint.h"
22
23
24 //#define TRACE_ENDPOINT_MANAGER
25 #ifdef TRACE_ENDPOINT_MANAGER
26 # define TRACE(x) dprintf x
27 #else
28 # define TRACE(x)
29 #endif
30
31 #if TCP_TRACING
32 # define ENDPOINT_TRACING
33 #endif
34 #ifdef ENDPOINT_TRACING
35 namespace EndpointTracing {
36
37 class Bind : public AbstractTraceEntry {
38 public:
Bind(TCPEndpoint * endpoint,ConstSocketAddress & address,bool ephemeral)39 Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral)
40 :
41 fEndpoint(endpoint),
42 fEphemeral(ephemeral)
43 {
44 address.AsString(fAddress, sizeof(fAddress), true);
45 Initialized();
46 }
47
Bind(TCPEndpoint * endpoint,SocketAddress & address,bool ephemeral)48 Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral)
49 :
50 fEndpoint(endpoint),
51 fEphemeral(ephemeral)
52 {
53 address.AsString(fAddress, sizeof(fAddress), true);
54 Initialized();
55 }
56
AddDump(TraceOutput & out)57 virtual void AddDump(TraceOutput& out)
58 {
59 out.Print("tcp:e:%p bind%s address %s", fEndpoint,
60 fEphemeral ? " ephemeral" : "", fAddress);
61 }
62
63 protected:
64 TCPEndpoint* fEndpoint;
65 char fAddress[32];
66 bool fEphemeral;
67 };
68
69 class Connect : public AbstractTraceEntry {
70 public:
Connect(TCPEndpoint * endpoint)71 Connect(TCPEndpoint* endpoint)
72 :
73 fEndpoint(endpoint)
74 {
75 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
76 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
77 Initialized();
78 }
79
AddDump(TraceOutput & out)80 virtual void AddDump(TraceOutput& out)
81 {
82 out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal,
83 fPeer);
84 }
85
86 protected:
87 TCPEndpoint* fEndpoint;
88 char fLocal[32];
89 char fPeer[32];
90 };
91
92 class Unbind : public AbstractTraceEntry {
93 public:
Unbind(TCPEndpoint * endpoint)94 Unbind(TCPEndpoint* endpoint)
95 :
96 fEndpoint(endpoint)
97 {
98 //fStackTrace = capture_tracing_stack_trace(10, 0, false);
99
100 endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
101 endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
102 Initialized();
103 }
104
105 #if 0
106 virtual void DumpStackTrace(TraceOutput& out)
107 {
108 out.PrintStackTrace(fStackTrace);
109 }
110 #endif
111
AddDump(TraceOutput & out)112 virtual void AddDump(TraceOutput& out)
113 {
114 out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal,
115 fPeer);
116 }
117
118 protected:
119 TCPEndpoint* fEndpoint;
120 //tracing_stack_trace* fStackTrace;
121 char fLocal[32];
122 char fPeer[32];
123 };
124
125 } // namespace EndpointTracing
126
127 # define T(x) new(std::nothrow) EndpointTracing::x
128 #else
129 # define T(x)
130 #endif // ENDPOINT_TRACING
131
132
133 static const uint16 kLastReservedPort = 1023;
134 static const uint16 kFirstEphemeralPort = 40000;
135
136
ConnectionHashDefinition(EndpointManager * manager)137 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager)
138 :
139 fManager(manager)
140 {
141 }
142
143
144 size_t
HashKey(const KeyType & key) const145 ConnectionHashDefinition::HashKey(const KeyType& key) const
146 {
147 return ConstSocketAddress(fManager->AddressModule(),
148 key.first).HashPair(key.second);
149 }
150
151
152 size_t
Hash(TCPEndpoint * endpoint) const153 ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const
154 {
155 return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
156 }
157
158
159 bool
Compare(const KeyType & key,TCPEndpoint * endpoint) const160 ConnectionHashDefinition::Compare(const KeyType& key,
161 TCPEndpoint* endpoint) const
162 {
163 return endpoint->LocalAddress().EqualTo(key.first, true)
164 && endpoint->PeerAddress().EqualTo(key.second, true);
165 }
166
167
168 TCPEndpoint*&
GetLink(TCPEndpoint * endpoint) const169 ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const
170 {
171 return endpoint->fConnectionHashLink;
172 }
173
174
175 // #pragma mark -
176
177
178 size_t
HashKey(uint16 port) const179 EndpointHashDefinition::HashKey(uint16 port) const
180 {
181 return port;
182 }
183
184
185 size_t
Hash(TCPEndpoint * endpoint) const186 EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const
187 {
188 return endpoint->LocalAddress().Port();
189 }
190
191
192 bool
Compare(uint16 port,TCPEndpoint * endpoint) const193 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const
194 {
195 return endpoint->LocalAddress().Port() == port;
196 }
197
198
199 bool
CompareValues(TCPEndpoint * first,TCPEndpoint * second) const200 EndpointHashDefinition::CompareValues(TCPEndpoint* first,
201 TCPEndpoint* second) const
202 {
203 return first->LocalAddress().Port() == second->LocalAddress().Port();
204 }
205
206
207 TCPEndpoint*&
GetLink(TCPEndpoint * endpoint) const208 EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const
209 {
210 return endpoint->fEndpointHashLink;
211 }
212
213
214 // #pragma mark -
215
216
EndpointManager(net_domain * domain)217 EndpointManager::EndpointManager(net_domain* domain)
218 :
219 fDomain(domain),
220 fConnectionHash(this),
221 fLastPort(kFirstEphemeralPort)
222 {
223 rw_lock_init(&fLock, "TCP endpoint manager");
224 }
225
226
~EndpointManager()227 EndpointManager::~EndpointManager()
228 {
229 rw_lock_destroy(&fLock);
230 }
231
232
233 status_t
Init()234 EndpointManager::Init()
235 {
236 status_t status = fConnectionHash.Init();
237 if (status == B_OK)
238 status = fEndpointHash.Init();
239
240 return status;
241 }
242
243
244 // #pragma mark - connections
245
246
247 /*! Returns the endpoint matching the connection.
248 You must hold the manager's lock when calling this method (either read or
249 write).
250 */
251 TCPEndpoint*
_LookupConnection(const sockaddr * local,const sockaddr * peer)252 EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer)
253 {
254 return fConnectionHash.Lookup(std::make_pair(local, peer));
255 }
256
257
258 status_t
SetConnection(TCPEndpoint * endpoint,const sockaddr * _local,const sockaddr * peer,const sockaddr * interfaceLocal)259 EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local,
260 const sockaddr* peer, const sockaddr* interfaceLocal)
261 {
262 TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
263
264 WriteLocker _(fLock);
265
266 SocketAddressStorage local(AddressModule());
267 local.SetTo(_local);
268
269 if (local.IsEmpty(false)) {
270 uint16 port = local.Port();
271 local.SetTo(interfaceLocal);
272 local.SetPort(port);
273 }
274
275 // We want to create a connection for (local, peer), so check to make sure
276 // that this pair is not already in use by an existing connection.
277 if (_LookupConnection(*local, peer) != NULL)
278 return EADDRINUSE;
279
280 endpoint->LocalAddress().SetTo(*local);
281 endpoint->PeerAddress().SetTo(peer);
282 T(Connect(endpoint));
283
284 // BOpenHashTable doesn't support inserting duplicate objects. Since
285 // BOpenHashTable is a chained hash table where the items are required to
286 // be intrusive linked list nodes, inserting the same object twice will
287 // create a cycle in the linked list, which is not handled currently.
288 //
289 // We need to makes sure to remove any existing copy of this endpoint
290 // object from the table in order to handle calling connect() on a closed
291 // socket to connect to a different remote (address, port) than it was
292 // originally used for.
293 //
294 // We use RemoveUnchecked here because we don't want the hash table to
295 // resize itself after this removal when we are planning to just add
296 // another.
297 fConnectionHash.RemoveUnchecked(endpoint);
298
299 fConnectionHash.Insert(endpoint);
300 return B_OK;
301 }
302
303
304 status_t
SetPassive(TCPEndpoint * endpoint)305 EndpointManager::SetPassive(TCPEndpoint* endpoint)
306 {
307 WriteLocker _(fLock);
308
309 if (!endpoint->IsBound()) {
310 // if the socket is unbound first bind it to ephemeral
311 SocketAddressStorage local(AddressModule());
312 local.SetToEmpty();
313
314 status_t status = _BindToEphemeral(endpoint, *local);
315 if (status < B_OK)
316 return status;
317 }
318
319 SocketAddressStorage passive(AddressModule());
320 passive.SetToEmpty();
321
322 if (_LookupConnection(*endpoint->LocalAddress(), *passive))
323 return EADDRINUSE;
324
325 endpoint->PeerAddress().SetTo(*passive);
326 fConnectionHash.Insert(endpoint);
327 return B_OK;
328 }
329
330
331 TCPEndpoint*
FindConnection(sockaddr * local,sockaddr * peer)332 EndpointManager::FindConnection(sockaddr* local, sockaddr* peer)
333 {
334 ReadLocker _(fLock);
335
336 TCPEndpoint *endpoint = _LookupConnection(local, peer);
337 if (endpoint != NULL) {
338 TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n",
339 endpoint));
340 if (gSocketModule->acquire_socket(endpoint->socket))
341 return endpoint;
342 }
343
344 // no explicit endpoint exists, check for wildcard endpoints
345
346 SocketAddressStorage wildcard(AddressModule());
347 wildcard.SetToEmpty();
348
349 endpoint = _LookupConnection(local, *wildcard);
350 if (endpoint != NULL) {
351 TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n",
352 endpoint));
353 if (gSocketModule->acquire_socket(endpoint->socket))
354 return endpoint;
355 }
356
357 SocketAddressStorage localWildcard(AddressModule());
358 localWildcard.SetToEmpty();
359 localWildcard.SetPort(AddressModule()->get_port(local));
360
361 endpoint = _LookupConnection(*localWildcard, *wildcard);
362 if (endpoint != NULL) {
363 TRACE(("TCP: Received packet corresponds to local wildcard endpoint "
364 "%p\n", endpoint));
365 if (gSocketModule->acquire_socket(endpoint->socket))
366 return endpoint;
367 }
368
369 // no matching endpoint exists
370 TRACE(("TCP: no matching endpoint!\n"));
371
372 return NULL;
373 }
374
375
376 // #pragma mark - endpoints
377
378
379 status_t
Bind(TCPEndpoint * endpoint,const sockaddr * address)380 EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address)
381 {
382 // check the family
383 if (!AddressModule()->is_same_family(address))
384 return EAFNOSUPPORT;
385
386 WriteLocker locker(fLock);
387
388 if (AddressModule()->get_port(address) == 0)
389 return _BindToEphemeral(endpoint, address);
390
391 return _BindToAddress(locker, endpoint, address);
392 }
393
394
395 status_t
BindChild(TCPEndpoint * endpoint,const sockaddr * address)396 EndpointManager::BindChild(TCPEndpoint* endpoint, const sockaddr* address)
397 {
398 WriteLocker _(fLock);
399 return _Bind(endpoint, address);
400 }
401
402
403 /*! You must have fLock write locked when calling this method. */
404 status_t
_BindToAddress(WriteLocker & locker,TCPEndpoint * endpoint,const sockaddr * _address)405 EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint,
406 const sockaddr* _address)
407 {
408 ConstSocketAddress address(AddressModule(), _address);
409 uint16 port = address.Port();
410
411 TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
412 T(Bind(endpoint, address, false));
413
414 // TODO: this check follows very typical UNIX semantics
415 // and generally should be improved.
416 if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
417 return B_PERMISSION_DENIED;
418
419 bool retrying = false;
420 int32 retry = 0;
421 do {
422 EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
423 retry = false;
424
425 while (portUsers.HasNext()) {
426 TCPEndpoint* user = portUsers.Next();
427
428 if (user->LocalAddress().IsEmpty(false)
429 || address.EqualTo(*user->LocalAddress(), false)) {
430 // Check if this belongs to a local connection
431
432 // Note, while we hold our lock, the endpoint cannot go away,
433 // it can only change its state - IsLocal() is safe to be used
434 // without having the endpoint locked.
435 tcp_state userState = user->State();
436 if (user->IsLocal()
437 && (userState > ESTABLISHED || userState == CLOSED)) {
438 // This is a closing local connection - wait until it's
439 // gone away for real
440 locker.Unlock();
441 snooze(10000);
442 locker.Lock();
443 // TODO: make this better
444 if (!retrying) {
445 retrying = true;
446 retry = 5;
447 }
448 break;
449 }
450
451 if ((endpoint->socket->options & SO_REUSEADDR) == 0)
452 return EADDRINUSE;
453
454 if (userState != TIME_WAIT && userState != CLOSED)
455 return EADDRINUSE;
456 }
457 }
458 } while (retry-- > 0);
459
460 return _Bind(endpoint, *address);
461 }
462
463
464 /*! You must have fLock write locked when calling this method. */
465 status_t
_BindToEphemeral(TCPEndpoint * endpoint,const sockaddr * address)466 EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint,
467 const sockaddr* address)
468 {
469 TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
470
471 uint32 max = fLastPort + 65536;
472
473 for (int32 i = 1; i < 5; i++) {
474 // try to retrieve a more or less random port
475 uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
476 uint32 counter = fLastPort + step;
477
478 while (counter < max) {
479 uint16 port = counter & 0xffff;
480 if (port <= kLastReservedPort)
481 port += kLastReservedPort;
482
483 fLastPort = port;
484 port = htons(port);
485
486 if (!fEndpointHash.Lookup(port).HasNext()) {
487 // found a port
488 SocketAddressStorage newAddress(AddressModule());
489 newAddress.SetTo(address);
490 newAddress.SetPort(port);
491
492 TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n",
493 endpoint, AddressString(Domain(), *newAddress,
494 true).Data()));
495 T(Bind(endpoint, newAddress, true));
496
497 return _Bind(endpoint, *newAddress);
498 }
499
500 counter += step;
501 }
502 }
503
504 // could not find a port!
505 return EADDRINUSE;
506 }
507
508
509 status_t
_Bind(TCPEndpoint * endpoint,const sockaddr * address)510 EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address)
511 {
512 // Thus far we have checked if the Bind() is allowed
513
514 status_t status = endpoint->next->module->bind(endpoint->next, address);
515 if (status < B_OK)
516 return status;
517
518 fEndpointHash.Insert(endpoint);
519
520 return B_OK;
521 }
522
523
524 status_t
Unbind(TCPEndpoint * endpoint)525 EndpointManager::Unbind(TCPEndpoint* endpoint)
526 {
527 TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
528 T(Unbind(endpoint));
529
530 if (endpoint == NULL || !endpoint->IsBound()) {
531 TRACE((" endpoint is unbound.\n"));
532 return B_BAD_VALUE;
533 }
534
535 WriteLocker _(fLock);
536
537 if (!fEndpointHash.Remove(endpoint))
538 panic("bound endpoint %p not in hash!", endpoint);
539
540 fConnectionHash.Remove(endpoint);
541
542 (*endpoint->LocalAddress())->sa_len = 0;
543
544 return B_OK;
545 }
546
547
548 status_t
ReplyWithReset(tcp_segment_header & segment,net_buffer * buffer)549 EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer)
550 {
551 TRACE(("TCP: Sending RST...\n"));
552
553 net_buffer* reply = gBufferModule->create(512);
554 if (reply == NULL)
555 return B_NO_MEMORY;
556
557 AddressModule()->set_to(reply->source, buffer->destination);
558 AddressModule()->set_to(reply->destination, buffer->source);
559
560 tcp_segment_header outSegment(TCP_FLAG_RESET);
561 outSegment.sequence = 0;
562 outSegment.acknowledge = 0;
563 outSegment.advertised_window = 0;
564 outSegment.urgent_offset = 0;
565
566 if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
567 outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
568 outSegment.acknowledge = segment.sequence + buffer->size;
569 // TODO: Confirm:
570 if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
571 outSegment.acknowledge++;
572 } else
573 outSegment.sequence = segment.acknowledge;
574
575 status_t status = add_tcp_header(AddressModule(), outSegment, reply);
576 if (status == B_OK)
577 status = Domain()->module->send_data(NULL, reply);
578
579 if (status != B_OK)
580 gBufferModule->free(reply);
581
582 return status;
583 }
584
585
586 void
Dump() const587 EndpointManager::Dump() const
588 {
589 kprintf("-------- TCP Domain %p ---------\n", this);
590 kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer",
591 "recv-q", "send-q", "state");
592
593 ConnectionTable::Iterator iterator = fConnectionHash.GetIterator();
594
595 while (iterator.HasNext()) {
596 TCPEndpoint *endpoint = iterator.Next();
597
598 char localBuf[64], peerBuf[64];
599 endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
600 endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
601
602 kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
603 endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
604 name_for_state(endpoint->State()));
605 }
606 }
607
608