xref: /haiku/src/add-ons/kernel/network/protocols/udp/udp.cpp (revision e0bc2fcce4c5ceb7b57d978b752cda01639d0098)
1 /*
2  * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Oliver Tappe, zooey@hirschkaefer.de
7  *		Hugo Santos, hugosantos@gmail.com
8  */
9 
10 
11 #include <net_buffer.h>
12 #include <net_datalink.h>
13 #include <net_protocol.h>
14 #include <net_stack.h>
15 
16 #include <lock.h>
17 #include <util/AutoLock.h>
18 #include <util/DoublyLinkedList.h>
19 #include <util/OpenHashTable.h>
20 
21 #include <KernelExport.h>
22 
23 #include <NetBufferUtilities.h>
24 #include <NetUtilities.h>
25 #include <ProtocolUtilities.h>
26 
27 #include <algorithm>
28 #include <netinet/in.h>
29 #include <netinet/ip.h>
30 #include <new>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <utility>
34 
35 
36 // NOTE the locking protocol dictates that we must hold UdpDomainSupport's
37 //      lock before holding a child UdpEndpoint's lock. This restriction
38 //      is dictated by the receive path as blind access to the endpoint
39 //      hash is required when holding the DomainSupport's lock.
40 
41 
42 //#define TRACE_UDP
43 #ifdef TRACE_UDP
44 #	define TRACE_BLOCK(x) dump_block x
45 // do not remove the space before ', ##args' if you want this
46 // to compile with gcc 2.95
47 #	define TRACE_EP(format, args...)	dprintf("UDP [%llu] %p " format "\n", \
48 		system_time(), this , ##args)
49 #	define TRACE_EPM(format, args...)	dprintf("UDP [%llu] " format "\n", \
50 		system_time() , ##args)
51 #	define TRACE_DOMAIN(format, args...)	dprintf("UDP [%llu] (%d) " format \
52 		"\n", system_time(), Domain()->family , ##args)
53 #else
54 #	define TRACE_BLOCK(x)
55 #	define TRACE_EP(args...)	do { } while (0)
56 #	define TRACE_EPM(args...)	do { } while (0)
57 #	define TRACE_DOMAIN(args...)	do { } while (0)
58 #endif
59 
60 
61 struct udp_header {
62 	uint16 source_port;
63 	uint16 destination_port;
64 	uint16 udp_length;
65 	uint16 udp_checksum;
66 } _PACKED;
67 
68 
69 typedef NetBufferField<uint16, offsetof(udp_header, udp_checksum)>
70 	UDPChecksumField;
71 
72 class UdpDomainSupport;
73 
74 class UdpEndpoint : public net_protocol, public DatagramSocket<> {
75 public:
76 	UdpEndpoint(net_socket *socket);
77 
78 	status_t				Bind(const sockaddr *newAddr);
79 	status_t				Unbind(sockaddr *newAddr);
80 	status_t				Connect(const sockaddr *newAddr);
81 
82 	status_t				Open();
83 	status_t				Close();
84 	status_t				Free();
85 
86 	status_t				SendRoutedData(net_buffer *buffer,
87 								net_route *route);
88 	status_t				SendData(net_buffer *buffer);
89 
90 	ssize_t					BytesAvailable();
91 	status_t				FetchData(size_t numBytes, uint32 flags,
92 								net_buffer **_buffer);
93 
94 	status_t				StoreData(net_buffer *buffer);
95 	status_t				DeliverData(net_buffer *buffer);
96 
97 	// only the domain support will change/check the Active flag so
98 	// we don't really need to protect it with the socket lock.
99 	bool					IsActive() const { return fActive; }
100 	void					SetActive(bool newValue) { fActive = newValue; }
101 
102 	UdpEndpoint				*&HashTableLink() { return fLink; }
103 
104 private:
105 	UdpDomainSupport		*fManager;
106 	bool					fActive;
107 								// an active UdpEndpoint is part of the
108 								// endpoint hash (and it is bound and optionally
109 								// connected)
110 
111 	UdpEndpoint				*fLink;
112 };
113 
114 
115 class UdpDomainSupport;
116 
117 struct UdpHashDefinition {
118 	typedef net_address_module_info ParentType;
119 	typedef std::pair<const sockaddr *, const sockaddr *> KeyType;
120 	typedef UdpEndpoint ValueType;
121 
122 	UdpHashDefinition(net_address_module_info *_module)
123 		: module(_module) {}
124 	UdpHashDefinition(const UdpHashDefinition& definition)
125 		: module(definition.module) {}
126 
127 	size_t HashKey(const KeyType &key) const
128 	{
129 		return _Mix(module->hash_address_pair(key.first, key.second));
130 	}
131 
132 	size_t Hash(UdpEndpoint *endpoint) const
133 	{
134 		return _Mix(endpoint->LocalAddress().HashPair(
135 			*endpoint->PeerAddress()));
136 	}
137 
138 	static size_t _Mix(size_t hash)
139 	{
140 		// move the bits into the relevant range (as defined by kNumHashBuckets)
141 		return (hash & 0x000007FF) ^ (hash & 0x003FF800) >> 11
142 			^ (hash & 0xFFC00000UL) >> 22;
143 	}
144 
145 	bool Compare(const KeyType &key, UdpEndpoint *endpoint) const
146 	{
147 		return endpoint->LocalAddress().EqualTo(key.first, true)
148 			&& endpoint->PeerAddress().EqualTo(key.second, true);
149 	}
150 
151 	UdpEndpoint *&GetLink(UdpEndpoint *endpoint) const
152 	{
153 		return endpoint->HashTableLink();
154 	}
155 
156 	net_address_module_info *module;
157 };
158 
159 
160 class UdpDomainSupport : public DoublyLinkedListLinkImpl<UdpDomainSupport> {
161 public:
162 	UdpDomainSupport(net_domain *domain);
163 	~UdpDomainSupport();
164 
165 	status_t Init();
166 
167 	net_domain *Domain() const { return fDomain; }
168 
169 	void Ref() { fEndpointCount++; }
170 	bool Put() { fEndpointCount--; return fEndpointCount == 0; }
171 
172 	status_t DemuxIncomingBuffer(net_buffer* buffer);
173 	status_t DeliverError(status_t error, net_buffer* buffer);
174 
175 	status_t BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
176 	status_t ConnectEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
177 	status_t UnbindEndpoint(UdpEndpoint *endpoint);
178 
179 	void DumpEndpoints() const;
180 
181 private:
182 	status_t _BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
183 	status_t _Bind(UdpEndpoint *endpoint, const sockaddr *address);
184 	status_t _BindToEphemeral(UdpEndpoint *endpoint, const sockaddr *address);
185 	status_t _FinishBind(UdpEndpoint *endpoint, const sockaddr *address);
186 
187 	UdpEndpoint *_FindActiveEndpoint(const sockaddr *ourAddress,
188 		const sockaddr *peerAddress);
189 	status_t _DemuxBroadcast(net_buffer *buffer);
190 	status_t _DemuxUnicast(net_buffer *buffer);
191 
192 	uint16 _GetNextEphemeral();
193 	UdpEndpoint *_EndpointWithPort(uint16 port) const;
194 
195 	net_address_module_info *AddressModule() const
196 		{ return fDomain->address_module; }
197 
198 	typedef BOpenHashTable<UdpHashDefinition, false> EndpointTable;
199 
200 	mutex			fLock;
201 	net_domain		*fDomain;
202 	uint16			fLastUsedEphemeral;
203 	EndpointTable	fActiveEndpoints;
204 	uint32			fEndpointCount;
205 
206 	static const uint16		kFirst = 49152;
207 	static const uint16		kLast = 65535;
208 	static const uint32		kNumHashBuckets = 0x800;
209 							// if you change this, adjust the shifting in
210 							// Hash() accordingly!
211 };
212 
213 
214 typedef DoublyLinkedList<UdpDomainSupport> UdpDomainList;
215 
216 
217 class UdpEndpointManager {
218 public:
219 								UdpEndpointManager();
220 								~UdpEndpointManager();
221 
222 			status_t			InitCheck() const;
223 
224 			status_t			ReceiveData(net_buffer* buffer);
225 			status_t			ReceiveError(status_t error,
226 									net_buffer* buffer);
227 			status_t			Deframe(net_buffer* buffer);
228 
229 			UdpDomainSupport*	OpenEndpoint(UdpEndpoint* endpoint);
230 			status_t			FreeEndpoint(UdpDomainSupport* domain);
231 
232 	static	int					DumpEndpoints(int argc, char *argv[]);
233 
234 private:
235 			UdpDomainSupport*	_GetDomain(net_domain *domain, bool create);
236 			UdpDomainSupport*	_GetDomain(net_buffer* buffer);
237 
238 			mutex				fLock;
239 			status_t			fStatus;
240 			UdpDomainList		fDomains;
241 };
242 
243 
244 static UdpEndpointManager *sUdpEndpointManager;
245 
246 net_buffer_module_info *gBufferModule;
247 net_datalink_module_info *gDatalinkModule;
248 net_stack_module_info *gStackModule;
249 net_socket_module_info *gSocketModule;
250 
251 
252 // #pragma mark -
253 
254 
255 UdpDomainSupport::UdpDomainSupport(net_domain *domain)
256 	:
257 	fDomain(domain),
258 	fActiveEndpoints(domain->address_module),
259 	fEndpointCount(0)
260 {
261 	mutex_init(&fLock, "udp domain");
262 
263 	fLastUsedEphemeral = kFirst + rand() % (kLast - kFirst);
264 }
265 
266 
267 UdpDomainSupport::~UdpDomainSupport()
268 {
269 	mutex_destroy(&fLock);
270 }
271 
272 
273 status_t
274 UdpDomainSupport::Init()
275 {
276 	return fActiveEndpoints.Init(kNumHashBuckets);
277 }
278 
279 
280 status_t
281 UdpDomainSupport::DemuxIncomingBuffer(net_buffer *buffer)
282 {
283 	// NOTE: multicast is delivered directly to the endpoint
284 	MutexLocker _(fLock);
285 
286 	if ((buffer->flags & MSG_BCAST) != 0)
287 		return _DemuxBroadcast(buffer);
288 	if ((buffer->flags & MSG_MCAST) != 0)
289 		return B_ERROR;
290 
291 	return _DemuxUnicast(buffer);
292 }
293 
294 
295 status_t
296 UdpDomainSupport::DeliverError(status_t error, net_buffer* buffer)
297 {
298 	if ((buffer->flags & (MSG_BCAST | MSG_MCAST)) != 0)
299 		return B_ERROR;
300 
301 	MutexLocker _(fLock);
302 
303 	// Forward the error to the socket
304 	UdpEndpoint* endpoint = _FindActiveEndpoint(buffer->source,
305 		buffer->destination);
306 	if (endpoint != NULL) {
307 		gSocketModule->notify(endpoint->Socket(), B_SELECT_ERROR, error);
308 		endpoint->NotifyOne();
309 	}
310 
311 	gBufferModule->free(buffer);
312 	return B_OK;
313 }
314 
315 
316 status_t
317 UdpDomainSupport::BindEndpoint(UdpEndpoint *endpoint,
318 	const sockaddr *address)
319 {
320 	if (!AddressModule()->is_same_family(address))
321 		return EAFNOSUPPORT;
322 
323 	MutexLocker _(fLock);
324 
325 	if (endpoint->IsActive())
326 		return EINVAL;
327 
328 	return _BindEndpoint(endpoint, address);
329 }
330 
331 
332 status_t
333 UdpDomainSupport::ConnectEndpoint(UdpEndpoint *endpoint,
334 	const sockaddr *address)
335 {
336 	MutexLocker _(fLock);
337 
338 	if (endpoint->IsActive()) {
339 		fActiveEndpoints.Remove(endpoint);
340 		endpoint->SetActive(false);
341 	}
342 
343 	if (address->sa_family == AF_UNSPEC) {
344 		// [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect",
345 		// so we reset the peer address:
346 		endpoint->PeerAddress().SetToEmpty();
347 	} else {
348 		if (!AddressModule()->is_same_family(address))
349 			return EAFNOSUPPORT;
350 
351 		// consider destination address INADDR_ANY as INADDR_LOOPBACK
352 		sockaddr_storage _address;
353 		if (AddressModule()->is_empty_address(address, false)) {
354 			AddressModule()->get_loopback_address((sockaddr *)&_address);
355 			// for IPv4 and IPv6 the port is at the same offset
356 			((sockaddr_in&)_address).sin_port
357 				= ((sockaddr_in *)address)->sin_port;
358 			address = (sockaddr *)&_address;
359 		}
360 
361 		status_t status = endpoint->PeerAddress().SetTo(address);
362 		if (status < B_OK)
363 			return status;
364 		struct net_route *routeToDestination
365 			= gDatalinkModule->get_route(fDomain, address);
366 		if (routeToDestination) {
367 			status = endpoint->LocalAddress().SetTo(
368 				routeToDestination->interface_address->local);
369 			gDatalinkModule->put_route(fDomain, routeToDestination);
370 			if (status < B_OK)
371 				return status;
372 		}
373 	}
374 
375 	// we need to activate no matter whether or not we have just disconnected,
376 	// as calling connect() always triggers an implicit bind():
377 	return _BindEndpoint(endpoint, *endpoint->LocalAddress());
378 }
379 
380 
381 status_t
382 UdpDomainSupport::UnbindEndpoint(UdpEndpoint *endpoint)
383 {
384 	MutexLocker _(fLock);
385 
386 	if (endpoint->IsActive())
387 		fActiveEndpoints.Remove(endpoint);
388 
389 	endpoint->SetActive(false);
390 
391 	return B_OK;
392 }
393 
394 
395 void
396 UdpDomainSupport::DumpEndpoints() const
397 {
398 	kprintf("-------- UDP Domain %p ---------\n", this);
399 	kprintf("%10s %20s %20s %8s\n", "address", "local", "peer", "recv-q");
400 
401 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
402 
403 	while (it.HasNext()) {
404 		UdpEndpoint *endpoint = it.Next();
405 
406 		char localBuf[64], peerBuf[64];
407 		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
408 		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
409 
410 		kprintf("%p %20s %20s %8lu\n", endpoint, localBuf, peerBuf,
411 			endpoint->AvailableData());
412 	}
413 }
414 
415 
416 status_t
417 UdpDomainSupport::_BindEndpoint(UdpEndpoint *endpoint,
418 	const sockaddr *address)
419 {
420 	if (AddressModule()->get_port(address) == 0)
421 		return _BindToEphemeral(endpoint, address);
422 
423 	return _Bind(endpoint, address);
424 }
425 
426 
427 status_t
428 UdpDomainSupport::_Bind(UdpEndpoint *endpoint, const sockaddr *address)
429 {
430 	int socketOptions = endpoint->Socket()->options;
431 
432 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
433 
434 	// Iterate over all active UDP-endpoints and check if the requested bind
435 	// is allowed (see figure 22.24 in [Stevens - TCP2, p735]):
436 	TRACE_DOMAIN("CheckBindRequest() for %s...", AddressString(fDomain,
437 		address, true).Data());
438 
439 	while (it.HasNext()) {
440 		UdpEndpoint *otherEndpoint = it.Next();
441 
442 		TRACE_DOMAIN("  ...checking endpoint %p (port=%u)...", otherEndpoint,
443 			ntohs(otherEndpoint->LocalAddress().Port()));
444 
445 		if (otherEndpoint->LocalAddress().EqualPorts(address)) {
446 			// port is already bound, SO_REUSEADDR or SO_REUSEPORT is required:
447 			if ((otherEndpoint->Socket()->options
448 					& (SO_REUSEADDR | SO_REUSEPORT)) == 0
449 				|| (socketOptions & (SO_REUSEADDR | SO_REUSEPORT)) == 0)
450 				return EADDRINUSE;
451 
452 			// if both addresses are the same, SO_REUSEPORT is required:
453 			if (otherEndpoint->LocalAddress().EqualTo(address, false)
454 				&& ((otherEndpoint->Socket()->options & SO_REUSEPORT) == 0
455 					|| (socketOptions & SO_REUSEPORT) == 0))
456 				return EADDRINUSE;
457 		}
458 	}
459 
460 	return _FinishBind(endpoint, address);
461 }
462 
463 
464 status_t
465 UdpDomainSupport::_BindToEphemeral(UdpEndpoint *endpoint,
466 	const sockaddr *address)
467 {
468 	SocketAddressStorage newAddress(AddressModule());
469 	status_t status = newAddress.SetTo(address);
470 	if (status < B_OK)
471 		return status;
472 
473 	uint16 allocedPort = _GetNextEphemeral();
474 	if (allocedPort == 0)
475 		return ENOBUFS;
476 
477 	newAddress.SetPort(htons(allocedPort));
478 
479 	return _FinishBind(endpoint, *newAddress);
480 }
481 
482 
483 status_t
484 UdpDomainSupport::_FinishBind(UdpEndpoint *endpoint, const sockaddr *address)
485 {
486 	status_t status = endpoint->next->module->bind(endpoint->next, address);
487 	if (status < B_OK)
488 		return status;
489 
490 	fActiveEndpoints.Insert(endpoint);
491 	endpoint->SetActive(true);
492 
493 	return B_OK;
494 }
495 
496 
497 UdpEndpoint *
498 UdpDomainSupport::_FindActiveEndpoint(const sockaddr *ourAddress,
499 	const sockaddr *peerAddress)
500 {
501 	ASSERT_LOCKED_MUTEX(&fLock);
502 
503 	TRACE_DOMAIN("finding Endpoint for %s <- %s",
504 		AddressString(fDomain, ourAddress, true).Data(),
505 		AddressString(fDomain, peerAddress, true).Data());
506 
507 	return fActiveEndpoints.Lookup(std::make_pair(ourAddress, peerAddress));
508 }
509 
510 
511 status_t
512 UdpDomainSupport::_DemuxBroadcast(net_buffer *buffer)
513 {
514 	sockaddr *peerAddr = buffer->source;
515 	sockaddr *broadcastAddr = buffer->destination;
516 	sockaddr *mask = NULL;
517 	if (buffer->interface_address != NULL)
518 		mask = (sockaddr *)buffer->interface_address->mask;
519 
520 	TRACE_DOMAIN("_DemuxBroadcast(%p)", buffer);
521 
522 	uint16 incomingPort = AddressModule()->get_port(broadcastAddr);
523 
524 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
525 
526 	while (it.HasNext()) {
527 		UdpEndpoint *endpoint = it.Next();
528 
529 		TRACE_DOMAIN("  _DemuxBroadcast(): checking endpoint %s...",
530 			AddressString(fDomain, *endpoint->LocalAddress(), true).Data());
531 
532 		if (endpoint->LocalAddress().Port() != incomingPort) {
533 			// ports don't match, so we do not dispatch to this endpoint...
534 			continue;
535 		}
536 
537 		if (!endpoint->PeerAddress().IsEmpty(true)) {
538 			// endpoint is connected to a specific destination, we check if
539 			// this datagram is from there:
540 			if (!endpoint->PeerAddress().EqualTo(peerAddr, true)) {
541 				// no, datagram is from another peer, so we do not dispatch to
542 				// this endpoint...
543 				continue;
544 			}
545 		}
546 
547 		if (endpoint->LocalAddress().MatchMasked(broadcastAddr, mask)
548 			|| endpoint->LocalAddress().IsEmpty(false)) {
549 			// address matches, dispatch to this endpoint:
550 			endpoint->StoreData(buffer);
551 		}
552 	}
553 
554 	return B_OK;
555 }
556 
557 
558 status_t
559 UdpDomainSupport::_DemuxUnicast(net_buffer* buffer)
560 {
561 	TRACE_DOMAIN("_DemuxUnicast(%p)", buffer);
562 
563 	const sockaddr* localAddress = buffer->destination;
564 	const sockaddr* peerAddress = buffer->source;
565 
566 	// look for full (most special) match:
567 	UdpEndpoint* endpoint = _FindActiveEndpoint(localAddress, peerAddress);
568 	if (endpoint == NULL) {
569 		// look for endpoint matching local address & port:
570 		endpoint = _FindActiveEndpoint(localAddress, NULL);
571 		if (endpoint == NULL) {
572 			// look for endpoint matching peer address & port and local port:
573 			SocketAddressStorage local(AddressModule());
574 			local.SetToEmpty();
575 			local.SetPort(AddressModule()->get_port(localAddress));
576 			endpoint = _FindActiveEndpoint(*local, peerAddress);
577 			if (endpoint == NULL) {
578 				// last chance: look for endpoint matching local port only:
579 				endpoint = _FindActiveEndpoint(*local, NULL);
580 			}
581 		}
582 	}
583 
584 	if (endpoint == NULL) {
585 		TRACE_DOMAIN("_DemuxUnicast(%p) - no matching endpoint found!", buffer);
586 		return B_NAME_NOT_FOUND;
587 	}
588 
589 	endpoint->StoreData(buffer);
590 	return B_OK;
591 }
592 
593 
594 uint16
595 UdpDomainSupport::_GetNextEphemeral()
596 {
597 	uint16 stop, curr;
598 	if (fLastUsedEphemeral < kLast) {
599 		stop = fLastUsedEphemeral;
600 		curr = fLastUsedEphemeral + 1;
601 	} else {
602 		stop = kLast;
603 		curr = kFirst;
604 	}
605 
606 	TRACE_DOMAIN("_GetNextEphemeral(), last %hu, curr %hu, stop %hu",
607 		fLastUsedEphemeral, curr, stop);
608 
609 	// TODO: a free list could be used to avoid the impact of these two
610 	//        nested loops most of the time... let's see how bad this really is
611 	for (; curr != stop; curr = (curr < kLast) ? (curr + 1) : kFirst) {
612 		TRACE_DOMAIN("  _GetNextEphemeral(): trying port %hu...", curr);
613 
614 		if (_EndpointWithPort(htons(curr)) == NULL) {
615 			TRACE_DOMAIN("  _GetNextEphemeral(): ...using port %hu", curr);
616 			fLastUsedEphemeral = curr;
617 			return curr;
618 		}
619 	}
620 
621 	return 0;
622 }
623 
624 
625 UdpEndpoint *
626 UdpDomainSupport::_EndpointWithPort(uint16 port) const
627 {
628 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
629 
630 	while (it.HasNext()) {
631 		UdpEndpoint *endpoint = it.Next();
632 		if (endpoint->LocalAddress().Port() == port)
633 			return endpoint;
634 	}
635 
636 	return NULL;
637 }
638 
639 
640 // #pragma mark -
641 
642 
643 UdpEndpointManager::UdpEndpointManager()
644 {
645 	mutex_init(&fLock, "UDP endpoints");
646 	fStatus = B_OK;
647 }
648 
649 
650 UdpEndpointManager::~UdpEndpointManager()
651 {
652 	mutex_destroy(&fLock);
653 }
654 
655 
656 status_t
657 UdpEndpointManager::InitCheck() const
658 {
659 	return fStatus;
660 }
661 
662 
663 int
664 UdpEndpointManager::DumpEndpoints(int argc, char *argv[])
665 {
666 	UdpDomainList::Iterator it = sUdpEndpointManager->fDomains.GetIterator();
667 
668 	while (it.HasNext())
669 		it.Next()->DumpEndpoints();
670 
671 	return 0;
672 }
673 
674 
675 // #pragma mark - inbound
676 
677 
678 status_t
679 UdpEndpointManager::ReceiveData(net_buffer *buffer)
680 {
681 	TRACE_EPM("ReceiveData(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
682 
683 	UdpDomainSupport* domainSupport = _GetDomain(buffer);
684 	if (domainSupport == NULL) {
685 		// we don't instantiate domain supports in the receiving path, as
686 		// we are only interested in delivering data to existing sockets.
687 		return B_ERROR;
688 	}
689 
690 	status_t status = Deframe(buffer);
691 	if (status != B_OK)
692 		return status;
693 
694 	status = domainSupport->DemuxIncomingBuffer(buffer);
695 	if (status != B_OK) {
696 		TRACE_EPM("  ReceiveData(): no endpoint.");
697 		// Send port unreachable error
698 		domainSupport->Domain()->module->error_reply(NULL, buffer,
699 			B_NET_ERROR_UNREACH_PORT, NULL);
700 		return B_ERROR;
701 	}
702 
703 	gBufferModule->free(buffer);
704 	return B_OK;
705 }
706 
707 
708 status_t
709 UdpEndpointManager::ReceiveError(status_t error, net_buffer* buffer)
710 {
711 	TRACE_EPM("ReceiveError(code %" B_PRId32 " %p [%" B_PRIu32 " bytes])",
712 		error, buffer, buffer->size);
713 
714 	// We only really need the port information
715 	if (buffer->size < 4)
716 		return B_BAD_VALUE;
717 
718 	UdpDomainSupport* domainSupport = _GetDomain(buffer);
719 	if (domainSupport == NULL) {
720 		// we don't instantiate domain supports in the receiving path, as
721 		// we are only interested in delivering data to existing sockets.
722 		return B_ERROR;
723 	}
724 
725 	// Deframe the buffer manually, as we usually only get 8 bytes from the
726 	// original packet
727 	udp_header header;
728 	if (gBufferModule->read(buffer, 0, &header,
729 			std::min(buffer->size, sizeof(udp_header))) != B_OK)
730 		return B_BAD_VALUE;
731 
732 	net_domain* domain = buffer->interface_address->domain;
733 	net_address_module_info* addressModule = domain->address_module;
734 
735 	SocketAddress source(addressModule, buffer->source);
736 	SocketAddress destination(addressModule, buffer->destination);
737 
738 	source.SetPort(header.source_port);
739 	destination.SetPort(header.destination_port);
740 
741 	status_t status = domainSupport->DeliverError(error, buffer);
742 	if (status != B_OK)
743 		return status;
744 
745 	gBufferModule->free(buffer);
746 	return B_OK;
747 }
748 
749 
750 status_t
751 UdpEndpointManager::Deframe(net_buffer *buffer)
752 {
753 	TRACE_EPM("Deframe(%p [%ld bytes])", buffer, buffer->size);
754 
755 	NetBufferHeaderReader<udp_header> bufferHeader(buffer);
756 	if (bufferHeader.Status() < B_OK)
757 		return bufferHeader.Status();
758 
759 	udp_header &header = bufferHeader.Data();
760 
761 	if (buffer->interface_address == NULL
762 		|| buffer->interface_address->domain == NULL) {
763 		TRACE_EPM("  Deframe(): UDP packed dropped as there was no domain "
764 			"specified (interface address %p).", buffer->interface_address);
765 		return B_BAD_VALUE;
766 	}
767 
768 	net_domain *domain = buffer->interface_address->domain;
769 	net_address_module_info *addressModule = domain->address_module;
770 
771 	SocketAddress source(addressModule, buffer->source);
772 	SocketAddress destination(addressModule, buffer->destination);
773 
774 	source.SetPort(header.source_port);
775 	destination.SetPort(header.destination_port);
776 
777 	TRACE_EPM("  Deframe(): data from %s to %s", source.AsString(true).Data(),
778 		destination.AsString(true).Data());
779 
780 	uint16 udpLength = ntohs(header.udp_length);
781 	if (udpLength > buffer->size) {
782 		TRACE_EPM("  Deframe(): buffer is too short, expected %hu.",
783 			udpLength);
784 		return B_MISMATCHED_VALUES;
785 	}
786 
787 	if (buffer->size > udpLength)
788 		gBufferModule->trim(buffer, udpLength);
789 
790 	if (header.udp_checksum != 0) {
791 		// check UDP-checksum (simulating a so-called "pseudo-header"):
792 		uint16 sum = Checksum::PseudoHeader(addressModule, gBufferModule,
793 			buffer, IPPROTO_UDP);
794 		if (sum != 0) {
795 			TRACE_EPM("  Deframe(): bad checksum 0x%hx.", sum);
796 			return B_BAD_VALUE;
797 		}
798 	}
799 
800 	bufferHeader.Remove();
801 		// remove UDP-header from buffer before passing it on
802 
803 	return B_OK;
804 }
805 
806 
807 UdpDomainSupport *
808 UdpEndpointManager::OpenEndpoint(UdpEndpoint *endpoint)
809 {
810 	MutexLocker _(fLock);
811 
812 	UdpDomainSupport *domain = _GetDomain(endpoint->Domain(), true);
813 	if (domain)
814 		domain->Ref();
815 	return domain;
816 }
817 
818 
819 status_t
820 UdpEndpointManager::FreeEndpoint(UdpDomainSupport *domain)
821 {
822 	MutexLocker _(fLock);
823 
824 	if (domain->Put()) {
825 		fDomains.Remove(domain);
826 		delete domain;
827 	}
828 
829 	return B_OK;
830 }
831 
832 
833 // #pragma mark -
834 
835 
836 UdpDomainSupport *
837 UdpEndpointManager::_GetDomain(net_domain *domain, bool create)
838 {
839 	UdpDomainList::Iterator it = fDomains.GetIterator();
840 
841 	// TODO convert this into a Hashtable or install per-domain
842 	//      receiver handlers that forward the requests to the
843 	//      appropriate DemuxIncomingBuffer(). For instance, while
844 	//      being constructed UdpDomainSupport could call
845 	//      register_domain_receiving_protocol() with the right
846 	//      family.
847 	while (it.HasNext()) {
848 		UdpDomainSupport *domainSupport = it.Next();
849 		if (domainSupport->Domain() == domain)
850 			return domainSupport;
851 	}
852 
853 	if (!create)
854 		return NULL;
855 
856 	UdpDomainSupport *domainSupport =
857 		new (std::nothrow) UdpDomainSupport(domain);
858 	if (domainSupport == NULL || domainSupport->Init() < B_OK) {
859 		delete domainSupport;
860 		return NULL;
861 	}
862 
863 	fDomains.Add(domainSupport);
864 	return domainSupport;
865 }
866 
867 
868 UdpDomainSupport*
869 UdpEndpointManager::_GetDomain(net_buffer* buffer)
870 {
871 	if (buffer->interface_address == NULL)
872 		return NULL;
873 
874 	MutexLocker _(fLock);
875 	return _GetDomain(buffer->interface_address->domain, false);
876 		// TODO: we don't want to hold to the manager's lock during the
877 		// whole RX path, we may not hold an endpoint's lock with the
878 		// manager lock held.
879 		// But we should increase the domain's refcount here.
880 }
881 
882 
883 // #pragma mark -
884 
885 
886 UdpEndpoint::UdpEndpoint(net_socket *socket)
887 	: DatagramSocket<>("udp endpoint", socket), fActive(false) {}
888 
889 
890 // #pragma mark - activation
891 
892 
893 status_t
894 UdpEndpoint::Bind(const sockaddr *address)
895 {
896 	TRACE_EP("Bind(%s)", AddressString(Domain(), address, true).Data());
897 	return fManager->BindEndpoint(this, address);
898 }
899 
900 
901 status_t
902 UdpEndpoint::Unbind(sockaddr *address)
903 {
904 	TRACE_EP("Unbind()");
905 	return fManager->UnbindEndpoint(this);
906 }
907 
908 
909 status_t
910 UdpEndpoint::Connect(const sockaddr *address)
911 {
912 	TRACE_EP("Connect(%s)", AddressString(Domain(), address, true).Data());
913 	return fManager->ConnectEndpoint(this, address);
914 }
915 
916 
917 status_t
918 UdpEndpoint::Open()
919 {
920 	TRACE_EP("Open()");
921 
922 	AutoLocker _(fLock);
923 
924 	status_t status = ProtocolSocket::Open();
925 	if (status < B_OK)
926 		return status;
927 
928 	fManager = sUdpEndpointManager->OpenEndpoint(this);
929 	if (fManager == NULL)
930 		return EAFNOSUPPORT;
931 
932 	return B_OK;
933 }
934 
935 
936 status_t
937 UdpEndpoint::Close()
938 {
939 	TRACE_EP("Close()");
940 	return B_OK;
941 }
942 
943 
944 status_t
945 UdpEndpoint::Free()
946 {
947 	TRACE_EP("Free()");
948 	fManager->UnbindEndpoint(this);
949 	return sUdpEndpointManager->FreeEndpoint(fManager);
950 }
951 
952 
953 // #pragma mark - outbound
954 
955 
956 status_t
957 UdpEndpoint::SendRoutedData(net_buffer *buffer, net_route *route)
958 {
959 	TRACE_EP("SendRoutedData(%p [%lu bytes], %p)", buffer, buffer->size, route);
960 
961 	if (buffer->size > (0xffff - sizeof(udp_header)))
962 		return EMSGSIZE;
963 
964 	buffer->protocol = IPPROTO_UDP;
965 
966 	// add and fill UDP-specific header:
967 	NetBufferPrepend<udp_header> header(buffer);
968 	if (header.Status() < B_OK)
969 		return header.Status();
970 
971 	header->source_port = AddressModule()->get_port(buffer->source);
972 	header->destination_port = AddressModule()->get_port(buffer->destination);
973 	header->udp_length = htons(buffer->size);
974 		// the udp-header is already included in the buffer-size
975 	header->udp_checksum = 0;
976 
977 	header.Sync();
978 
979 	uint16 calculatedChecksum = Checksum::PseudoHeader(AddressModule(),
980 		gBufferModule, buffer, IPPROTO_UDP);
981 	if (calculatedChecksum == 0)
982 		calculatedChecksum = 0xffff;
983 
984 	*UDPChecksumField(buffer) = calculatedChecksum;
985 
986 	return next->module->send_routed_data(next, route, buffer);
987 }
988 
989 
990 status_t
991 UdpEndpoint::SendData(net_buffer *buffer)
992 {
993 	TRACE_EP("SendData(%p [%lu bytes])", buffer, buffer->size);
994 
995 	return gDatalinkModule->send_data(this, NULL, buffer);
996 }
997 
998 
999 // #pragma mark - inbound
1000 
1001 
1002 ssize_t
1003 UdpEndpoint::BytesAvailable()
1004 {
1005 	size_t bytes = AvailableData();
1006 	TRACE_EP("BytesAvailable(): %lu", bytes);
1007 	return bytes;
1008 }
1009 
1010 
1011 status_t
1012 UdpEndpoint::FetchData(size_t numBytes, uint32 flags, net_buffer **_buffer)
1013 {
1014 	TRACE_EP("FetchData(%ld, 0x%lx)", numBytes, flags);
1015 
1016 	status_t status = Dequeue(flags, _buffer);
1017 	TRACE_EP("  FetchData(): returned from fifo status=0x%lx", status);
1018 	if (status != B_OK)
1019 		return status;
1020 
1021 	TRACE_EP("  FetchData(): returns buffer with %ld bytes", (*_buffer)->size);
1022 	return B_OK;
1023 }
1024 
1025 
1026 status_t
1027 UdpEndpoint::StoreData(net_buffer *buffer)
1028 {
1029 	TRACE_EP("StoreData(%p [%ld bytes])", buffer, buffer->size);
1030 
1031 	return EnqueueClone(buffer);
1032 }
1033 
1034 
1035 status_t
1036 UdpEndpoint::DeliverData(net_buffer *_buffer)
1037 {
1038 	TRACE_EP("DeliverData(%p [%ld bytes])", _buffer, _buffer->size);
1039 
1040 	net_buffer *buffer = gBufferModule->clone(_buffer, false);
1041 	if (buffer == NULL)
1042 		return B_NO_MEMORY;
1043 
1044 	status_t status = sUdpEndpointManager->Deframe(buffer);
1045 	if (status < B_OK) {
1046 		gBufferModule->free(buffer);
1047 		return status;
1048 	}
1049 
1050 	return Enqueue(buffer);
1051 }
1052 
1053 
1054 // #pragma mark - protocol interface
1055 
1056 
1057 net_protocol *
1058 udp_init_protocol(net_socket *socket)
1059 {
1060 	socket->protocol = IPPROTO_UDP;
1061 
1062 	UdpEndpoint *endpoint = new (std::nothrow) UdpEndpoint(socket);
1063 	if (endpoint == NULL || endpoint->InitCheck() < B_OK) {
1064 		delete endpoint;
1065 		return NULL;
1066 	}
1067 
1068 	return endpoint;
1069 }
1070 
1071 
1072 status_t
1073 udp_uninit_protocol(net_protocol *protocol)
1074 {
1075 	delete (UdpEndpoint *)protocol;
1076 	return B_OK;
1077 }
1078 
1079 
1080 status_t
1081 udp_open(net_protocol *protocol)
1082 {
1083 	return ((UdpEndpoint *)protocol)->Open();
1084 }
1085 
1086 
1087 status_t
1088 udp_close(net_protocol *protocol)
1089 {
1090 	return ((UdpEndpoint *)protocol)->Close();
1091 }
1092 
1093 
1094 status_t
1095 udp_free(net_protocol *protocol)
1096 {
1097 	return ((UdpEndpoint *)protocol)->Free();
1098 }
1099 
1100 
1101 status_t
1102 udp_connect(net_protocol *protocol, const struct sockaddr *address)
1103 {
1104 	return ((UdpEndpoint *)protocol)->Connect(address);
1105 }
1106 
1107 
1108 status_t
1109 udp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
1110 {
1111 	return B_NOT_SUPPORTED;
1112 }
1113 
1114 
1115 status_t
1116 udp_control(net_protocol *protocol, int level, int option, void *value,
1117 	size_t *_length)
1118 {
1119 	return protocol->next->module->control(protocol->next, level, option,
1120 		value, _length);
1121 }
1122 
1123 
1124 status_t
1125 udp_getsockopt(net_protocol *protocol, int level, int option, void *value,
1126 	int *length)
1127 {
1128 	return protocol->next->module->getsockopt(protocol->next, level, option,
1129 		value, length);
1130 }
1131 
1132 
1133 status_t
1134 udp_setsockopt(net_protocol *protocol, int level, int option,
1135 	const void *value, int length)
1136 {
1137 	return protocol->next->module->setsockopt(protocol->next, level, option,
1138 		value, length);
1139 }
1140 
1141 
1142 status_t
1143 udp_bind(net_protocol *protocol, const struct sockaddr *address)
1144 {
1145 	return ((UdpEndpoint *)protocol)->Bind(address);
1146 }
1147 
1148 
1149 status_t
1150 udp_unbind(net_protocol *protocol, struct sockaddr *address)
1151 {
1152 	return ((UdpEndpoint *)protocol)->Unbind(address);
1153 }
1154 
1155 
1156 status_t
1157 udp_listen(net_protocol *protocol, int count)
1158 {
1159 	return B_NOT_SUPPORTED;
1160 }
1161 
1162 
1163 status_t
1164 udp_shutdown(net_protocol *protocol, int direction)
1165 {
1166 	return B_NOT_SUPPORTED;
1167 }
1168 
1169 
1170 status_t
1171 udp_send_routed_data(net_protocol *protocol, struct net_route *route,
1172 	net_buffer *buffer)
1173 {
1174 	return ((UdpEndpoint *)protocol)->SendRoutedData(buffer, route);
1175 }
1176 
1177 
1178 status_t
1179 udp_send_data(net_protocol *protocol, net_buffer *buffer)
1180 {
1181 	return ((UdpEndpoint *)protocol)->SendData(buffer);
1182 }
1183 
1184 
1185 ssize_t
1186 udp_send_avail(net_protocol *protocol)
1187 {
1188 	return protocol->socket->send.buffer_size;
1189 }
1190 
1191 
1192 status_t
1193 udp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
1194 	net_buffer **_buffer)
1195 {
1196 	return ((UdpEndpoint *)protocol)->FetchData(numBytes, flags, _buffer);
1197 }
1198 
1199 
1200 ssize_t
1201 udp_read_avail(net_protocol *protocol)
1202 {
1203 	return ((UdpEndpoint *)protocol)->BytesAvailable();
1204 }
1205 
1206 
1207 struct net_domain *
1208 udp_get_domain(net_protocol *protocol)
1209 {
1210 	return protocol->next->module->get_domain(protocol->next);
1211 }
1212 
1213 
1214 size_t
1215 udp_get_mtu(net_protocol *protocol, const struct sockaddr *address)
1216 {
1217 	return protocol->next->module->get_mtu(protocol->next, address);
1218 }
1219 
1220 
1221 status_t
1222 udp_receive_data(net_buffer *buffer)
1223 {
1224 	return sUdpEndpointManager->ReceiveData(buffer);
1225 }
1226 
1227 
1228 status_t
1229 udp_deliver_data(net_protocol *protocol, net_buffer *buffer)
1230 {
1231 	return ((UdpEndpoint *)protocol)->DeliverData(buffer);
1232 }
1233 
1234 
1235 status_t
1236 udp_error_received(net_error error, net_buffer* buffer)
1237 {
1238 	status_t notifyError = B_OK;
1239 
1240 	switch (error) {
1241 		case B_NET_ERROR_UNREACH_NET:
1242 			notifyError = ENETUNREACH;
1243 			break;
1244 		case B_NET_ERROR_UNREACH_HOST:
1245 		case B_NET_ERROR_TRANSIT_TIME_EXCEEDED:
1246 			notifyError = EHOSTUNREACH;
1247 			break;
1248 		case B_NET_ERROR_UNREACH_PROTOCOL:
1249 		case B_NET_ERROR_UNREACH_PORT:
1250 			notifyError = ECONNREFUSED;
1251 			break;
1252 		case B_NET_ERROR_MESSAGE_SIZE:
1253 			notifyError = EMSGSIZE;
1254 			break;
1255 		case B_NET_ERROR_PARAMETER_PROBLEM:
1256 			notifyError = ENOPROTOOPT;
1257 			break;
1258 
1259 		case B_NET_ERROR_QUENCH:
1260 		default:
1261 			// ignore them
1262 			break;
1263 	}
1264 
1265 	if (notifyError != B_OK)
1266 		sUdpEndpointManager->ReceiveError(notifyError, buffer);
1267 
1268 	gBufferModule->free(buffer);
1269 	return B_OK;
1270 }
1271 
1272 
1273 status_t
1274 udp_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
1275 	net_error_data *errorData)
1276 {
1277 	return B_ERROR;
1278 }
1279 
1280 
1281 ssize_t
1282 udp_process_ancillary_data_no_container(net_protocol *protocol,
1283 	net_buffer* buffer, void *data, size_t dataSize)
1284 {
1285 	return protocol->next->module->process_ancillary_data_no_container(
1286 		protocol, buffer, data, dataSize);
1287 }
1288 
1289 
1290 //	#pragma mark - module interface
1291 
1292 
1293 static status_t
1294 init_udp()
1295 {
1296 	status_t status;
1297 	TRACE_EPM("init_udp()");
1298 
1299 	sUdpEndpointManager = new (std::nothrow) UdpEndpointManager;
1300 	if (sUdpEndpointManager == NULL)
1301 		return B_NO_MEMORY;
1302 
1303 	status = sUdpEndpointManager->InitCheck();
1304 	if (status != B_OK)
1305 		goto err1;
1306 
1307 	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM,
1308 		IPPROTO_IP,
1309 		"network/protocols/udp/v1",
1310 		"network/protocols/ipv4/v1",
1311 		NULL);
1312 	if (status < B_OK)
1313 		goto err1;
1314 	status = gStackModule->register_domain_protocols(AF_INET6, SOCK_DGRAM,
1315 		IPPROTO_IP,
1316 		"network/protocols/udp/v1",
1317 		"network/protocols/ipv6/v1",
1318 		NULL);
1319 	if (status < B_OK)
1320 		goto err1;
1321 
1322 	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM,
1323 		IPPROTO_UDP,
1324 		"network/protocols/udp/v1",
1325 		"network/protocols/ipv4/v1",
1326 		NULL);
1327 	if (status < B_OK)
1328 		goto err1;
1329 	status = gStackModule->register_domain_protocols(AF_INET6, SOCK_DGRAM,
1330 		IPPROTO_UDP,
1331 		"network/protocols/udp/v1",
1332 		"network/protocols/ipv6/v1",
1333 		NULL);
1334 	if (status < B_OK)
1335 		goto err1;
1336 
1337 	status = gStackModule->register_domain_receiving_protocol(AF_INET,
1338 		IPPROTO_UDP, "network/protocols/udp/v1");
1339 	if (status < B_OK)
1340 		goto err1;
1341 	status = gStackModule->register_domain_receiving_protocol(AF_INET6,
1342 		IPPROTO_UDP, "network/protocols/udp/v1");
1343 	if (status < B_OK)
1344 		goto err1;
1345 
1346 	add_debugger_command("udp_endpoints", UdpEndpointManager::DumpEndpoints,
1347 		"lists all open UDP endpoints");
1348 
1349 	return B_OK;
1350 
1351 err1:
1352 	// TODO: shouldn't unregister the protocols here?
1353 	delete sUdpEndpointManager;
1354 
1355 	TRACE_EPM("init_udp() fails with %lx (%s)", status, strerror(status));
1356 	return status;
1357 }
1358 
1359 
1360 static status_t
1361 uninit_udp()
1362 {
1363 	TRACE_EPM("uninit_udp()");
1364 	remove_debugger_command("udp_endpoints",
1365 		UdpEndpointManager::DumpEndpoints);
1366 	delete sUdpEndpointManager;
1367 	return B_OK;
1368 }
1369 
1370 
1371 static status_t
1372 udp_std_ops(int32 op, ...)
1373 {
1374 	switch (op) {
1375 		case B_MODULE_INIT:
1376 			return init_udp();
1377 
1378 		case B_MODULE_UNINIT:
1379 			return uninit_udp();
1380 
1381 		default:
1382 			return B_ERROR;
1383 	}
1384 }
1385 
1386 
1387 net_protocol_module_info sUDPModule = {
1388 	{
1389 		"network/protocols/udp/v1",
1390 		0,
1391 		udp_std_ops
1392 	},
1393 	NET_PROTOCOL_ATOMIC_MESSAGES,
1394 
1395 	udp_init_protocol,
1396 	udp_uninit_protocol,
1397 	udp_open,
1398 	udp_close,
1399 	udp_free,
1400 	udp_connect,
1401 	udp_accept,
1402 	udp_control,
1403 	udp_getsockopt,
1404 	udp_setsockopt,
1405 	udp_bind,
1406 	udp_unbind,
1407 	udp_listen,
1408 	udp_shutdown,
1409 	udp_send_data,
1410 	udp_send_routed_data,
1411 	udp_send_avail,
1412 	udp_read_data,
1413 	udp_read_avail,
1414 	udp_get_domain,
1415 	udp_get_mtu,
1416 	udp_receive_data,
1417 	udp_deliver_data,
1418 	udp_error_received,
1419 	udp_error_reply,
1420 	NULL,		// add_ancillary_data()
1421 	NULL,		// process_ancillary_data()
1422 	udp_process_ancillary_data_no_container,
1423 	NULL,		// send_data_no_buffer()
1424 	NULL		// read_data_no_buffer()
1425 };
1426 
1427 module_dependency module_dependencies[] = {
1428 	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
1429 	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
1430 	{NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule},
1431 	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
1432 	{}
1433 };
1434 
1435 module_info *modules[] = {
1436 	(module_info *)&sUDPModule,
1437 	NULL
1438 };
1439