xref: /haiku/src/add-ons/kernel/network/protocols/tcp/EndpointManager.cpp (revision cfc3fa87da824bdf593eb8b817a83b6376e77935)
1 /*
2  * Copyright 2006-2007, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  *		Hugo Santos, hugosantos@gmail.com
8  */
9 
10 
11 #include "EndpointManager.h"
12 
13 #include <unistd.h>
14 
15 #include <KernelExport.h>
16 
17 #include <NetUtilities.h>
18 #include <util/AutoLock.h>
19 
20 #include "TCPEndpoint.h"
21 
22 
23 //#define TRACE_ENDPOINT_MANAGER
24 #ifdef TRACE_ENDPOINT_MANAGER
25 #	define TRACE(x) dprintf x
26 #else
27 #	define TRACE(x)
28 #endif
29 
30 
31 static const uint16 kLastReservedPort = 1023;
32 static const uint16 kFirstEphemeralPort = 40000;
33 
34 
35 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager *manager)
36 	:
37 	fManager(manager)
38 {
39 }
40 
41 
42 size_t
43 ConnectionHashDefinition::HashKey(const KeyType &key) const
44 {
45 	return ConstSocketAddress(fManager->AddressModule(),
46 		key.first).HashPair(key.second);
47 }
48 
49 
50 size_t
51 ConnectionHashDefinition::Hash(TCPEndpoint *endpoint) const
52 {
53 	return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
54 }
55 
56 
57 bool
58 ConnectionHashDefinition::Compare(const KeyType &key,
59 	TCPEndpoint *endpoint) const
60 {
61 	return endpoint->LocalAddress().EqualTo(key.first, true)
62 		&& endpoint->PeerAddress().EqualTo(key.second, true);
63 }
64 
65 
66 HashTableLink<TCPEndpoint> *
67 ConnectionHashDefinition::GetLink(TCPEndpoint *endpoint) const
68 {
69 	return &endpoint->fConnectionHashLink;
70 }
71 
72 
73 size_t
74 EndpointHashDefinition::HashKey(uint16 port) const
75 {
76 	return port;
77 }
78 
79 
80 size_t
81 EndpointHashDefinition::Hash(TCPEndpoint *endpoint) const
82 {
83 	return endpoint->LocalAddress().Port();
84 }
85 
86 
87 bool
88 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint *endpoint) const
89 {
90 	return endpoint->LocalAddress().Port() == port;
91 }
92 
93 
94 bool
95 EndpointHashDefinition::CompareValues(TCPEndpoint *first,
96 	TCPEndpoint *second) const
97 {
98 	return first->LocalAddress().Port() == second->LocalAddress().Port();
99 }
100 
101 
102 HashTableLink<TCPEndpoint> *
103 EndpointHashDefinition::GetLink(TCPEndpoint *endpoint) const
104 {
105 	return &endpoint->fEndpointHashLink;
106 }
107 
108 
109 EndpointManager::EndpointManager(net_domain *domain)
110 	: fDomain(domain), fConnectionHash(this)
111 {
112 	benaphore_init(&fLock, "endpoint manager");
113 }
114 
115 
116 EndpointManager::~EndpointManager()
117 {
118 	benaphore_destroy(&fLock);
119 }
120 
121 
122 status_t
123 EndpointManager::InitCheck() const
124 {
125 	if (fConnectionHash.InitCheck() < B_OK)
126 		return fConnectionHash.InitCheck();
127 
128 	if (fEndpointHash.InitCheck() < B_OK)
129 		return fEndpointHash.InitCheck();
130 
131 	if (fLock.sem < B_OK)
132 		return fLock.sem;
133 
134 	return B_OK;
135 }
136 
137 
138 //	#pragma mark - connections
139 
140 
141 /*!
142 	Returns the endpoint matching the connection.
143 	You must hold the manager's lock when calling this method.
144 */
145 TCPEndpoint *
146 EndpointManager::_LookupConnection(const sockaddr *local, const sockaddr *peer)
147 {
148 	return fConnectionHash.Lookup(std::make_pair(local, peer));
149 }
150 
151 
152 status_t
153 EndpointManager::SetConnection(TCPEndpoint *endpoint,
154 	const sockaddr *_local, const sockaddr *peer, const sockaddr *interfaceLocal)
155 {
156 	TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
157 
158 	BenaphoreLocker _(fLock);
159 
160 	SocketAddressStorage local(AddressModule());
161 	local.SetTo(_local);
162 
163 	if (local.IsEmpty(false)) {
164 		uint16 port = local.Port();
165 		local.SetTo(interfaceLocal);
166 		local.SetPort(port);
167 	}
168 
169 	if (_LookupConnection(*local, peer) != NULL)
170 		return EADDRINUSE;
171 
172 	endpoint->LocalAddress().SetTo(*local);
173 	endpoint->PeerAddress().SetTo(peer);
174 
175 	fConnectionHash.Insert(endpoint);
176 	return B_OK;
177 }
178 
179 
180 status_t
181 EndpointManager::SetPassive(TCPEndpoint *endpoint)
182 {
183 	BenaphoreLocker _(fLock);
184 
185 	if (!endpoint->IsBound()) {
186 		// if the socket is unbound first bind it to ephemeral
187 		SocketAddressStorage local(AddressModule());
188 		local.SetToEmpty();
189 
190 		status_t status = _BindToEphemeral(endpoint, *local);
191 		if (status < B_OK)
192 			return status;
193 	}
194 
195 	SocketAddressStorage passive(AddressModule());
196 	passive.SetToEmpty();
197 
198 	if (_LookupConnection(*endpoint->LocalAddress(), *passive))
199 		return EADDRINUSE;
200 
201 	endpoint->PeerAddress().SetTo(*passive);
202 	fConnectionHash.Insert(endpoint);
203 	return B_OK;
204 }
205 
206 
207 TCPEndpoint *
208 EndpointManager::FindConnection(sockaddr *local, sockaddr *peer)
209 {
210 	BenaphoreLocker _(fLock);
211 
212 	TCPEndpoint *endpoint = _LookupConnection(local, peer);
213 	if (endpoint != NULL) {
214 		TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", endpoint));
215 		return endpoint;
216 	}
217 
218 	// no explicit endpoint exists, check for wildcard endpoints
219 
220 	SocketAddressStorage wildcard(AddressModule());
221 	wildcard.SetToEmpty();
222 
223 	endpoint = _LookupConnection(local, *wildcard);
224 	if (endpoint != NULL) {
225 		TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint));
226 		return endpoint;
227 	}
228 
229 	SocketAddressStorage localWildcard(AddressModule());
230 	localWildcard.SetToEmpty();
231 	localWildcard.SetPort(AddressModule()->get_port(local));
232 
233 	endpoint = _LookupConnection(*localWildcard, *wildcard);
234 	if (endpoint != NULL) {
235 		TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint));
236 		return endpoint;
237 	}
238 
239 	// no matching endpoint exists
240 	TRACE(("TCP: no matching endpoint!\n"));
241 
242 	return NULL;
243 }
244 
245 
246 //	#pragma mark - endpoints
247 
248 
249 status_t
250 EndpointManager::Bind(TCPEndpoint *endpoint, const sockaddr *address)
251 {
252 	// TODO check the family:
253 	//
254 	// if (!AddressModule()->is_understandable(address))
255 	//	return EAFNOSUPPORT;
256 
257 	BenaphoreLocker _(fLock);
258 
259 	if (AddressModule()->get_port(address) == 0)
260 		return _BindToEphemeral(endpoint, address);
261 
262 	return _BindToAddress(endpoint, address);
263 }
264 
265 
266 status_t
267 EndpointManager::BindChild(TCPEndpoint *endpoint)
268 {
269 	BenaphoreLocker _(fLock);
270 	return _Bind(endpoint, *endpoint->LocalAddress());
271 }
272 
273 
274 status_t
275 EndpointManager::_BindToAddress(TCPEndpoint *endpoint, const sockaddr *_address)
276 {
277 	TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
278 
279 	ConstSocketAddress address(AddressModule(), _address);
280 
281 	uint16 port = address.Port();
282 
283 	// TODO this check follows very typical UNIX semantics
284 	//      and generally should be improved.
285 	if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
286 		return B_PERMISSION_DENIED;
287 
288 	EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
289 
290 	while (portUsers.HasNext()) {
291 		TCPEndpoint *user = portUsers.Next();
292 
293 		if (user->LocalAddress().IsEmpty(false)
294 			|| address.EqualTo(*user->LocalAddress(), false)) {
295 			if ((endpoint->socket->options & SO_REUSEADDR) == 0)
296 				return EADDRINUSE;
297 			// TODO lock endpoint before retriving state?
298 			if (user->State() != TIME_WAIT && user->State() != CLOSED)
299 				return EADDRINUSE;
300 		}
301 	}
302 
303 	return _Bind(endpoint, *address);
304 }
305 
306 
307 status_t
308 EndpointManager::_BindToEphemeral(TCPEndpoint *endpoint,
309 	const sockaddr *address)
310 {
311 	TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
312 
313 	uint32 max = kFirstEphemeralPort + 65536;
314 
315 	for (int32 i = 1; i < 5; i++) {
316 		// try to retrieve a more or less random port
317 		uint32 counter = kFirstEphemeralPort;
318 		uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
319 
320 		while (counter < max) {
321 			uint16 port = counter & 0xffff;
322 			if (port <= kLastReservedPort)
323 				port += kLastReservedPort;
324 
325 			port = htons(port);
326 
327 			if (!fEndpointHash.Lookup(port).HasNext()) {
328 				SocketAddressStorage newAddress(AddressModule());
329 				newAddress.SetTo(address);
330 				newAddress.SetPort(port);
331 
332 				// found a port
333 				TRACE(("   EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint,
334 					AddressString(Domain(), *newAddress, true).Data()));
335 
336 				return _Bind(endpoint, *newAddress);
337 			}
338 
339 			counter += step;
340 		}
341 	}
342 
343 	// could not find a port!
344 	return EADDRINUSE;
345 }
346 
347 
348 status_t
349 EndpointManager::_Bind(TCPEndpoint *endpoint, const sockaddr *address)
350 {
351 	// Thus far we have checked if the Bind() is allowed
352 
353 	status_t status = endpoint->next->module->bind(endpoint->next, address);
354 	if (status < B_OK)
355 		return status;
356 
357 	fEndpointHash.Insert(endpoint);
358 
359 	return B_OK;
360 }
361 
362 
363 status_t
364 EndpointManager::Unbind(TCPEndpoint *endpoint)
365 {
366 	TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
367 
368 	if (endpoint == NULL || !endpoint->IsBound()) {
369 		TRACE(("  endpoint is unbound.\n"));
370 		return B_BAD_VALUE;
371 	}
372 
373 	BenaphoreLocker _(fLock);
374 
375 	if (!fEndpointHash.Remove(endpoint))
376 		panic("bound endpoint %p not in hash!", endpoint);
377 
378 	fConnectionHash.Remove(endpoint);
379 
380 	(*endpoint->LocalAddress())->sa_len = 0;
381 
382 	return B_OK;
383 }
384 
385 
386 status_t
387 EndpointManager::ReplyWithReset(tcp_segment_header &segment,
388 	net_buffer *buffer)
389 {
390 	TRACE(("TCP: Sending RST...\n"));
391 
392 	net_buffer *reply = gBufferModule->create(512);
393 	if (reply == NULL)
394 		return B_NO_MEMORY;
395 
396 	AddressModule()->set_to(reply->source, buffer->destination);
397 	AddressModule()->set_to(reply->destination, buffer->source);
398 
399 	tcp_segment_header outSegment(TCP_FLAG_RESET);
400 	outSegment.sequence = 0;
401 	outSegment.acknowledge = 0;
402 	outSegment.advertised_window = 0;
403 	outSegment.urgent_offset = 0;
404 
405 	if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
406 		outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
407 		outSegment.acknowledge = segment.sequence + buffer->size;
408 		// TODO: Confirm:
409 		if ((segment.flags & (TCP_FLAG_SYNCHRONIZE | TCP_FLAG_FINISH)) != 0)
410 			outSegment.acknowledge++;
411 	} else
412 		outSegment.sequence = segment.acknowledge;
413 
414 	status_t status = add_tcp_header(AddressModule(), outSegment, reply);
415 	if (status == B_OK)
416 		status = Domain()->module->send_data(NULL, reply);
417 
418 	if (status != B_OK)
419 		gBufferModule->free(reply);
420 
421 	return status;
422 }
423 
424 
425 void
426 EndpointManager::DumpEndpoints() const
427 {
428 	kprintf("-------- TCP Domain %p ---------\n", this);
429 	kprintf("%10s %20s %20s %8s %8s %12s\n", "address", "local", "peer",
430 		"recv-q", "send-q", "state");
431 
432 	ConnectionTable::Iterator it = fConnectionHash.GetIterator();
433 
434 	while (it.HasNext()) {
435 		TCPEndpoint *endpoint = it.Next();
436 
437 		char localBuf[64], peerBuf[64];
438 		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
439 		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
440 
441 		kprintf("%p %20s %20s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
442 			endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
443 			name_for_state(endpoint->State()));
444 	}
445 }
446 
447