xref: /haiku/src/add-ons/kernel/network/protocols/tcp/EndpointManager.cpp (revision c9060eb991e10e477ece52478d6743fc7691c143)
1 /*
2  * Copyright 2006-2008, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  *		Hugo Santos, hugosantos@gmail.com
8  */
9 
10 
11 #include "EndpointManager.h"
12 
13 #include <new>
14 #include <unistd.h>
15 
16 #include <KernelExport.h>
17 
18 #include <NetUtilities.h>
19 #include <util/AutoLock.h>
20 #include <tracing.h>
21 
22 #include "TCPEndpoint.h"
23 
24 
25 //#define TRACE_ENDPOINT_MANAGER
26 #ifdef TRACE_ENDPOINT_MANAGER
27 #	define TRACE(x) dprintf x
28 #else
29 #	define TRACE(x)
30 #endif
31 
32 #if TCP_TRACING
33 #	define ENDPOINT_TRACING
34 #endif
35 #ifdef ENDPOINT_TRACING
36 namespace EndpointTracing {
37 
38 class Bind : public AbstractTraceEntry {
39 public:
40 	Bind(TCPEndpoint* endpoint, ConstSocketAddress& address, bool ephemeral)
41 		:
42 		fEndpoint(endpoint),
43 		fEphemeral(ephemeral)
44 	{
45 		address.AsString(fAddress, sizeof(fAddress), true);
46 		Initialized();
47 	}
48 
49 	Bind(TCPEndpoint* endpoint, SocketAddress& address, bool ephemeral)
50 		:
51 		fEndpoint(endpoint),
52 		fEphemeral(ephemeral)
53 	{
54 		address.AsString(fAddress, sizeof(fAddress), true);
55 		Initialized();
56 	}
57 
58 	virtual void AddDump(TraceOutput& out)
59 	{
60 		out.Print("tcp:e:%p bind%s address %s", fEndpoint,
61 			fEphemeral ? " ephemeral" : "", fAddress);
62 	}
63 
64 protected:
65 	TCPEndpoint*	fEndpoint;
66 	char			fAddress[32];
67 	bool			fEphemeral;
68 };
69 
70 class Connect : public AbstractTraceEntry {
71 public:
72 	Connect(TCPEndpoint* endpoint)
73 		:
74 		fEndpoint(endpoint)
75 	{
76 		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
77 		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
78 		Initialized();
79 	}
80 
81 	virtual void AddDump(TraceOutput& out)
82 	{
83 		out.Print("tcp:e:%p connect local %s, peer %s", fEndpoint, fLocal,
84 			fPeer);
85 	}
86 
87 protected:
88 	TCPEndpoint*	fEndpoint;
89 	char			fLocal[32];
90 	char			fPeer[32];
91 };
92 
93 class Unbind : public AbstractTraceEntry {
94 public:
95 	Unbind(TCPEndpoint* endpoint)
96 		:
97 		fEndpoint(endpoint)
98 	{
99 		//fStackTrace = capture_tracing_stack_trace(10, 0, false);
100 
101 		endpoint->LocalAddress().AsString(fLocal, sizeof(fLocal), true);
102 		endpoint->PeerAddress().AsString(fPeer, sizeof(fPeer), true);
103 		Initialized();
104 	}
105 
106 #if 0
107 	virtual void DumpStackTrace(TraceOutput& out)
108 	{
109 		out.PrintStackTrace(fStackTrace);
110 	}
111 #endif
112 
113 	virtual void AddDump(TraceOutput& out)
114 	{
115 		out.Print("tcp:e:%p unbind, local %s, peer %s", fEndpoint, fLocal,
116 			fPeer);
117 	}
118 
119 protected:
120 	TCPEndpoint*	fEndpoint;
121 	//tracing_stack_trace* fStackTrace;
122 	char			fLocal[32];
123 	char			fPeer[32];
124 };
125 
126 }	// namespace EndpointTracing
127 
128 #	define T(x)	new(std::nothrow) EndpointTracing::x
129 #else
130 #	define T(x)
131 #endif	// ENDPOINT_TRACING
132 
133 
134 static const uint16 kLastReservedPort = 1023;
135 static const uint16 kFirstEphemeralPort = 40000;
136 
137 
138 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager* manager)
139 	:
140 	fManager(manager)
141 {
142 }
143 
144 
145 size_t
146 ConnectionHashDefinition::HashKey(const KeyType& key) const
147 {
148 	return ConstSocketAddress(fManager->AddressModule(),
149 		key.first).HashPair(key.second);
150 }
151 
152 
153 size_t
154 ConnectionHashDefinition::Hash(TCPEndpoint* endpoint) const
155 {
156 	return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
157 }
158 
159 
160 bool
161 ConnectionHashDefinition::Compare(const KeyType& key,
162 	TCPEndpoint* endpoint) const
163 {
164 	return endpoint->LocalAddress().EqualTo(key.first, true)
165 		&& endpoint->PeerAddress().EqualTo(key.second, true);
166 }
167 
168 
169 HashTableLink<TCPEndpoint>*
170 ConnectionHashDefinition::GetLink(TCPEndpoint* endpoint) const
171 {
172 	return &endpoint->fConnectionHashLink;
173 }
174 
175 
176 //	#pragma mark -
177 
178 
179 size_t
180 EndpointHashDefinition::HashKey(uint16 port) const
181 {
182 	return port;
183 }
184 
185 
186 size_t
187 EndpointHashDefinition::Hash(TCPEndpoint* endpoint) const
188 {
189 	return endpoint->LocalAddress().Port();
190 }
191 
192 
193 bool
194 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint* endpoint) const
195 {
196 	return endpoint->LocalAddress().Port() == port;
197 }
198 
199 
200 bool
201 EndpointHashDefinition::CompareValues(TCPEndpoint* first,
202 	TCPEndpoint* second) const
203 {
204 	return first->LocalAddress().Port() == second->LocalAddress().Port();
205 }
206 
207 
208 HashTableLink<TCPEndpoint>*
209 EndpointHashDefinition::GetLink(TCPEndpoint* endpoint) const
210 {
211 	return &endpoint->fEndpointHashLink;
212 }
213 
214 
215 //	#pragma mark -
216 
217 
218 EndpointManager::EndpointManager(net_domain* domain)
219 	:
220 	fDomain(domain),
221 	fConnectionHash(this),
222 	fLastPort(kFirstEphemeralPort)
223 {
224 	mutex_init(&fLock, "endpoint manager");
225 }
226 
227 
228 EndpointManager::~EndpointManager()
229 {
230 	mutex_destroy(&fLock);
231 }
232 
233 
234 status_t
235 EndpointManager::Init()
236 {
237 	status_t status = fConnectionHash.Init();
238 	if (status == B_OK)
239 		status = fEndpointHash.Init();
240 
241 	return status;
242 }
243 
244 
245 //	#pragma mark - connections
246 
247 
248 /*!	Returns the endpoint matching the connection.
249 	You must hold the manager's lock when calling this method.
250 */
251 TCPEndpoint*
252 EndpointManager::_LookupConnection(const sockaddr* local, const sockaddr* peer)
253 {
254 	return fConnectionHash.Lookup(std::make_pair(local, peer));
255 }
256 
257 
258 status_t
259 EndpointManager::SetConnection(TCPEndpoint* endpoint, const sockaddr* _local,
260 	const sockaddr* peer, const sockaddr* interfaceLocal)
261 {
262 	TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
263 
264 	MutexLocker _(fLock);
265 
266 	SocketAddressStorage local(AddressModule());
267 	local.SetTo(_local);
268 
269 	if (local.IsEmpty(false)) {
270 		uint16 port = local.Port();
271 		local.SetTo(interfaceLocal);
272 		local.SetPort(port);
273 	}
274 
275 	if (_LookupConnection(*local, peer) != NULL)
276 		return EADDRINUSE;
277 
278 	endpoint->LocalAddress().SetTo(*local);
279 	endpoint->PeerAddress().SetTo(peer);
280 	T(Connect(endpoint));
281 
282 	fConnectionHash.Insert(endpoint);
283 	return B_OK;
284 }
285 
286 
287 status_t
288 EndpointManager::SetPassive(TCPEndpoint* endpoint)
289 {
290 	MutexLocker _(fLock);
291 
292 	if (!endpoint->IsBound()) {
293 		// if the socket is unbound first bind it to ephemeral
294 		SocketAddressStorage local(AddressModule());
295 		local.SetToEmpty();
296 
297 		status_t status = _BindToEphemeral(endpoint, *local);
298 		if (status < B_OK)
299 			return status;
300 	}
301 
302 	SocketAddressStorage passive(AddressModule());
303 	passive.SetToEmpty();
304 
305 	if (_LookupConnection(*endpoint->LocalAddress(), *passive))
306 		return EADDRINUSE;
307 
308 	endpoint->PeerAddress().SetTo(*passive);
309 	fConnectionHash.Insert(endpoint);
310 	return B_OK;
311 }
312 
313 
314 TCPEndpoint*
315 EndpointManager::FindConnection(sockaddr* local, sockaddr* peer)
316 {
317 	MutexLocker _(fLock);
318 
319 	TCPEndpoint *endpoint = _LookupConnection(local, peer);
320 	if (endpoint != NULL) {
321 		TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", endpoint));
322 		return endpoint;
323 	}
324 
325 	// no explicit endpoint exists, check for wildcard endpoints
326 
327 	SocketAddressStorage wildcard(AddressModule());
328 	wildcard.SetToEmpty();
329 
330 	endpoint = _LookupConnection(local, *wildcard);
331 	if (endpoint != NULL) {
332 		TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint));
333 		return endpoint;
334 	}
335 
336 	SocketAddressStorage localWildcard(AddressModule());
337 	localWildcard.SetToEmpty();
338 	localWildcard.SetPort(AddressModule()->get_port(local));
339 
340 	endpoint = _LookupConnection(*localWildcard, *wildcard);
341 	if (endpoint != NULL) {
342 		TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint));
343 		return endpoint;
344 	}
345 
346 	// no matching endpoint exists
347 	TRACE(("TCP: no matching endpoint!\n"));
348 
349 	return NULL;
350 }
351 
352 
353 //	#pragma mark - endpoints
354 
355 
356 status_t
357 EndpointManager::Bind(TCPEndpoint* endpoint, const sockaddr* address)
358 {
359 	// TODO check the family:
360 	//
361 	// if (!AddressModule()->is_understandable(address))
362 	//	return EAFNOSUPPORT;
363 
364 	MutexLocker _(fLock);
365 
366 	if (AddressModule()->get_port(address) == 0)
367 		return _BindToEphemeral(endpoint, address);
368 
369 	return _BindToAddress(endpoint, address);
370 }
371 
372 
373 status_t
374 EndpointManager::BindChild(TCPEndpoint* endpoint)
375 {
376 	MutexLocker _(fLock);
377 	return _Bind(endpoint, *endpoint->LocalAddress());
378 }
379 
380 
381 /*! You must hold fLock when calling this method. */
382 status_t
383 EndpointManager::_BindToAddress(TCPEndpoint* endpoint, const sockaddr* _address)
384 {
385 	ConstSocketAddress address(AddressModule(), _address);
386 	uint16 port = address.Port();
387 
388 	TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
389 	T(Bind(endpoint, address, false));
390 
391 	// TODO this check follows very typical UNIX semantics
392 	//      and generally should be improved.
393 	if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
394 		return B_PERMISSION_DENIED;
395 
396 	bool retrying = false;
397 	int32 retry = 0;
398 	do {
399 		EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
400 		retry = false;
401 
402 		while (portUsers.HasNext()) {
403 			TCPEndpoint* user = portUsers.Next();
404 
405 			if (user->LocalAddress().IsEmpty(false)
406 				|| address.EqualTo(*user->LocalAddress(), false)) {
407 				// Check if this belongs to a local connection
408 
409 				// Note, while we hold our lock, the endpoint cannot go away,
410 				// it can only change its state - IsLocal() is safe to be used
411 				// without having the endpoint locked.
412 				tcp_state userState = user->State();
413 				if (user->IsLocal()
414 					&& (userState > ESTABLISHED || userState == CLOSED)) {
415 					// This is a closing local connection - wait until it's
416 					// gone away for real
417 					mutex_unlock(&fLock);
418 					snooze(10000);
419 					mutex_lock(&fLock);
420 						// TODO: make this better
421 					if (!retrying) {
422 						retrying = true;
423 						retry = 5;
424 					}
425 					break;
426 				}
427 
428 				if ((endpoint->socket->options & SO_REUSEADDR) == 0)
429 					return EADDRINUSE;
430 
431 				if (userState != TIME_WAIT && userState != CLOSED)
432 					return EADDRINUSE;
433 			}
434 		}
435 	} while (retry-- > 0);
436 
437 	return _Bind(endpoint, *address);
438 }
439 
440 
441 /*! You must hold fLock when calling this method. */
442 status_t
443 EndpointManager::_BindToEphemeral(TCPEndpoint* endpoint,
444 	const sockaddr* address)
445 {
446 	TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
447 
448 	uint32 max = fLastPort + 65536;
449 
450 	for (int32 i = 1; i < 5; i++) {
451 		// try to retrieve a more or less random port
452 		uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
453 		uint32 counter = fLastPort + step;
454 
455 		while (counter < max) {
456 			uint16 port = counter & 0xffff;
457 			if (port <= kLastReservedPort)
458 				port += kLastReservedPort;
459 
460 			fLastPort = port;
461 			port = htons(port);
462 
463 			if (!fEndpointHash.Lookup(port).HasNext()) {
464 				// found a port
465 				SocketAddressStorage newAddress(AddressModule());
466 				newAddress.SetTo(address);
467 				newAddress.SetPort(port);
468 
469 				TRACE(("   EndpointManager::BindToEphemeral(%p) -> %s\n",
470 					endpoint, AddressString(Domain(), *newAddress,
471 					true).Data()));
472 				T(Bind(endpoint, newAddress, true));
473 
474 				return _Bind(endpoint, *newAddress);
475 			}
476 
477 			counter += step;
478 		}
479 	}
480 
481 	// could not find a port!
482 	return EADDRINUSE;
483 }
484 
485 
486 status_t
487 EndpointManager::_Bind(TCPEndpoint* endpoint, const sockaddr* address)
488 {
489 	// Thus far we have checked if the Bind() is allowed
490 
491 	status_t status = endpoint->next->module->bind(endpoint->next, address);
492 	if (status < B_OK)
493 		return status;
494 
495 	fEndpointHash.Insert(endpoint);
496 
497 	return B_OK;
498 }
499 
500 
501 status_t
502 EndpointManager::Unbind(TCPEndpoint* endpoint)
503 {
504 	TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
505 	T(Unbind(endpoint));
506 
507 	if (endpoint == NULL || !endpoint->IsBound()) {
508 		TRACE(("  endpoint is unbound.\n"));
509 		return B_BAD_VALUE;
510 	}
511 
512 	MutexLocker _(fLock);
513 
514 	if (!fEndpointHash.Remove(endpoint))
515 		panic("bound endpoint %p not in hash!", endpoint);
516 
517 	fConnectionHash.Remove(endpoint);
518 
519 	(*endpoint->LocalAddress())->sa_len = 0;
520 
521 	return B_OK;
522 }
523 
524 
525 status_t
526 EndpointManager::ReplyWithReset(tcp_segment_header& segment, net_buffer* buffer)
527 {
528 	TRACE(("TCP: Sending RST...\n"));
529 
530 	net_buffer* reply = gBufferModule->create(512);
531 	if (reply == NULL)
532 		return B_NO_MEMORY;
533 
534 	AddressModule()->set_to(reply->source, buffer->destination);
535 	AddressModule()->set_to(reply->destination, buffer->source);
536 
537 	tcp_segment_header outSegment(TCP_FLAG_RESET);
538 	outSegment.sequence = 0;
539 	outSegment.acknowledge = 0;
540 	outSegment.advertised_window = 0;
541 	outSegment.urgent_offset = 0;
542 
543 	if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
544 		outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
545 		outSegment.acknowledge = segment.sequence + buffer->size;
546 		// TODO: Confirm:
547 		if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
548 			outSegment.acknowledge++;
549 	} else
550 		outSegment.sequence = segment.acknowledge;
551 
552 	status_t status = add_tcp_header(AddressModule(), outSegment, reply);
553 	if (status == B_OK)
554 		status = Domain()->module->send_data(NULL, reply);
555 
556 	if (status != B_OK)
557 		gBufferModule->free(reply);
558 
559 	return status;
560 }
561 
562 
563 void
564 EndpointManager::Dump() const
565 {
566 	kprintf("-------- TCP Domain %p ---------\n", this);
567 	kprintf("%10s %20s %20s %8s %8s %12s\n", "address", "local", "peer",
568 		"recv-q", "send-q", "state");
569 
570 	ConnectionTable::Iterator iterator = fConnectionHash.GetIterator();
571 
572 	while (iterator.HasNext()) {
573 		TCPEndpoint *endpoint = iterator.Next();
574 
575 		char localBuf[64], peerBuf[64];
576 		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
577 		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
578 
579 		kprintf("%p %20s %20s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
580 			endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
581 			name_for_state(endpoint->State()));
582 	}
583 }
584 
585