xref: /haiku/src/add-ons/kernel/network/protocols/tcp/EndpointManager.cpp (revision 9f3bdf3d039430b5172c424def20ce5d9f7367d4)
1 /*
2  * Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  *		Hugo Santos, hugosantos@gmail.com
8  */
9 
10 
11 #include "EndpointManager.h"
12 
13 #include <new>
14 #include <unistd.h>
15 
16 #include <KernelExport.h>
17 
18 #include <NetUtilities.h>
19 #include <tracing.h>
20 
21 #include "TCPEndpoint.h"
22 
23 
24 //#define TRACE_ENDPOINT_MANAGER
25 #ifdef TRACE_ENDPOINT_MANAGER
26 #	define TRACE(x) dprintf x
27 #else
28 #	define TRACE(x)
29 #endif
30 
31 #if TCP_TRACING
32 #	define ENDPOINT_TRACING
33 #endif
34 #ifdef ENDPOINT_TRACING
35 namespace EndpointTracing {
36 
37 class Bind : public AbstractTraceEntry {
38 public:
39 	Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral)
40 		:
41 		fEndpoint(endpoint),
42 		fEphemeral(ephemeral)
43 	{
44 		address.AsString(fAddress, sizeof(fAddress), true);
45 		Initialized();
46 	}
47 
48 	Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral)
49 		:
50 		fEndpoint(endpoint),
51 		fEphemeral(ephemeral)
52 	{
53 		address.AsString(fAddress, sizeof(fAddress), true);
54 		Initialized();
55 	}
56 
57 	virtual void AddDump(TraceOutput& out)
58 	{
59 		out.Print("tcp:e:%p bind%s address %s", fEndpoint,
60 			fEphemeral ? " ephemeral" : "", fAddress);
61 	}
62 
63 protected:
64 	TCPEndpoint*	fEndpoint;
65 	char			fAddress[32];
66 	bool			fEphemeral;
67 };
68 
69 class Connect : public AbstractTraceEntry {
70 public:
71 	Connect(TCPEndpoint* endpoint)
72 		:
73 		fEndpoint(endpoint)
74 	{
75 		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
76 		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
77 		Initialized();
78 	}
79 
80 	virtual void AddDump(TraceOutput& out)
81 	{
82 		out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal,
83 			fPeer);
84 	}
85 
86 protected:
87 	TCPEndpoint*	fEndpoint;
88 	char			fLocal[32];
89 	char			fPeer[32];
90 };
91 
92 class Unbind : public AbstractTraceEntry {
93 public:
94 	Unbind(TCPEndpoint* endpoint)
95 		:
96 		fEndpoint(endpoint)
97 	{
98 		//fStackTrace = capture_tracing_stack_trace(10, 0, false);
99 
100 		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
101 		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
102 		Initialized();
103 	}
104 
105 #if 0
106 	virtual void DumpStackTrace(TraceOutput& out)
107 	{
108 		out.PrintStackTrace(fStackTrace);
109 	}
110 #endif
111 
112 	virtual void AddDump(TraceOutput& out)
113 	{
114 		out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal,
115 			fPeer);
116 	}
117 
118 protected:
119 	TCPEndpoint*	fEndpoint;
120 	//tracing_stack_trace* fStackTrace;
121 	char			fLocal[32];
122 	char			fPeer[32];
123 };
124 
125 }	// namespace EndpointTracing
126 
127 #	define T(x)	new(std::nothrow) EndpointTracing::x
128 #else
129 #	define T(x)
130 #endif	// ENDPOINT_TRACING
131 
132 
133 static const uint16 kLastReservedPort = 1023;
134 static const uint16 kFirstEphemeralPort = 40000;
135 
136 
137 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager)
138 	:
139 	fManager(manager)
140 {
141 }
142 
143 
144 size_t
145 ConnectionHashDefinition::HashKey(const KeyType& key) const
146 {
147 	return ConstSocketAddress(fManager->AddressModule(),
148 		key.first).HashPair(key.second);
149 }
150 
151 
152 size_t
153 ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const
154 {
155 	return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
156 }
157 
158 
159 bool
160 ConnectionHashDefinition::Compare(const KeyType& key,
161 	TCPEndpoint* endpoint) const
162 {
163 	return endpoint->LocalAddress().EqualTo(key.first, true)
164 		&& endpoint->PeerAddress().EqualTo(key.second, true);
165 }
166 
167 
168 TCPEndpoint*&
169 ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const
170 {
171 	return endpoint->fConnectionHashLink;
172 }
173 
174 
175 //	#pragma mark -
176 
177 
178 size_t
179 EndpointHashDefinition::HashKey(uint16 port) const
180 {
181 	return port;
182 }
183 
184 
185 size_t
186 EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const
187 {
188 	return endpoint->LocalAddress().Port();
189 }
190 
191 
192 bool
193 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const
194 {
195 	return endpoint->LocalAddress().Port() == port;
196 }
197 
198 
199 bool
200 EndpointHashDefinition::CompareValues(TCPEndpoint* first,
201 	TCPEndpoint* second) const
202 {
203 	return first->LocalAddress().Port() == second->LocalAddress().Port();
204 }
205 
206 
207 TCPEndpoint*&
208 EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const
209 {
210 	return endpoint->fEndpointHashLink;
211 }
212 
213 
214 //	#pragma mark -
215 
216 
217 EndpointManager::EndpointManager(net_domain* domain)
218 	:
219 	fDomain(domain),
220 	fConnectionHash(this),
221 	fLastPort(kFirstEphemeralPort)
222 {
223 	rw_lock_init(&fLock, "TCP endpoint manager");
224 }
225 
226 
227 EndpointManager::~EndpointManager()
228 {
229 	rw_lock_destroy(&fLock);
230 }
231 
232 
233 status_t
234 EndpointManager::Init()
235 {
236 	status_t status = fConnectionHash.Init();
237 	if (status == B_OK)
238 		status = fEndpointHash.Init();
239 
240 	return status;
241 }
242 
243 
244 //	#pragma mark - connections
245 
246 
247 /*!	Returns the endpoint matching the connection.
248 	You must hold the manager's lock when calling this method (either read or
249 	write).
250 */
251 TCPEndpoint*
252 EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer)
253 {
254 	return fConnectionHash.Lookup(std::make_pair(local, peer));
255 }
256 
257 
258 status_t
259 EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local,
260 	const sockaddr* peer, const sockaddr* interfaceLocal)
261 {
262 	TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
263 
264 	WriteLocker _(fLock);
265 
266 	SocketAddressStorage local(AddressModule());
267 	local.SetTo(_local);
268 
269 	if (local.IsEmpty(false)) {
270 		uint16 port = local.Port();
271 		local.SetTo(interfaceLocal);
272 		local.SetPort(port);
273 	}
274 
275 	// We want to create a connection for (local, peer), so check to make sure
276 	// that this pair is not already in use by an existing connection.
277 	if (_LookupConnection(*local, peer) != NULL)
278 		return EADDRINUSE;
279 
280 	endpoint->LocalAddress().SetTo(*local);
281 	endpoint->PeerAddress().SetTo(peer);
282 	T(Connect(endpoint));
283 
284 	// BOpenHashTable doesn't support inserting duplicate objects. Since
285 	// BOpenHashTable is a chained hash table where the items are required to
286 	// be intrusive linked list nodes, inserting the same object twice will
287 	// create a cycle in the linked list, which is not handled currently.
288 	//
289 	// We need to makes sure to remove any existing copy of this endpoint
290 	// object from the table in order to handle calling connect() on a closed
291 	// socket to connect to a different remote (address, port) than it was
292 	// originally used for.
293 	//
294 	// We use RemoveUnchecked here because we don't want the hash table to
295 	// resize itself after this removal when we are planning to just add
296 	// another.
297 	fConnectionHash.RemoveUnchecked(endpoint);
298 
299 	fConnectionHash.Insert(endpoint);
300 	return B_OK;
301 }
302 
303 
304 status_t
305 EndpointManager::SetPassive(TCPEndpoint* endpoint)
306 {
307 	WriteLocker _(fLock);
308 
309 	if (!endpoint->IsBound()) {
310 		// if the socket is unbound first bind it to ephemeral
311 		SocketAddressStorage local(AddressModule());
312 		local.SetToEmpty();
313 
314 		status_t status = _BindToEphemeral(endpoint, *local);
315 		if (status < B_OK)
316 			return status;
317 	}
318 
319 	SocketAddressStorage passive(AddressModule());
320 	passive.SetToEmpty();
321 
322 	if (_LookupConnection(*endpoint->LocalAddress(), *passive))
323 		return EADDRINUSE;
324 
325 	endpoint->PeerAddress().SetTo(*passive);
326 	fConnectionHash.Insert(endpoint);
327 	return B_OK;
328 }
329 
330 
331 TCPEndpoint*
332 EndpointManager::FindConnection(sockaddr* local, sockaddr* peer)
333 {
334 	ReadLocker _(fLock);
335 
336 	TCPEndpoint *endpoint = _LookupConnection(local, peer);
337 	if (endpoint != NULL) {
338 		TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n",
339 			endpoint));
340 		if (gSocketModule->acquire_socket(endpoint->socket))
341 			return endpoint;
342 	}
343 
344 	// no explicit endpoint exists, check for wildcard endpoints
345 
346 	SocketAddressStorage wildcard(AddressModule());
347 	wildcard.SetToEmpty();
348 
349 	endpoint = _LookupConnection(local, *wildcard);
350 	if (endpoint != NULL) {
351 		TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n",
352 			endpoint));
353 		if (gSocketModule->acquire_socket(endpoint->socket))
354 			return endpoint;
355 	}
356 
357 	SocketAddressStorage localWildcard(AddressModule());
358 	localWildcard.SetToEmpty();
359 	localWildcard.SetPort(AddressModule()->get_port(local));
360 
361 	endpoint = _LookupConnection(*localWildcard, *wildcard);
362 	if (endpoint != NULL) {
363 		TRACE(("TCP: Received packet corresponds to local wildcard endpoint "
364 			"%p\n", endpoint));
365 		if (gSocketModule->acquire_socket(endpoint->socket))
366 			return endpoint;
367 	}
368 
369 	// no matching endpoint exists
370 	TRACE(("TCP: no matching endpoint!\n"));
371 
372 	return NULL;
373 }
374 
375 
376 //	#pragma mark - endpoints
377 
378 
379 status_t
380 EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address)
381 {
382 	// check the family
383 	if (!AddressModule()->is_same_family(address))
384 		return EAFNOSUPPORT;
385 
386 	WriteLocker locker(fLock);
387 
388 	if (AddressModule()->get_port(address) == 0)
389 		return _BindToEphemeral(endpoint, address);
390 
391 	return _BindToAddress(locker, endpoint, address);
392 }
393 
394 
395 status_t
396 EndpointManager::BindChild(TCPEndpoint* endpoint, const sockaddr* address)
397 {
398 	WriteLocker _(fLock);
399 	return _Bind(endpoint, address);
400 }
401 
402 
403 /*! You must have fLock write locked when calling this method. */
404 status_t
405 EndpointManager::_BindToAddress(WriteLocker& locker, TCPEndpoint* endpoint,
406 	const sockaddr* _address)
407 {
408 	ConstSocketAddress address(AddressModule(), _address);
409 	uint16 port = address.Port();
410 
411 	TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
412 	T(Bind(endpoint, address, false));
413 
414 	// TODO: this check follows very typical UNIX semantics
415 	// and generally should be improved.
416 	if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
417 		return B_PERMISSION_DENIED;
418 
419 	bool retrying = false;
420 	int32 retry = 0;
421 	do {
422 		EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
423 		retry = false;
424 
425 		while (portUsers.HasNext()) {
426 			TCPEndpoint* user = portUsers.Next();
427 
428 			if (user->LocalAddress().IsEmpty(false)
429 				|| address.EqualTo(*user->LocalAddress(), false)) {
430 				// Check if this belongs to a local connection
431 
432 				// Note, while we hold our lock, the endpoint cannot go away,
433 				// it can only change its state - IsLocal() is safe to be used
434 				// without having the endpoint locked.
435 				tcp_state userState = user->State();
436 				if (user->IsLocal()
437 					&& (userState > ESTABLISHED || userState == CLOSED)) {
438 					// This is a closing local connection - wait until it's
439 					// gone away for real
440 					locker.Unlock();
441 					snooze(10000);
442 					locker.Lock();
443 						// TODO: make this better
444 					if (!retrying) {
445 						retrying = true;
446 						retry = 5;
447 					}
448 					break;
449 				}
450 
451 				if ((endpoint->socket->options & SO_REUSEADDR) == 0)
452 					return EADDRINUSE;
453 
454 				if (userState != TIME_WAIT && userState != CLOSED)
455 					return EADDRINUSE;
456 			}
457 		}
458 	} while (retry-- > 0);
459 
460 	return _Bind(endpoint, *address);
461 }
462 
463 
464 /*! You must have fLock write locked when calling this method. */
465 status_t
466 EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint,
467 	const sockaddr* address)
468 {
469 	TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
470 
471 	uint32 max = fLastPort + 65536;
472 
473 	for (int32 i = 1; i < 5; i++) {
474 		// try to retrieve a more or less random port
475 		uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
476 		uint32 counter = fLastPort + step;
477 
478 		while (counter < max) {
479 			uint16 port = counter & 0xffff;
480 			if (port <= kLastReservedPort)
481 				port += kLastReservedPort;
482 
483 			fLastPort = port;
484 			port = htons(port);
485 
486 			if (!fEndpointHash.Lookup(port).HasNext()) {
487 				// found a port
488 				SocketAddressStorage newAddress(AddressModule());
489 				newAddress.SetTo(address);
490 				newAddress.SetPort(port);
491 
492 				TRACE(("   EndpointManager::BindToEphemeral(%p) -> %s\n",
493 					endpoint, AddressString(Domain(), *newAddress,
494 					true).Data()));
495 				T(Bind(endpoint, newAddress, true));
496 
497 				return _Bind(endpoint, *newAddress);
498 			}
499 
500 			counter += step;
501 		}
502 	}
503 
504 	// could not find a port!
505 	return EADDRINUSE;
506 }
507 
508 
509 status_t
510 EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address)
511 {
512 	// Thus far we have checked if the Bind() is allowed
513 
514 	status_t status = endpoint->next->module->bind(endpoint->next, address);
515 	if (status < B_OK)
516 		return status;
517 
518 	fEndpointHash.Insert(endpoint);
519 
520 	return B_OK;
521 }
522 
523 
524 status_t
525 EndpointManager::Unbind(TCPEndpoint* endpoint)
526 {
527 	TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
528 	T(Unbind(endpoint));
529 
530 	if (endpoint == NULL || !endpoint->IsBound()) {
531 		TRACE(("  endpoint is unbound.\n"));
532 		return B_BAD_VALUE;
533 	}
534 
535 	WriteLocker _(fLock);
536 
537 	if (!fEndpointHash.Remove(endpoint))
538 		panic("bound endpoint %p not in hash!", endpoint);
539 
540 	fConnectionHash.Remove(endpoint);
541 
542 	(*endpoint->LocalAddress())->sa_len = 0;
543 
544 	return B_OK;
545 }
546 
547 
548 status_t
549 EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer)
550 {
551 	TRACE(("TCP: Sending RST...\n"));
552 
553 	net_buffer* reply = gBufferModule->create(512);
554 	if (reply == NULL)
555 		return B_NO_MEMORY;
556 
557 	AddressModule()->set_to(reply->source, buffer->destination);
558 	AddressModule()->set_to(reply->destination, buffer->source);
559 
560 	tcp_segment_header outSegment(TCP_FLAG_RESET);
561 	outSegment.sequence = 0;
562 	outSegment.acknowledge = 0;
563 	outSegment.advertised_window = 0;
564 	outSegment.urgent_offset = 0;
565 
566 	if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
567 		outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
568 		outSegment.acknowledge = segment.sequence + buffer->size;
569 		// TODO: Confirm:
570 		if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
571 			outSegment.acknowledge++;
572 	} else
573 		outSegment.sequence = segment.acknowledge;
574 
575 	status_t status = add_tcp_header(AddressModule(), outSegment, reply);
576 	if (status == B_OK)
577 		status = Domain()->module->send_data(NULL, reply);
578 
579 	if (status != B_OK)
580 		gBufferModule->free(reply);
581 
582 	return status;
583 }
584 
585 
586 void
587 EndpointManager::Dump() const
588 {
589 	kprintf("-------- TCP Domain %p ---------\n", this);
590 	kprintf("%10s %21s %21s %8s %8s %12s\n", "address", "local", "peer",
591 		"recv-q", "send-q", "state");
592 
593 	ConnectionTable::Iterator iterator = fConnectionHash.GetIterator();
594 
595 	while (iterator.HasNext()) {
596 		TCPEndpoint *endpoint = iterator.Next();
597 
598 		char localBuf[64], peerBuf[64];
599 		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
600 		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
601 
602 		kprintf("%p %21s %21s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
603 			endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
604 			name_for_state(endpoint->State()));
605 	}
606 }
607 
608