xref: /haiku/src/add-ons/kernel/network/protocols/tcp/EndpointManager.cpp (revision 14e3d1b5768e7110b3d5c0855833267409b71dbb)
1 /*
2  * Copyright 2006-2007, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  */
8 
9 
10 #include "EndpointManager.h"
11 #include "TCPEndpoint.h"
12 
13 #include <NetUtilities.h>
14 
15 #include <util/AutoLock.h>
16 
17 #include <KernelExport.h>
18 
19 
20 //#define TRACE_ENDPOINT_MANAGER
21 #ifdef TRACE_ENDPOINT_MANAGER
22 #	define TRACE(x) dprintf x
23 #else
24 #	define TRACE(x)
25 #endif
26 
27 
28 static const uint16 kLastReservedPort = 1023;
29 static const uint16 kFirstEphemeralPort = 40000;
30 
31 
32 ConnectionHashDefinition::ConnectionHashDefinition(EndpointManager *manager)
33 	: fManager(manager) {}
34 
35 size_t
36 ConnectionHashDefinition::HashKey(const KeyType &key) const
37 {
38 	return ConstSocketAddress(fManager->AddressModule(),
39 		key.first).HashPair(key.second);
40 }
41 
42 
43 size_t
44 ConnectionHashDefinition::Hash(TCPEndpoint *endpoint) const
45 {
46 	return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
47 }
48 
49 
50 bool
51 ConnectionHashDefinition::Compare(const KeyType &key,
52 	TCPEndpoint *endpoint) const
53 {
54 	return endpoint->LocalAddress().EqualTo(key.first, true)
55 		&& endpoint->PeerAddress().EqualTo(key.second, true);
56 }
57 
58 
59 HashTableLink<TCPEndpoint> *
60 ConnectionHashDefinition::GetLink(TCPEndpoint *endpoint) const
61 {
62 	return &endpoint->fConnectionHashLink;
63 }
64 
65 
66 size_t
67 EndpointHashDefinition::HashKey(uint16 port) const
68 {
69 	return port;
70 }
71 
72 
73 size_t
74 EndpointHashDefinition::Hash(TCPEndpoint *endpoint) const
75 {
76 	return endpoint->LocalAddress().Port();
77 }
78 
79 
80 bool
81 EndpointHashDefinition::Compare(uint16 port, TCPEndpoint *endpoint) const
82 {
83 	return endpoint->LocalAddress().Port() == port;
84 }
85 
86 
87 bool
88 EndpointHashDefinition::CompareValues(TCPEndpoint *first,
89 	TCPEndpoint *second) const
90 {
91 	return first->LocalAddress().Port() == second->LocalAddress().Port();
92 }
93 
94 
95 HashTableLink<TCPEndpoint> *
96 EndpointHashDefinition::GetLink(TCPEndpoint *endpoint) const
97 {
98 	return &endpoint->fEndpointHashLink;
99 }
100 
101 
102 EndpointManager::EndpointManager(net_domain *domain)
103 	: fDomain(domain), fConnectionHash(this)
104 {
105 	benaphore_init(&fLock, "endpoint manager");
106 }
107 
108 
109 EndpointManager::~EndpointManager()
110 {
111 	benaphore_destroy(&fLock);
112 }
113 
114 
115 status_t
116 EndpointManager::InitCheck() const
117 {
118 	if (fConnectionHash.InitCheck() < B_OK)
119 		return fConnectionHash.InitCheck();
120 
121 	if (fEndpointHash.InitCheck() < B_OK)
122 		return fEndpointHash.InitCheck();
123 
124 	if (fLock.sem < B_OK)
125 		return fLock.sem;
126 
127 	return B_OK;
128 }
129 
130 
131 //	#pragma mark - connections
132 
133 
134 /*!
135 	Returns the endpoint matching the connection.
136 	You must hold the manager's lock when calling this method.
137 */
138 TCPEndpoint *
139 EndpointManager::_LookupConnection(const sockaddr *local, const sockaddr *peer)
140 {
141 	return fConnectionHash.Lookup(std::make_pair(local, peer));
142 }
143 
144 
145 status_t
146 EndpointManager::SetConnection(TCPEndpoint *endpoint,
147 	const sockaddr *_local, const sockaddr *peer, const sockaddr *interfaceLocal)
148 {
149 	TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
150 
151 	BenaphoreLocker _(fLock);
152 
153 	SocketAddressStorage local(AddressModule());
154 	local.SetTo(_local);
155 
156 	if (local.IsEmpty(false)) {
157 		uint16 port = local.Port();
158 		local.SetTo(interfaceLocal);
159 		local.SetPort(port);
160 	}
161 
162 	if (_LookupConnection(*local, peer) != NULL)
163 		return EADDRINUSE;
164 
165 	endpoint->LocalAddress().SetTo(*local);
166 	endpoint->PeerAddress().SetTo(peer);
167 
168 	fConnectionHash.Insert(endpoint);
169 	return B_OK;
170 }
171 
172 
173 status_t
174 EndpointManager::SetPassive(TCPEndpoint *endpoint)
175 {
176 	BenaphoreLocker _(fLock);
177 
178 	if (!endpoint->IsBound()) {
179 		// if the socket is unbound first bind it to ephemeral
180 		SocketAddressStorage local(AddressModule());
181 		local.SetToEmpty();
182 
183 		status_t status = _BindToEphemeral(endpoint, *local);
184 		if (status < B_OK)
185 			return status;
186 	}
187 
188 	SocketAddressStorage passive(AddressModule());
189 	passive.SetToEmpty();
190 
191 	if (_LookupConnection(*endpoint->LocalAddress(), *passive))
192 		return EADDRINUSE;
193 
194 	endpoint->PeerAddress().SetTo(*passive);
195 	fConnectionHash.Insert(endpoint);
196 	return B_OK;
197 }
198 
199 
200 TCPEndpoint *
201 EndpointManager::FindConnection(sockaddr *local, sockaddr *peer)
202 {
203 	BenaphoreLocker _(fLock);
204 
205 	TCPEndpoint *endpoint = _LookupConnection(local, peer);
206 	if (endpoint != NULL) {
207 		TRACE(("TCP: Received packet corresponds to explicit endpoint %p\n", endpoint));
208 		return endpoint;
209 	}
210 
211 	// no explicit endpoint exists, check for wildcard endpoints
212 
213 	SocketAddressStorage wildcard(AddressModule());
214 	wildcard.SetToEmpty();
215 
216 	endpoint = _LookupConnection(local, *wildcard);
217 	if (endpoint != NULL) {
218 		TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint));
219 		return endpoint;
220 	}
221 
222 	SocketAddressStorage localWildcard(AddressModule());
223 	localWildcard.SetToEmpty();
224 	localWildcard.SetPort(AddressModule()->get_port(local));
225 
226 	endpoint = _LookupConnection(*localWildcard, *wildcard);
227 	if (endpoint != NULL) {
228 		TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint));
229 		return endpoint;
230 	}
231 
232 	// no matching endpoint exists
233 	TRACE(("TCP: no matching endpoint!\n"));
234 
235 	return NULL;
236 }
237 
238 
239 //	#pragma mark - endpoints
240 
241 
242 status_t
243 EndpointManager::Bind(TCPEndpoint *endpoint, const sockaddr *address)
244 {
245 	// TODO check the family:
246 	//
247 	// if (!AddressModule()->is_understandable(address))
248 	//	return EAFNOSUPPORT;
249 
250 	BenaphoreLocker _(fLock);
251 
252 	if (AddressModule()->get_port(address) == 0)
253 		return _BindToEphemeral(endpoint, address);
254 
255 	return _BindToAddress(endpoint, address);
256 }
257 
258 
259 status_t
260 EndpointManager::BindChild(TCPEndpoint *endpoint)
261 {
262 	BenaphoreLocker _(fLock);
263 	return _Bind(endpoint, *endpoint->LocalAddress());
264 }
265 
266 
267 status_t
268 EndpointManager::_BindToAddress(TCPEndpoint *endpoint, const sockaddr *_address)
269 {
270 	TRACE(("EndpointManager::BindToAddress(%p)\n", endpoint));
271 
272 	ConstSocketAddress address(AddressModule(), _address);
273 
274 	uint16 port = address.Port();
275 
276 	// TODO this check follows very typical UNIX semantics
277 	//      and generally should be improved.
278 	if (ntohs(port) <= kLastReservedPort && geteuid() != 0)
279 		return B_PERMISSION_DENIED;
280 
281 	EndpointTable::ValueIterator portUsers = fEndpointHash.Lookup(port);
282 
283 	while (portUsers.HasNext()) {
284 		TCPEndpoint *user = portUsers.Next();
285 
286 		if (user->LocalAddress().IsEmpty(false)
287 			|| address.EqualTo(*user->LocalAddress(), false)) {
288 			if ((endpoint->socket->options & SO_REUSEADDR) == 0)
289 				return EADDRINUSE;
290 			// TODO lock endpoint before retriving state?
291 			if (user->State() != TIME_WAIT && user->State() != CLOSED)
292 				return EADDRINUSE;
293 		}
294 	}
295 
296 	return _Bind(endpoint, *address);
297 }
298 
299 
300 status_t
301 EndpointManager::_BindToEphemeral(TCPEndpoint *endpoint,
302 	const sockaddr *address)
303 {
304 	TRACE(("EndpointManager::BindToEphemeral(%p)\n", endpoint));
305 
306 	uint32 max = kFirstEphemeralPort + 65536;
307 
308 	for (int32 i = 1; i < 5; i++) {
309 		// try to retrieve a more or less random port
310 		uint32 counter = kFirstEphemeralPort;
311 		uint32 step = i == 4 ? 1 : (system_time() & 0x1f) + 1;
312 
313 		while (counter < max) {
314 			uint16 port = counter & 0xffff;
315 			if (port <= kLastReservedPort)
316 				port += kLastReservedPort;
317 
318 			port = htons(port);
319 
320 			if (!fEndpointHash.Lookup(port).HasNext()) {
321 				SocketAddressStorage newAddress(AddressModule());
322 				newAddress.SetTo(address);
323 				newAddress.SetPort(port);
324 
325 				// found a port
326 				TRACE(("   EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint,
327 					AddressString(Domain(), *newAddress, true).Data()));
328 
329 				return _Bind(endpoint, *newAddress);
330 			}
331 
332 			counter += step;
333 		}
334 	}
335 
336 	// could not find a port!
337 	return EADDRINUSE;
338 }
339 
340 
341 status_t
342 EndpointManager::_Bind(TCPEndpoint *endpoint, const sockaddr *address)
343 {
344 	// Thus far we have checked if the Bind() is allowed
345 
346 	status_t status = endpoint->next->module->bind(endpoint->next, address);
347 	if (status < B_OK)
348 		return status;
349 
350 	fEndpointHash.Insert(endpoint);
351 
352 	return B_OK;
353 }
354 
355 
356 status_t
357 EndpointManager::Unbind(TCPEndpoint *endpoint)
358 {
359 	TRACE(("EndpointManager::Unbind(%p)\n", endpoint));
360 
361 	if (endpoint == NULL || !endpoint->IsBound()) {
362 		TRACE(("  endpoint is unbound.\n"));
363 		return B_BAD_VALUE;
364 	}
365 
366 	BenaphoreLocker _(fLock);
367 
368 	if (!fEndpointHash.Remove(endpoint))
369 		panic("bound endpoint %p not in hash!", endpoint);
370 
371 	fConnectionHash.Remove(endpoint);
372 
373 	(*endpoint->LocalAddress())->sa_len = 0;
374 
375 	return B_OK;
376 }
377 
378 
379 status_t
380 EndpointManager::ReplyWithReset(tcp_segment_header &segment,
381 	net_buffer *buffer)
382 {
383 	TRACE(("TCP: Sending RST...\n"));
384 
385 	net_buffer *reply = gBufferModule->create(512);
386 	if (reply == NULL)
387 		return B_NO_MEMORY;
388 
389 	AddressModule()->set_to(reply->source, buffer->destination);
390 	AddressModule()->set_to(reply->destination, buffer->source);
391 
392 	tcp_segment_header outSegment(TCP_FLAG_RESET);
393 	outSegment.sequence = 0;
394 	outSegment.acknowledge = 0;
395 	outSegment.advertised_window = 0;
396 	outSegment.urgent_offset = 0;
397 
398 	if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) == 0) {
399 		outSegment.flags |= TCP_FLAG_ACKNOWLEDGE;
400 		outSegment.acknowledge = segment.sequence + buffer->size;
401 	} else
402 		outSegment.sequence = segment.acknowledge;
403 
404 	status_t status = add_tcp_header(AddressModule(), outSegment, reply);
405 	if (status == B_OK)
406 		status = Domain()->module->send_data(NULL, reply);
407 
408 	if (status != B_OK)
409 		gBufferModule->free(reply);
410 
411 	return status;
412 }
413 
414 
415 void
416 EndpointManager::DumpEndpoints() const
417 {
418 	kprintf("-------- TCP Domain %p ---------\n", this);
419 	kprintf("%10s %20s %20s %8s %8s %12s\n", "address", "local", "peer",
420 		"recv-q", "send-q", "state");
421 
422 	ConnectionTable::Iterator it = fConnectionHash.GetIterator();
423 
424 	while (it.HasNext()) {
425 		TCPEndpoint *endpoint = it.Next();
426 
427 		char localBuf[64], peerBuf[64];
428 		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
429 		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
430 
431 		kprintf("%p %20s %20s %8lu %8lu %12s\n", endpoint, localBuf, peerBuf,
432 			endpoint->fReceiveQueue.Available(), endpoint->fSendQueue.Used(),
433 			name_for_state(endpoint->State()));
434 	}
435 }
436 
437