xref: /haiku/src/add-ons/kernel/network/protocols/udp/udp.cpp (revision 3c6e2dd68577c34d93e17f19711f6245bf6d0915)
1 /*
2  * Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Oliver Tappe, zooey@hirschkaefer.de
7  *		Hugo Santos, hugosantos@gmail.com
8  */
9 
10 
11 #include <net_buffer.h>
12 #include <net_datalink.h>
13 #include <net_protocol.h>
14 #include <net_stack.h>
15 
16 #include <lock.h>
17 #include <util/AutoLock.h>
18 #include <util/DoublyLinkedList.h>
19 #include <util/OpenHashTable.h>
20 
21 #include <KernelExport.h>
22 
23 #include <NetBufferUtilities.h>
24 #include <NetUtilities.h>
25 #include <ProtocolUtilities.h>
26 
27 #include <netinet/in.h>
28 #include <new>
29 #include <stdlib.h>
30 #include <string.h>
31 #include <utility>
32 
33 
34 // NOTE the locking protocol dictates that we must hold UdpDomainSupport's
35 //      lock before holding a child UdpEndpoint's lock. This restriction
36 //      is dictated by the receive path as blind access to the endpoint
37 //      hash is required when holding the DomainSupport's lock.
38 
39 
40 //#define TRACE_UDP
41 #ifdef TRACE_UDP
42 #	define TRACE_BLOCK(x) dump_block x
43 // do not remove the space before ', ##args' if you want this
44 // to compile with gcc 2.95
45 #	define TRACE_EP(format, args...)	dprintf("UDP [%llu] %p " format "\n", \
46 		system_time(), this , ##args)
47 #	define TRACE_EPM(format, args...)	dprintf("UDP [%llu] " format "\n", \
48 		system_time() , ##args)
49 #	define TRACE_DOMAIN(format, args...)	dprintf("UDP [%llu] (%d) " format \
50 		"\n", system_time(), Domain()->family , ##args)
51 #else
52 #	define TRACE_BLOCK(x)
53 #	define TRACE_EP(args...)	do { } while (0)
54 #	define TRACE_EPM(args...)	do { } while (0)
55 #	define TRACE_DOMAIN(args...)	do { } while (0)
56 #endif
57 
58 
59 struct udp_header {
60 	uint16 source_port;
61 	uint16 destination_port;
62 	uint16 udp_length;
63 	uint16 udp_checksum;
64 } _PACKED;
65 
66 
67 typedef NetBufferField<uint16, offsetof(udp_header, udp_checksum)>
68 	UDPChecksumField;
69 
70 class UdpDomainSupport;
71 
72 class UdpEndpoint : public net_protocol, public DatagramSocket<> {
73 public:
74 	UdpEndpoint(net_socket *socket);
75 
76 	status_t				Bind(const sockaddr *newAddr);
77 	status_t				Unbind(sockaddr *newAddr);
78 	status_t				Connect(const sockaddr *newAddr);
79 
80 	status_t				Open();
81 	status_t				Close();
82 	status_t				Free();
83 
84 	status_t				SendRoutedData(net_buffer *buffer, net_route *route);
85 	status_t				SendData(net_buffer *buffer);
86 
87 	ssize_t					BytesAvailable();
88 	status_t				FetchData(size_t numBytes, uint32 flags,
89 								net_buffer **_buffer);
90 
91 	status_t				StoreData(net_buffer *buffer);
92 	status_t				DeliverData(net_buffer *buffer);
93 
94 	// only the domain support will change/check the Active flag so
95 	// we don't really need to protect it with the socket lock.
96 	bool					IsActive() const { return fActive; }
97 	void					SetActive(bool newValue) { fActive = newValue; }
98 
99 	UdpEndpoint				*&HashTableLink() { return fLink; }
100 
101 private:
102 	UdpDomainSupport		*fManager;
103 	bool					fActive;
104 								// an active UdpEndpoint is part of the endpoint
105 								// hash (and it is bound and optionally connected)
106 
107 	UdpEndpoint				*fLink;
108 };
109 
110 
111 class UdpDomainSupport;
112 
113 struct UdpHashDefinition {
114 	typedef net_address_module_info ParentType;
115 	typedef std::pair<const sockaddr *, const sockaddr *> KeyType;
116 	typedef UdpEndpoint ValueType;
117 
118 	UdpHashDefinition(net_address_module_info *_module)
119 		: module(_module) {}
120 	UdpHashDefinition(const UdpHashDefinition& definition)
121 		: module(definition.module) {}
122 
123 	size_t HashKey(const KeyType &key) const
124 	{
125 		return _Mix(module->hash_address_pair(key.first, key.second));
126 	}
127 
128 	size_t Hash(UdpEndpoint *endpoint) const
129 	{
130 		return _Mix(endpoint->LocalAddress().HashPair(*endpoint->PeerAddress()));
131 	}
132 
133 	static size_t _Mix(size_t hash)
134 	{
135 		// move the bits into the relevant range (as defined by kNumHashBuckets):
136 		return (hash & 0x000007FF) ^ (hash & 0x003FF800) >> 11
137 			^ (hash & 0xFFC00000UL) >> 22;
138 	}
139 
140 	bool Compare(const KeyType &key, UdpEndpoint *endpoint) const
141 	{
142 		return endpoint->LocalAddress().EqualTo(key.first, true)
143 			&& endpoint->PeerAddress().EqualTo(key.second, true);
144 	}
145 
146 	UdpEndpoint *&GetLink(UdpEndpoint *endpoint) const
147 	{
148 		return endpoint->HashTableLink();
149 	}
150 
151 	net_address_module_info *module;
152 };
153 
154 
155 class UdpDomainSupport : public DoublyLinkedListLinkImpl<UdpDomainSupport> {
156 public:
157 	UdpDomainSupport(net_domain *domain);
158 	~UdpDomainSupport();
159 
160 	status_t Init();
161 
162 	net_domain *Domain() const { return fDomain; }
163 
164 	void Ref() { fEndpointCount++; }
165 	bool Put() { fEndpointCount--; return fEndpointCount == 0; }
166 
167 	status_t DemuxIncomingBuffer(net_buffer *buffer);
168 
169 	status_t BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
170 	status_t ConnectEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
171 	status_t UnbindEndpoint(UdpEndpoint *endpoint);
172 
173 	void DumpEndpoints() const;
174 
175 private:
176 	status_t _BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
177 	status_t _Bind(UdpEndpoint *endpoint, const sockaddr *address);
178 	status_t _BindToEphemeral(UdpEndpoint *endpoint, const sockaddr *address);
179 	status_t _FinishBind(UdpEndpoint *endpoint, const sockaddr *address);
180 
181 	UdpEndpoint *_FindActiveEndpoint(const sockaddr *ourAddress,
182 		const sockaddr *peerAddress);
183 	status_t _DemuxBroadcast(net_buffer *buffer);
184 	status_t _DemuxUnicast(net_buffer *buffer);
185 
186 	uint16 _GetNextEphemeral();
187 	UdpEndpoint *_EndpointWithPort(uint16 port) const;
188 
189 	net_address_module_info *AddressModule() const
190 		{ return fDomain->address_module; }
191 
192 	typedef BOpenHashTable<UdpHashDefinition, false> EndpointTable;
193 
194 	mutex			fLock;
195 	net_domain		*fDomain;
196 	uint16			fLastUsedEphemeral;
197 	EndpointTable	fActiveEndpoints;
198 	uint32			fEndpointCount;
199 
200 	static const uint16		kFirst = 49152;
201 	static const uint16		kLast = 65535;
202 	static const uint32		kNumHashBuckets = 0x800;
203 							// if you change this, adjust the shifting in
204 							// Hash() accordingly!
205 };
206 
207 
208 typedef DoublyLinkedList<UdpDomainSupport> UdpDomainList;
209 
210 
211 class UdpEndpointManager {
212 public:
213 	UdpEndpointManager();
214 	~UdpEndpointManager();
215 
216 	status_t		ReceiveData(net_buffer *buffer);
217 	status_t		Deframe(net_buffer *buffer);
218 
219 	UdpDomainSupport *OpenEndpoint(UdpEndpoint *endpoint);
220 	status_t FreeEndpoint(UdpDomainSupport *domain);
221 
222 	status_t		InitCheck() const;
223 
224 	static int DumpEndpoints(int argc, char *argv[]);
225 
226 private:
227 	UdpDomainSupport *_GetDomain(net_domain *domain, bool create);
228 
229 	mutex			fLock;
230 	status_t		fStatus;
231 	UdpDomainList	fDomains;
232 };
233 
234 
235 static UdpEndpointManager *sUdpEndpointManager;
236 
237 net_buffer_module_info *gBufferModule;
238 net_datalink_module_info *gDatalinkModule;
239 net_stack_module_info *gStackModule;
240 
241 
242 // #pragma mark -
243 
244 
245 UdpDomainSupport::UdpDomainSupport(net_domain *domain)
246 	:
247 	fDomain(domain),
248 	fActiveEndpoints(domain->address_module),
249 	fEndpointCount(0)
250 {
251 	mutex_init(&fLock, "udp domain");
252 
253 	fLastUsedEphemeral = kFirst + rand() % (kLast - kFirst);
254 }
255 
256 
257 UdpDomainSupport::~UdpDomainSupport()
258 {
259 	mutex_destroy(&fLock);
260 }
261 
262 
263 status_t
264 UdpDomainSupport::Init()
265 {
266 	return fActiveEndpoints.Init(kNumHashBuckets);
267 }
268 
269 
270 status_t
271 UdpDomainSupport::DemuxIncomingBuffer(net_buffer *buffer)
272 {
273 	// NOTE multicast is delivered directly to the endpoint
274 
275 	MutexLocker _(fLock);
276 
277 	if (buffer->flags & MSG_BCAST)
278 		return _DemuxBroadcast(buffer);
279 	else if (buffer->flags & MSG_MCAST)
280 		return B_ERROR;
281 
282 	return _DemuxUnicast(buffer);
283 }
284 
285 
286 status_t
287 UdpDomainSupport::BindEndpoint(UdpEndpoint *endpoint,
288 	const sockaddr *address)
289 {
290 	MutexLocker _(fLock);
291 
292 	if (endpoint->IsActive())
293 		return EINVAL;
294 
295 	return _BindEndpoint(endpoint, address);
296 }
297 
298 
299 status_t
300 UdpDomainSupport::ConnectEndpoint(UdpEndpoint *endpoint,
301 	const sockaddr *address)
302 {
303 	MutexLocker _(fLock);
304 
305 	if (endpoint->IsActive()) {
306 		fActiveEndpoints.Remove(endpoint);
307 		endpoint->SetActive(false);
308 	}
309 
310 	if (address->sa_family == AF_UNSPEC) {
311 		// [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect",
312 		// so we reset the peer address:
313 		endpoint->PeerAddress().SetToEmpty();
314 	} else {
315 		status_t status = endpoint->PeerAddress().SetTo(address);
316 		if (status < B_OK)
317 			return status;
318 		struct net_route *routeToDestination
319 			= gDatalinkModule->get_route(fDomain, address);
320 		if (routeToDestination) {
321 			status = endpoint->LocalAddress().SetTo(
322 				routeToDestination->interface->address);
323 			gDatalinkModule->put_route(fDomain, routeToDestination);
324 			if (status < B_OK)
325 				return status;
326 		}
327 	}
328 
329 	// we need to activate no matter whether or not we have just disconnected,
330 	// as calling connect() always triggers an implicit bind():
331 	return _BindEndpoint(endpoint, *endpoint->LocalAddress());
332 }
333 
334 
335 status_t
336 UdpDomainSupport::UnbindEndpoint(UdpEndpoint *endpoint)
337 {
338 	MutexLocker _(fLock);
339 
340 	if (endpoint->IsActive())
341 		fActiveEndpoints.Remove(endpoint);
342 
343 	endpoint->SetActive(false);
344 
345 	return B_OK;
346 }
347 
348 
349 void
350 UdpDomainSupport::DumpEndpoints() const
351 {
352 	kprintf("-------- UDP Domain %p ---------\n", this);
353 	kprintf("%10s %20s %20s %8s\n", "address", "local", "peer", "recv-q");
354 
355 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
356 
357 	while (it.HasNext()) {
358 		UdpEndpoint *endpoint = it.Next();
359 
360 		char localBuf[64], peerBuf[64];
361 		endpoint->LocalAddress().AsString(localBuf, sizeof(localBuf), true);
362 		endpoint->PeerAddress().AsString(peerBuf, sizeof(peerBuf), true);
363 
364 		kprintf("%p %20s %20s %8lu\n", endpoint, localBuf, peerBuf,
365 			endpoint->AvailableData());
366 	}
367 }
368 
369 
370 status_t
371 UdpDomainSupport::_BindEndpoint(UdpEndpoint *endpoint,
372 	const sockaddr *address)
373 {
374 	if (AddressModule()->get_port(address) == 0)
375 		return _BindToEphemeral(endpoint, address);
376 
377 	return _Bind(endpoint, address);
378 }
379 
380 
381 status_t
382 UdpDomainSupport::_Bind(UdpEndpoint *endpoint, const sockaddr *address)
383 {
384 	int socketOptions = endpoint->Socket()->options;
385 
386 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
387 
388 	// Iterate over all active UDP-endpoints and check if the requested bind
389 	// is allowed (see figure 22.24 in [Stevens - TCP2, p735]):
390 	TRACE_DOMAIN("CheckBindRequest() for %s...", AddressString(fDomain,
391 		address, true).Data());
392 
393 	while (it.HasNext()) {
394 		UdpEndpoint *otherEndpoint = it.Next();
395 
396 		TRACE_DOMAIN("  ...checking endpoint %p (port=%u)...", otherEndpoint,
397 			ntohs(otherEndpoint->LocalAddress().Port()));
398 
399 		if (otherEndpoint->LocalAddress().EqualPorts(address)) {
400 			// port is already bound, SO_REUSEADDR or SO_REUSEPORT is required:
401 			if ((otherEndpoint->Socket()->options & (SO_REUSEADDR | SO_REUSEPORT)) == 0
402 				|| (socketOptions & (SO_REUSEADDR | SO_REUSEPORT)) == 0)
403 				return EADDRINUSE;
404 
405 			// if both addresses are the same, SO_REUSEPORT is required:
406 			if (otherEndpoint->LocalAddress().EqualTo(address, false)
407 				&& ((otherEndpoint->Socket()->options & SO_REUSEPORT) == 0
408 					|| (socketOptions & SO_REUSEPORT) == 0))
409 				return EADDRINUSE;
410 		}
411 	}
412 
413 	return _FinishBind(endpoint, address);
414 }
415 
416 
417 status_t
418 UdpDomainSupport::_BindToEphemeral(UdpEndpoint *endpoint,
419 	const sockaddr *address)
420 {
421 	SocketAddressStorage newAddress(AddressModule());
422 	status_t status = newAddress.SetTo(address);
423 	if (status < B_OK)
424 		return status;
425 
426 	uint16 allocedPort = _GetNextEphemeral();
427 	if (allocedPort == 0)
428 		return ENOBUFS;
429 
430 	newAddress.SetPort(htons(allocedPort));
431 
432 	return _FinishBind(endpoint, *newAddress);
433 }
434 
435 
436 status_t
437 UdpDomainSupport::_FinishBind(UdpEndpoint *endpoint, const sockaddr *address)
438 {
439 	status_t status = endpoint->next->module->bind(endpoint->next, address);
440 	if (status < B_OK)
441 		return status;
442 
443 	fActiveEndpoints.Insert(endpoint);
444 	endpoint->SetActive(true);
445 
446 	return B_OK;
447 }
448 
449 
450 UdpEndpoint *
451 UdpDomainSupport::_FindActiveEndpoint(const sockaddr *ourAddress,
452 	const sockaddr *peerAddress)
453 {
454 	TRACE_DOMAIN("finding Endpoint for %s <- %s",
455 		AddressString(fDomain, ourAddress, true).Data(),
456 		AddressString(fDomain, peerAddress, true).Data());
457 
458 	return fActiveEndpoints.Lookup(std::make_pair(ourAddress, peerAddress));
459 }
460 
461 
462 status_t
463 UdpDomainSupport::_DemuxBroadcast(net_buffer *buffer)
464 {
465 	sockaddr *peerAddr = buffer->source;
466 	sockaddr *broadcastAddr = buffer->destination;
467 	sockaddr *mask = NULL;
468 	if (buffer->interface)
469 		mask = (sockaddr *)buffer->interface->mask;
470 
471 	TRACE_DOMAIN("_DemuxBroadcast(%p)", buffer);
472 
473 	uint16 incomingPort = AddressModule()->get_port(broadcastAddr);
474 
475 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
476 
477 	while (it.HasNext()) {
478 		UdpEndpoint *endpoint = it.Next();
479 
480 		TRACE_DOMAIN("  _DemuxBroadcast(): checking endpoint %s...",
481 			AddressString(fDomain, *endpoint->LocalAddress(), true).Data());
482 
483 		if (endpoint->LocalAddress().Port() != incomingPort) {
484 			// ports don't match, so we do not dispatch to this endpoint...
485 			continue;
486 		}
487 
488 		if (!endpoint->PeerAddress().IsEmpty(true)) {
489 			// endpoint is connected to a specific destination, we check if
490 			// this datagram is from there:
491 			if (!endpoint->PeerAddress().EqualTo(peerAddr, true)) {
492 				// no, datagram is from another peer, so we do not dispatch to
493 				// this endpoint...
494 				continue;
495 			}
496 		}
497 
498 		if (endpoint->LocalAddress().MatchMasked(broadcastAddr, mask)
499 			|| endpoint->LocalAddress().IsEmpty(false)) {
500 			// address matches, dispatch to this endpoint:
501 			endpoint->StoreData(buffer);
502 		}
503 	}
504 
505 	return B_OK;
506 }
507 
508 
509 status_t
510 UdpDomainSupport::_DemuxUnicast(net_buffer *buffer)
511 {
512 	struct sockaddr *peerAddr = buffer->source;
513 	struct sockaddr *localAddr = buffer->destination;
514 
515 	TRACE_DOMAIN("_DemuxUnicast(%p)", buffer);
516 
517 	UdpEndpoint *endpoint;
518 	// look for full (most special) match:
519 	endpoint = _FindActiveEndpoint(localAddr, peerAddr);
520 	if (!endpoint) {
521 		// look for endpoint matching local address & port:
522 		endpoint = _FindActiveEndpoint(localAddr, NULL);
523 		if (!endpoint) {
524 			// look for endpoint matching peer address & port and local port:
525 			SocketAddressStorage local(AddressModule());
526 			local.SetToEmpty();
527 			local.SetPort(AddressModule()->get_port(localAddr));
528 			endpoint = _FindActiveEndpoint(*local, peerAddr);
529 			if (!endpoint) {
530 				// last chance: look for endpoint matching local port only:
531 				endpoint = _FindActiveEndpoint(*local, NULL);
532 			}
533 		}
534 	}
535 
536 	if (!endpoint) {
537 		TRACE_DOMAIN("_DemuxBroadcast(%p) - no matching endpoint found!", buffer);
538 		return B_NAME_NOT_FOUND;
539 	}
540 
541 	endpoint->StoreData(buffer);
542 	return B_OK;
543 }
544 
545 
546 uint16
547 UdpDomainSupport::_GetNextEphemeral()
548 {
549 	uint16 stop, curr;
550 	if (fLastUsedEphemeral < kLast) {
551 		stop = fLastUsedEphemeral;
552 		curr = fLastUsedEphemeral + 1;
553 	} else {
554 		stop = kLast;
555 		curr = kFirst;
556 	}
557 
558 	TRACE_DOMAIN("_GetNextEphemeral(), last %hu, curr %hu, stop %hu",
559 		fLastUsedEphemeral, curr, stop);
560 
561 	// TODO: a free list could be used to avoid the impact of these two
562 	//        nested loops most of the time... let's see how bad this really is
563 	for (; curr != stop; curr = (curr < kLast) ? (curr + 1) : kFirst) {
564 		TRACE_DOMAIN("  _GetNextEphemeral(): trying port %hu...", curr);
565 
566 		if (_EndpointWithPort(htons(curr)) == NULL) {
567 			TRACE_DOMAIN("  _GetNextEphemeral(): ...using port %hu", curr);
568 			fLastUsedEphemeral = curr;
569 			return curr;
570 		}
571 	}
572 
573 	return 0;
574 }
575 
576 
577 UdpEndpoint *
578 UdpDomainSupport::_EndpointWithPort(uint16 port) const
579 {
580 	EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
581 
582 	while (it.HasNext()) {
583 		UdpEndpoint *endpoint = it.Next();
584 		if (endpoint->LocalAddress().Port() == port)
585 			return endpoint;
586 	}
587 
588 	return NULL;
589 }
590 
591 
592 // #pragma mark -
593 
594 
595 UdpEndpointManager::UdpEndpointManager()
596 {
597 	mutex_init(&fLock, "UDP endpoints");
598 	fStatus = B_OK;
599 }
600 
601 
602 UdpEndpointManager::~UdpEndpointManager()
603 {
604 	mutex_destroy(&fLock);
605 }
606 
607 
608 status_t
609 UdpEndpointManager::InitCheck() const
610 {
611 	return fStatus;
612 }
613 
614 
615 int
616 UdpEndpointManager::DumpEndpoints(int argc, char *argv[])
617 {
618 	UdpDomainList::Iterator it = sUdpEndpointManager->fDomains.GetIterator();
619 
620 	while (it.HasNext())
621 		it.Next()->DumpEndpoints();
622 
623 	return 0;
624 }
625 
626 
627 // #pragma mark - inbound
628 
629 
630 status_t
631 UdpEndpointManager::ReceiveData(net_buffer *buffer)
632 {
633 	status_t status = Deframe(buffer);
634 	if (status < B_OK)
635 		return status;
636 
637 	TRACE_EPM("ReceiveData(%p [%ld bytes])", buffer, buffer->size);
638 
639 	net_domain *domain = buffer->interface->domain;
640 
641 	UdpDomainSupport *domainSupport = NULL;
642 
643 	{
644 		MutexLocker _(fLock);
645 		domainSupport = _GetDomain(domain, false);
646 		// TODO we don't want to hold to the manager's lock
647 		//      during the whole RX path, we may not hold an
648 		//      endpoint's lock with the manager lock held.
649 		//      But we should increase the domain's refcount
650 		//      here.
651 	}
652 
653 	if (domainSupport == NULL) {
654 		// we don't instantiate domain supports in the
655 		// RX path as we are only interested in delivering
656 		// data to existing sockets.
657 		return B_ERROR;
658 	}
659 
660 	status = domainSupport->DemuxIncomingBuffer(buffer);
661 	if (status < B_OK) {
662 		TRACE_EPM("  ReceiveData(): no endpoint.");
663 		// TODO: send ICMP-error
664 		return B_ERROR;
665 	}
666 
667 	gBufferModule->free(buffer);
668 	return B_OK;
669 }
670 
671 
672 status_t
673 UdpEndpointManager::Deframe(net_buffer *buffer)
674 {
675 	TRACE_EPM("Deframe(%p [%ld bytes])", buffer, buffer->size);
676 
677 	NetBufferHeaderReader<udp_header> bufferHeader(buffer);
678 	if (bufferHeader.Status() < B_OK)
679 		return bufferHeader.Status();
680 
681 	udp_header &header = bufferHeader.Data();
682 
683 	if (buffer->interface == NULL || buffer->interface->domain == NULL) {
684 		TRACE_EPM("  Deframe(): UDP packed dropped as there was no domain "
685 			"specified (interface %p).", buffer->interface);
686 		return B_BAD_VALUE;
687 	}
688 
689 	net_domain *domain = buffer->interface->domain;
690 	net_address_module_info *addressModule = domain->address_module;
691 
692 	SocketAddress source(addressModule, buffer->source);
693 	SocketAddress destination(addressModule, buffer->destination);
694 
695 	source.SetPort(header.source_port);
696 	destination.SetPort(header.destination_port);
697 
698 	TRACE_EPM("  Deframe(): data from %s to %s", source.AsString(true).Data(),
699 		destination.AsString(true).Data());
700 
701 	uint16 udpLength = ntohs(header.udp_length);
702 	if (udpLength > buffer->size) {
703 		TRACE_EPM("  Deframe(): buffer is too short, expected %hu.",
704 			udpLength);
705 		return B_MISMATCHED_VALUES;
706 	}
707 
708 	if (buffer->size > udpLength)
709 		gBufferModule->trim(buffer, udpLength);
710 
711 	if (header.udp_checksum != 0) {
712 		// check UDP-checksum (simulating a so-called "pseudo-header"):
713 		uint16 sum = Checksum::PseudoHeader(addressModule, gBufferModule,
714 			buffer, IPPROTO_UDP);
715 		if (sum != 0) {
716 			TRACE_EPM("  Deframe(): bad checksum 0x%hx.", sum);
717 			return B_BAD_VALUE;
718 		}
719 	}
720 
721 	bufferHeader.Remove();
722 		// remove UDP-header from buffer before passing it on
723 
724 	return B_OK;
725 }
726 
727 
728 UdpDomainSupport *
729 UdpEndpointManager::OpenEndpoint(UdpEndpoint *endpoint)
730 {
731 	MutexLocker _(fLock);
732 
733 	UdpDomainSupport *domain = _GetDomain(endpoint->Domain(), true);
734 	if (domain)
735 		domain->Ref();
736 	return domain;
737 }
738 
739 
740 status_t
741 UdpEndpointManager::FreeEndpoint(UdpDomainSupport *domain)
742 {
743 	MutexLocker _(fLock);
744 
745 	if (domain->Put()) {
746 		fDomains.Remove(domain);
747 		delete domain;
748 	}
749 
750 	return B_OK;
751 }
752 
753 
754 // #pragma mark -
755 
756 
757 UdpDomainSupport *
758 UdpEndpointManager::_GetDomain(net_domain *domain, bool create)
759 {
760 	UdpDomainList::Iterator it = fDomains.GetIterator();
761 
762 	// TODO convert this into a Hashtable or install per-domain
763 	//      receiver handlers that forward the requests to the
764 	//      appropriate DemuxIncomingBuffer(). For instance, while
765 	//      being constructed UdpDomainSupport could call
766 	//      register_domain_receiving_protocol() with the right
767 	//      family.
768 	while (it.HasNext()) {
769 		UdpDomainSupport *domainSupport = it.Next();
770 		if (domainSupport->Domain() == domain)
771 			return domainSupport;
772 	}
773 
774 	if (!create)
775 		return NULL;
776 
777 	UdpDomainSupport *domainSupport =
778 		new (std::nothrow) UdpDomainSupport(domain);
779 	if (domainSupport == NULL || domainSupport->Init() < B_OK) {
780 		delete domainSupport;
781 		return NULL;
782 	}
783 
784 	fDomains.Add(domainSupport);
785 	return domainSupport;
786 }
787 
788 
789 // #pragma mark -
790 
791 
792 UdpEndpoint::UdpEndpoint(net_socket *socket)
793 	: DatagramSocket<>("udp endpoint", socket), fActive(false) {}
794 
795 
796 // #pragma mark - activation
797 
798 
799 status_t
800 UdpEndpoint::Bind(const sockaddr *address)
801 {
802 	TRACE_EP("Bind(%s)", AddressString(Domain(), address, true).Data());
803 	return fManager->BindEndpoint(this, address);
804 }
805 
806 
807 status_t
808 UdpEndpoint::Unbind(sockaddr *address)
809 {
810 	TRACE_EP("Unbind()");
811 	return fManager->UnbindEndpoint(this);
812 }
813 
814 
815 status_t
816 UdpEndpoint::Connect(const sockaddr *address)
817 {
818 	TRACE_EP("Connect(%s)", AddressString(Domain(), address, true).Data());
819 	return fManager->ConnectEndpoint(this, address);
820 }
821 
822 
823 status_t
824 UdpEndpoint::Open()
825 {
826 	TRACE_EP("Open()");
827 
828 	AutoLocker _(fLock);
829 
830 	status_t status = ProtocolSocket::Open();
831 	if (status < B_OK)
832 		return status;
833 
834 	fManager = sUdpEndpointManager->OpenEndpoint(this);
835 	if (fManager == NULL)
836 		return EAFNOSUPPORT;
837 
838 	return B_OK;
839 }
840 
841 
842 status_t
843 UdpEndpoint::Close()
844 {
845 	TRACE_EP("Close()");
846 	return B_OK;
847 }
848 
849 
850 status_t
851 UdpEndpoint::Free()
852 {
853 	TRACE_EP("Free()");
854 	fManager->UnbindEndpoint(this);
855 	return sUdpEndpointManager->FreeEndpoint(fManager);
856 }
857 
858 
859 // #pragma mark - outbound
860 
861 
862 status_t
863 UdpEndpoint::SendRoutedData(net_buffer *buffer, net_route *route)
864 {
865 	TRACE_EP("SendRoutedData(%p [%lu bytes], %p)", buffer, buffer->size, route);
866 
867 	if (buffer->size > (0xffff - sizeof(udp_header)))
868 		return EMSGSIZE;
869 
870 	buffer->protocol = IPPROTO_UDP;
871 
872 	// add and fill UDP-specific header:
873 	NetBufferPrepend<udp_header> header(buffer);
874 	if (header.Status() < B_OK)
875 		return header.Status();
876 
877 	header->source_port = AddressModule()->get_port(buffer->source);
878 	header->destination_port = AddressModule()->get_port(buffer->destination);
879 	header->udp_length = htons(buffer->size);
880 		// the udp-header is already included in the buffer-size
881 	header->udp_checksum = 0;
882 
883 	header.Sync();
884 
885 	uint16 calculatedChecksum = Checksum::PseudoHeader(AddressModule(),
886 		gBufferModule, buffer, IPPROTO_UDP);
887 	if (calculatedChecksum == 0)
888 		calculatedChecksum = 0xffff;
889 
890 	*UDPChecksumField(buffer) = calculatedChecksum;
891 
892 	return next->module->send_routed_data(next, route, buffer);
893 }
894 
895 
896 status_t
897 UdpEndpoint::SendData(net_buffer *buffer)
898 {
899 	TRACE_EP("SendData(%p [%lu bytes])", buffer, buffer->size);
900 
901 	return gDatalinkModule->send_datagram(this, NULL, buffer);
902 }
903 
904 
905 // #pragma mark - inbound
906 
907 
908 ssize_t
909 UdpEndpoint::BytesAvailable()
910 {
911 	size_t bytes = AvailableData();
912 	TRACE_EP("BytesAvailable(): %lu", bytes);
913 	return bytes;
914 }
915 
916 
917 status_t
918 UdpEndpoint::FetchData(size_t numBytes, uint32 flags, net_buffer **_buffer)
919 {
920 	TRACE_EP("FetchData(%ld, 0x%lx)", numBytes, flags);
921 
922 	status_t status = SocketDequeue(flags, _buffer);
923 	TRACE_EP("  FetchData(): returned from fifo status=0x%lx", status);
924 	if (status < B_OK)
925 		return status;
926 
927 	TRACE_EP("  FetchData(): returns buffer with %ld bytes", (*_buffer)->size);
928 	return B_OK;
929 }
930 
931 
932 status_t
933 UdpEndpoint::StoreData(net_buffer *buffer)
934 {
935 	TRACE_EP("StoreData(%p [%ld bytes])", buffer, buffer->size);
936 
937 	return SocketEnqueue(buffer);
938 }
939 
940 
941 status_t
942 UdpEndpoint::DeliverData(net_buffer *_buffer)
943 {
944 	TRACE_EP("DeliverData(%p [%ld bytes])", _buffer, _buffer->size);
945 
946 	net_buffer *buffer = gBufferModule->clone(_buffer, false);
947 	if (buffer == NULL)
948 		return B_NO_MEMORY;
949 
950 	status_t status = sUdpEndpointManager->Deframe(buffer);
951 	if (status < B_OK) {
952 		gBufferModule->free(buffer);
953 		return status;
954 	}
955 
956 	// we call Enqueue() instead of SocketEnqueue() as there is
957 	// no need to clone the buffer again.
958 	return Enqueue(buffer);
959 }
960 
961 
962 // #pragma mark - protocol interface
963 
964 
965 net_protocol *
966 udp_init_protocol(net_socket *socket)
967 {
968 	socket->protocol = IPPROTO_UDP;
969 
970 	UdpEndpoint *endpoint = new (std::nothrow) UdpEndpoint(socket);
971 	if (endpoint == NULL || endpoint->InitCheck() < B_OK) {
972 		delete endpoint;
973 		return NULL;
974 	}
975 
976 	return endpoint;
977 }
978 
979 
980 status_t
981 udp_uninit_protocol(net_protocol *protocol)
982 {
983 	delete (UdpEndpoint *)protocol;
984 	return B_OK;
985 }
986 
987 
988 status_t
989 udp_open(net_protocol *protocol)
990 {
991 	return ((UdpEndpoint *)protocol)->Open();
992 }
993 
994 
995 status_t
996 udp_close(net_protocol *protocol)
997 {
998 	return ((UdpEndpoint *)protocol)->Close();
999 }
1000 
1001 
1002 status_t
1003 udp_free(net_protocol *protocol)
1004 {
1005 	return ((UdpEndpoint *)protocol)->Free();
1006 }
1007 
1008 
1009 status_t
1010 udp_connect(net_protocol *protocol, const struct sockaddr *address)
1011 {
1012 	return ((UdpEndpoint *)protocol)->Connect(address);
1013 }
1014 
1015 
1016 status_t
1017 udp_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
1018 {
1019 	return EOPNOTSUPP;
1020 }
1021 
1022 
1023 status_t
1024 udp_control(net_protocol *protocol, int level, int option, void *value,
1025 	size_t *_length)
1026 {
1027 	return protocol->next->module->control(protocol->next, level, option,
1028 		value, _length);
1029 }
1030 
1031 
1032 status_t
1033 udp_getsockopt(net_protocol *protocol, int level, int option, void *value,
1034 	int *length)
1035 {
1036 	return protocol->next->module->getsockopt(protocol->next, level, option,
1037 		value, length);
1038 }
1039 
1040 
1041 status_t
1042 udp_setsockopt(net_protocol *protocol, int level, int option,
1043 	const void *value, int length)
1044 {
1045 	return protocol->next->module->setsockopt(protocol->next, level, option,
1046 		value, length);
1047 }
1048 
1049 
1050 status_t
1051 udp_bind(net_protocol *protocol, const struct sockaddr *address)
1052 {
1053 	return ((UdpEndpoint *)protocol)->Bind(address);
1054 }
1055 
1056 
1057 status_t
1058 udp_unbind(net_protocol *protocol, struct sockaddr *address)
1059 {
1060 	return ((UdpEndpoint *)protocol)->Unbind(address);
1061 }
1062 
1063 
1064 status_t
1065 udp_listen(net_protocol *protocol, int count)
1066 {
1067 	return EOPNOTSUPP;
1068 }
1069 
1070 
1071 status_t
1072 udp_shutdown(net_protocol *protocol, int direction)
1073 {
1074 	return EOPNOTSUPP;
1075 }
1076 
1077 
1078 status_t
1079 udp_send_routed_data(net_protocol *protocol, struct net_route *route,
1080 	net_buffer *buffer)
1081 {
1082 	return ((UdpEndpoint *)protocol)->SendRoutedData(buffer, route);
1083 }
1084 
1085 
1086 status_t
1087 udp_send_data(net_protocol *protocol, net_buffer *buffer)
1088 {
1089 	return ((UdpEndpoint *)protocol)->SendData(buffer);
1090 }
1091 
1092 
1093 ssize_t
1094 udp_send_avail(net_protocol *protocol)
1095 {
1096 	return protocol->socket->send.buffer_size;
1097 }
1098 
1099 
1100 status_t
1101 udp_read_data(net_protocol *protocol, size_t numBytes, uint32 flags,
1102 	net_buffer **_buffer)
1103 {
1104 	return ((UdpEndpoint *)protocol)->FetchData(numBytes, flags, _buffer);
1105 }
1106 
1107 
1108 ssize_t
1109 udp_read_avail(net_protocol *protocol)
1110 {
1111 	return ((UdpEndpoint *)protocol)->BytesAvailable();
1112 }
1113 
1114 
1115 struct net_domain *
1116 udp_get_domain(net_protocol *protocol)
1117 {
1118 	return protocol->next->module->get_domain(protocol->next);
1119 }
1120 
1121 
1122 size_t
1123 udp_get_mtu(net_protocol *protocol, const struct sockaddr *address)
1124 {
1125 	return protocol->next->module->get_mtu(protocol->next, address);
1126 }
1127 
1128 
1129 status_t
1130 udp_receive_data(net_buffer *buffer)
1131 {
1132 	return sUdpEndpointManager->ReceiveData(buffer);
1133 }
1134 
1135 
1136 status_t
1137 udp_deliver_data(net_protocol *protocol, net_buffer *buffer)
1138 {
1139 	return ((UdpEndpoint *)protocol)->DeliverData(buffer);
1140 }
1141 
1142 
1143 status_t
1144 udp_error(uint32 code, net_buffer *data)
1145 {
1146 	return B_ERROR;
1147 }
1148 
1149 
1150 status_t
1151 udp_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code,
1152 	void *errorData)
1153 {
1154 	return B_ERROR;
1155 }
1156 
1157 
1158 ssize_t
1159 udp_process_ancillary_data_no_container(net_protocol *protocol,
1160 	net_buffer* buffer, void *data, size_t dataSize)
1161 {
1162 	return protocol->next->module->process_ancillary_data_no_container(
1163 		protocol, buffer, data, dataSize);
1164 }
1165 
1166 
1167 //	#pragma mark - module interface
1168 
1169 
1170 static status_t
1171 init_udp()
1172 {
1173 	status_t status;
1174 	TRACE_EPM("init_udp()");
1175 
1176 	sUdpEndpointManager = new (std::nothrow) UdpEndpointManager;
1177 	if (sUdpEndpointManager == NULL)
1178 		return B_NO_MEMORY;
1179 
1180 	status = sUdpEndpointManager->InitCheck();
1181 	if (status != B_OK)
1182 		goto err1;
1183 
1184 	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM, IPPROTO_IP,
1185 		"network/protocols/udp/v1",
1186 		"network/protocols/ipv4/v1",
1187 		NULL);
1188 	if (status < B_OK)
1189 		goto err1;
1190 	status = gStackModule->register_domain_protocols(AF_INET, SOCK_DGRAM, IPPROTO_UDP,
1191 		"network/protocols/udp/v1",
1192 		"network/protocols/ipv4/v1",
1193 		NULL);
1194 	if (status < B_OK)
1195 		goto err1;
1196 
1197 	status = gStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_UDP,
1198 		"network/protocols/udp/v1");
1199 	if (status < B_OK)
1200 		goto err1;
1201 
1202 	add_debugger_command("udp_endpoints", UdpEndpointManager::DumpEndpoints,
1203 		"lists all open UDP endpoints");
1204 
1205 	return B_OK;
1206 
1207 err1:
1208 	delete sUdpEndpointManager;
1209 
1210 	TRACE_EPM("init_udp() fails with %lx (%s)", status, strerror(status));
1211 	return status;
1212 }
1213 
1214 
1215 static status_t
1216 uninit_udp()
1217 {
1218 	TRACE_EPM("uninit_udp()");
1219 	remove_debugger_command("udp_endpoints",
1220 		UdpEndpointManager::DumpEndpoints);
1221 	delete sUdpEndpointManager;
1222 	return B_OK;
1223 }
1224 
1225 
1226 static status_t
1227 udp_std_ops(int32 op, ...)
1228 {
1229 	switch (op) {
1230 		case B_MODULE_INIT:
1231 			return init_udp();
1232 
1233 		case B_MODULE_UNINIT:
1234 			return uninit_udp();
1235 
1236 		default:
1237 			return B_ERROR;
1238 	}
1239 }
1240 
1241 
1242 net_protocol_module_info sUDPModule = {
1243 	{
1244 		"network/protocols/udp/v1",
1245 		0,
1246 		udp_std_ops
1247 	},
1248 	NET_PROTOCOL_ATOMIC_MESSAGES,
1249 
1250 	udp_init_protocol,
1251 	udp_uninit_protocol,
1252 	udp_open,
1253 	udp_close,
1254 	udp_free,
1255 	udp_connect,
1256 	udp_accept,
1257 	udp_control,
1258 	udp_getsockopt,
1259 	udp_setsockopt,
1260 	udp_bind,
1261 	udp_unbind,
1262 	udp_listen,
1263 	udp_shutdown,
1264 	udp_send_data,
1265 	udp_send_routed_data,
1266 	udp_send_avail,
1267 	udp_read_data,
1268 	udp_read_avail,
1269 	udp_get_domain,
1270 	udp_get_mtu,
1271 	udp_receive_data,
1272 	udp_deliver_data,
1273 	udp_error,
1274 	udp_error_reply,
1275 	NULL,		// add_ancillary_data()
1276 	NULL,		// process_ancillary_data()
1277 	udp_process_ancillary_data_no_container,
1278 	NULL,		// send_data_no_buffer()
1279 	NULL		// read_data_no_buffer()
1280 };
1281 
1282 module_dependency module_dependencies[] = {
1283 	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
1284 	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
1285 	{NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule},
1286 	{}
1287 };
1288 
1289 module_info *modules[] = {
1290 	(module_info *)&sUDPModule,
1291 	NULL
1292 };
1293