xref: /haiku/src/add-ons/kernel/network/protocols/ipv4/ipv4.cpp (revision 1acbe440b8dd798953bec31d18ee589aa3f71b73)
1 /*
2  * Copyright 2006-2007, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  */
8 
9 
10 #include "ipv4_address.h"
11 
12 #include <net_datalink.h>
13 #include <net_protocol.h>
14 #include <net_stack.h>
15 #include <NetBufferUtilities.h>
16 
17 #include <ByteOrder.h>
18 #include <KernelExport.h>
19 #include <util/AutoLock.h>
20 #include <util/list.h>
21 #include <util/khash.h>
22 #include <util/DoublyLinkedList.h>
23 
24 #include <netinet/in.h>
25 #include <netinet/ip.h>
26 #include <new>
27 #include <stdlib.h>
28 #include <string.h>
29 
30 
31 //#define TRACE_IPV4
32 #ifdef TRACE_IPV4
33 #	define TRACE(x) dprintf x
34 #else
35 #	define TRACE(x) ;
36 #endif
37 
38 struct ipv4_header {
39 #if B_HOST_IS_LENDIAN == 1
40 	uint8		header_length : 4;	// header length in 32-bit words
41 	uint8		version : 4;
42 #else
43 	uint8		version : 4;
44 	uint8		header_length : 4;
45 #endif
46 	uint8		service_type;
47 	uint16		total_length;
48 	uint16		id;
49 	uint16		fragment_offset;
50 	uint8		time_to_live;
51 	uint8		protocol;
52 	uint16		checksum;
53 	in_addr_t	source;
54 	in_addr_t	destination;
55 
56 	uint16 HeaderLength() const { return header_length << 2; }
57 	uint16 TotalLength() const { return ntohs(total_length); }
58 	uint16 FragmentOffset() const { return ntohs(fragment_offset); }
59 } _PACKED;
60 
61 #define IP_VERSION				4
62 
63 // fragment flags
64 #define IP_RESERVED_FLAG		0x8000
65 #define IP_DONT_FRAGMENT		0x4000
66 #define IP_MORE_FRAGMENTS		0x2000
67 #define IP_FRAGMENT_OFFSET_MASK	0x1fff
68 
69 #define MAX_HASH_FRAGMENTS 		64
70 	// slots in the fragment packet's hash
71 #define FRAGMENT_TIMEOUT		60000000LL
72 	// discard fragment after 60 seconds
73 
74 typedef DoublyLinkedList<struct net_buffer,
75 	DoublyLinkedListCLink<struct net_buffer> > FragmentList;
76 
77 struct ipv4_packet_key {
78 	in_addr_t	source;
79 	in_addr_t	destination;
80 	uint16		id;
81 	uint8		protocol;
82 };
83 
84 class FragmentPacket {
85 	public:
86 		FragmentPacket(const ipv4_packet_key &key);
87 		~FragmentPacket();
88 
89 		status_t AddFragment(uint16 start, uint16 end, net_buffer *buffer,
90 					bool lastFragment);
91 		status_t Reassemble(net_buffer *to);
92 
93 		bool IsComplete() const { return fReceivedLastFragment && fBytesLeft == 0; }
94 
95 		static uint32 Hash(void *_packet, const void *_key, uint32 range);
96 		static int Compare(void *_packet, const void *_key);
97 		static int32 NextOffset() { return offsetof(FragmentPacket, fNext); }
98 		static void StaleTimer(struct net_timer *timer, void *data);
99 
100 	private:
101 		FragmentPacket	*fNext;
102 		struct ipv4_packet_key fKey;
103 		bool			fReceivedLastFragment;
104 		int32			fBytesLeft;
105 		FragmentList	fFragments;
106 		net_timer		fTimer;
107 };
108 
109 typedef DoublyLinkedList<class RawSocket> RawSocketList;
110 
111 class RawSocket : public DoublyLinkedListLinkImpl<RawSocket> {
112 	public:
113 		RawSocket(net_socket *socket);
114 		~RawSocket();
115 
116 		status_t InitCheck();
117 
118 		status_t Read(size_t numBytes, uint32 flags, bigtime_t timeout,
119 					net_buffer **_buffer);
120 		ssize_t BytesAvailable();
121 
122 		status_t Write(net_buffer *buffer);
123 
124 	private:
125 		net_socket	*fSocket;
126 		net_fifo	fFifo;
127 };
128 
129 struct ipv4_protocol : net_protocol {
130 	RawSocket	*raw;
131 	uint8		service_type;
132 	uint8		time_to_live;
133 	uint32		flags;
134 };
135 
136 // protocol flags
137 #define IP_FLAG_HEADER_INCLUDED	0x01
138 
139 
140 extern net_protocol_module_info gIPv4Module;
141 	// we need this in ipv4_std_ops() for registering the AF_INET domain
142 
143 static struct net_domain *sDomain;
144 static net_datalink_module_info *sDatalinkModule;
145 static net_stack_module_info *sStackModule;
146 struct net_buffer_module_info *gBufferModule;
147 static int32 sPacketID;
148 static RawSocketList sRawSockets;
149 static benaphore sRawSocketsLock;
150 static benaphore sFragmentLock;
151 static hash_table *sFragmentHash;
152 static net_protocol_module_info *sReceivingProtocol[256];
153 static benaphore sReceivingProtocolLock;
154 
155 
156 RawSocket::RawSocket(net_socket *socket)
157 	:
158 	fSocket(socket)
159 {
160 	status_t status = sStackModule->init_fifo(&fFifo, "ipv4 raw socket", 65536);
161 	if (status < B_OK)
162 		fFifo.notify = status;
163 }
164 
165 
166 RawSocket::~RawSocket()
167 {
168 	if (fFifo.notify >= B_OK)
169 		sStackModule->uninit_fifo(&fFifo);
170 }
171 
172 
173 status_t
174 RawSocket::InitCheck()
175 {
176 	return fFifo.notify >= B_OK ? B_OK : fFifo.notify;
177 }
178 
179 
180 status_t
181 RawSocket::Read(size_t numBytes, uint32 flags, bigtime_t timeout,
182 	net_buffer **_buffer)
183 {
184 	net_buffer *buffer;
185 	status_t status = sStackModule->fifo_dequeue_buffer(&fFifo,
186 		flags, timeout, &buffer);
187 	if (status < B_OK)
188 		return status;
189 
190 	*_buffer = buffer;
191 	return B_OK;
192 }
193 
194 
195 ssize_t
196 RawSocket::BytesAvailable()
197 {
198 	return fFifo.current_bytes;
199 }
200 
201 
202 status_t
203 RawSocket::Write(net_buffer *source)
204 {
205 	return sStackModule->fifo_socket_enqueue_buffer(&fFifo, fSocket,
206 			B_SELECT_READ, source);
207 }
208 
209 
210 //	#pragma mark -
211 
212 
213 FragmentPacket::FragmentPacket(const ipv4_packet_key &key)
214 	:
215 	fKey(key),
216 	fReceivedLastFragment(false),
217 	fBytesLeft(IP_MAXPACKET)
218 {
219 	sStackModule->init_timer(&fTimer, StaleTimer, this);
220 }
221 
222 
223 FragmentPacket::~FragmentPacket()
224 {
225 	// cancel the kill timer
226 	sStackModule->set_timer(&fTimer, -1);
227 
228 	// delete all fragments
229 	net_buffer *buffer;
230 	while ((buffer = fFragments.RemoveHead()) != NULL) {
231 		gBufferModule->free(buffer);
232 	}
233 }
234 
235 
236 status_t
237 FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer *buffer,
238 	bool lastFragment)
239 {
240 	// restart the timer
241 	sStackModule->set_timer(&fTimer, FRAGMENT_TIMEOUT);
242 
243 	if (start >= end) {
244 		// invalid fragment
245 		return B_BAD_DATA;
246 	}
247 
248 	// Search for a position in the list to insert the fragment
249 
250 	FragmentList::ReverseIterator iterator = fFragments.GetReverseIterator();
251 	net_buffer *previous = NULL;
252 	net_buffer *next = NULL;
253 	while ((previous = iterator.Next()) != NULL) {
254 		if (previous->fragment.start <= start) {
255 			// The new fragment can be inserted after this one
256 			break;
257 		}
258 
259 		next = previous;
260 	}
261 
262 	// See if we already have the fragment's data
263 
264 	if (previous != NULL && previous->fragment.start <= start
265 		&& previous->fragment.end >= end) {
266 		// we do, so we can just drop this fragment
267 		gBufferModule->free(buffer);
268 		return B_OK;
269 	}
270 
271 	TRACE(("    previous: %p, next: %p\n", previous, next));
272 
273 	// If we have parts of the data already, truncate as needed
274 
275 	if (previous != NULL && previous->fragment.end > start) {
276 		TRACE(("    remove header %d bytes\n", previous->fragment.end - start));
277 		gBufferModule->remove_header(buffer, previous->fragment.end - start);
278 		start = previous->fragment.end;
279 	}
280 	if (next != NULL && next->fragment.start < end) {
281 		TRACE(("    remove trailer %d bytes\n", next->fragment.start - end));
282 		gBufferModule->remove_trailer(buffer, next->fragment.start - end);
283 		end = next->fragment.start;
284 	}
285 
286 	// Now try if we can already merge the fragments together
287 
288 	// We will always keep the last buffer received, so that we can still
289 	// report an error (in which case we're not responsible for freeing it)
290 
291 	if (previous != NULL && previous->fragment.end == start) {
292 		fFragments.Remove(previous);
293 
294 		buffer->fragment.start = previous->fragment.start;
295 		buffer->fragment.end = end;
296 
297 		status_t status = gBufferModule->merge(buffer, previous, false);
298 		TRACE(("    merge previous: %s\n", strerror(status)));
299 		if (status < B_OK) {
300 			fFragments.Insert(next, previous);
301 			return status;
302 		}
303 
304 		fFragments.Insert(next, buffer);
305 
306 		// cut down existing hole
307 		fBytesLeft -= end - start;
308 
309 		if (lastFragment && !fReceivedLastFragment) {
310 			fReceivedLastFragment = true;
311 			fBytesLeft -= IP_MAXPACKET - end;
312 		}
313 
314 		TRACE(("    hole length: %d\n", (int)fBytesLeft));
315 
316 		return B_OK;
317 	} else if (next != NULL && next->fragment.start == end) {
318 		fFragments.Remove(next);
319 
320 		buffer->fragment.start = start;
321 		buffer->fragment.end = next->fragment.end;
322 
323 		status_t status = gBufferModule->merge(buffer, next, true);
324 		TRACE(("    merge next: %s\n", strerror(status)));
325 		if (status < B_OK) {
326 			fFragments.Insert((net_buffer *)previous->link.next, next);
327 			return status;
328 		}
329 
330 		fFragments.Insert((net_buffer *)previous->link.next, buffer);
331 
332 		// cut down existing hole
333 		fBytesLeft -= end - start;
334 
335 		if (lastFragment && !fReceivedLastFragment) {
336 			fReceivedLastFragment = true;
337 			fBytesLeft -= IP_MAXPACKET - end;
338 		}
339 
340 		TRACE(("    hole length: %d\n", (int)fBytesLeft));
341 
342 		return B_OK;
343 	}
344 
345 	// We couldn't merge the fragments, so we need to add it as is
346 
347 	TRACE(("    new fragment: %p, bytes %d-%d\n", buffer, start, end));
348 
349 	buffer->fragment.start = start;
350 	buffer->fragment.end = end;
351 	fFragments.Insert(next, buffer);
352 
353 	// update length of the hole, if any
354 	fBytesLeft -= end - start;
355 
356 	if (lastFragment && !fReceivedLastFragment) {
357 		fReceivedLastFragment = true;
358 		fBytesLeft -= IP_MAXPACKET - end;
359 	}
360 
361 	TRACE(("    hole length: %d\n", (int)fBytesLeft));
362 
363 	return B_OK;
364 }
365 
366 
367 /*!
368 	Reassembles the fragments to the specified buffer \a to.
369 	This buffer must have been added via AddFragment() before.
370 */
371 status_t
372 FragmentPacket::Reassemble(net_buffer *to)
373 {
374 	if (!IsComplete())
375 		return NULL;
376 
377 	net_buffer *buffer = NULL;
378 
379 	net_buffer *fragment;
380 	while ((fragment = fFragments.RemoveHead()) != NULL) {
381 		if (buffer != NULL) {
382 			status_t status;
383 			if (to == fragment) {
384 				status = gBufferModule->merge(fragment, buffer, false);
385 				buffer = fragment;
386 			} else
387 				status = gBufferModule->merge(buffer, fragment, true);
388 			if (status < B_OK)
389 				return status;
390 		} else
391 			buffer = fragment;
392 	}
393 
394 	if (buffer != to)
395 		panic("ipv4 packet reassembly did not work correctly.\n");
396 
397 	return B_OK;
398 }
399 
400 
401 int
402 FragmentPacket::Compare(void *_packet, const void *_key)
403 {
404 	const ipv4_packet_key *key = (ipv4_packet_key *)_key;
405 	ipv4_packet_key *packetKey = &((FragmentPacket *)_packet)->fKey;
406 
407 	if (packetKey->id == key->id
408 		&& packetKey->source == key->source
409 		&& packetKey->destination == key->destination
410 		&& packetKey->protocol == key->protocol)
411 		return 0;
412 
413 	return 1;
414 }
415 
416 
417 uint32
418 FragmentPacket::Hash(void *_packet, const void *_key, uint32 range)
419 {
420 	const struct ipv4_packet_key *key = (struct ipv4_packet_key *)_key;
421 	FragmentPacket *packet = (FragmentPacket *)_packet;
422 	if (packet != NULL)
423 		key = &packet->fKey;
424 
425 	return (key->source ^ key->destination ^ key->protocol ^ key->id) % range;
426 }
427 
428 
429 /*static*/ void
430 FragmentPacket::StaleTimer(struct net_timer *timer, void *data)
431 {
432 	FragmentPacket *packet = (FragmentPacket *)data;
433 	TRACE(("Assembling FragmentPacket %p timed out!\n", packet));
434 
435 	BenaphoreLocker locker(&sFragmentLock);
436 
437 	hash_remove(sFragmentHash, packet);
438 	delete packet;
439 }
440 
441 
442 //	#pragma mark -
443 
444 
445 #if 0
446 static void
447 dump_ipv4_header(ipv4_header &header)
448 {
449 	struct pretty_ipv4 {
450 	#if B_HOST_IS_LENDIAN == 1
451 		uint8 a;
452 		uint8 b;
453 		uint8 c;
454 		uint8 d;
455 	#else
456 		uint8 d;
457 		uint8 c;
458 		uint8 b;
459 		uint8 a;
460 	#endif
461 	};
462 	struct pretty_ipv4 *src = (struct pretty_ipv4 *)&header.source;
463 	struct pretty_ipv4 *dst = (struct pretty_ipv4 *)&header.destination;
464 	dprintf("  version: %d\n", header.version);
465 	dprintf("  header_length: 4 * %d\n", header.header_length);
466 	dprintf("  service_type: %d\n", header.service_type);
467 	dprintf("  total_length: %d\n", header.TotalLength());
468 	dprintf("  id: %d\n", ntohs(header.id));
469 	dprintf("  fragment_offset: %d (flags: %c%c%c)\n",
470 		header.FragmentOffset() & IP_FRAGMENT_OFFSET_MASK,
471 		(header.FragmentOffset() & IP_RESERVED_FLAG) ? 'r' : '-',
472 		(header.FragmentOffset() & IP_DONT_FRAGMENT) ? 'd' : '-',
473 		(header.FragmentOffset() & IP_MORE_FRAGMENTS) ? 'm' : '-');
474 	dprintf("  time_to_live: %d\n", header.time_to_live);
475 	dprintf("  protocol: %d\n", header.protocol);
476 	dprintf("  checksum: %d\n", ntohs(header.checksum));
477 	dprintf("  source: %d.%d.%d.%d\n", src->a, src->b, src->c, src->d);
478 	dprintf("  destination: %d.%d.%d.%d\n", dst->a, dst->b, dst->c, dst->d);
479 }
480 #endif
481 
482 
483 /*!
484 	Attempts to re-assemble fragmented packets.
485 	\return B_OK if everything went well; if it could reassemble the packet, \a _buffer
486 		will point to its buffer, otherwise, it will be \c NULL.
487 	\return various error codes if something went wrong (mostly B_NO_MEMORY)
488 */
489 static status_t
490 reassemble_fragments(const ipv4_header &header, net_buffer **_buffer)
491 {
492 	net_buffer *buffer = *_buffer;
493 	status_t status;
494 
495 	struct ipv4_packet_key key;
496 	key.source = (in_addr_t)header.source;
497 	key.destination = (in_addr_t)header.destination;
498 	key.id = header.id;
499 	key.protocol = header.protocol;
500 
501 	// TODO: Make locking finer grained.
502 	BenaphoreLocker locker(&sFragmentLock);
503 
504 	FragmentPacket *packet = (FragmentPacket *)hash_lookup(sFragmentHash, &key);
505 	if (packet == NULL) {
506 		// New fragment packet
507 		packet = new (std::nothrow) FragmentPacket(key);
508 		if (packet == NULL)
509 			return B_NO_MEMORY;
510 
511 		// add packet to hash
512 		status = hash_insert(sFragmentHash, packet);
513 		if (status != B_OK) {
514 			delete packet;
515 			return status;
516 		}
517 	}
518 
519 	uint16 fragmentOffset = header.FragmentOffset();
520 	uint16 start = (fragmentOffset & IP_FRAGMENT_OFFSET_MASK) << 3;
521 	uint16 end = start + header.TotalLength() - header.HeaderLength();
522 	bool lastFragment = (fragmentOffset & IP_MORE_FRAGMENTS) == 0;
523 
524 	TRACE(("   Received IPv4 %sfragment of size %d, offset %d.\n",
525 		lastFragment ? "last ": "", end - start, start));
526 
527 	// Remove header unless this is the first fragment
528 	if (start != 0)
529 		gBufferModule->remove_header(buffer, header.HeaderLength());
530 
531 	status = packet->AddFragment(start, end, buffer, lastFragment);
532 	if (status != B_OK)
533 		return status;
534 
535 	if (packet->IsComplete()) {
536 		hash_remove(sFragmentHash, packet);
537 			// no matter if reassembling succeeds, we won't need this packet anymore
538 
539 		status = packet->Reassemble(buffer);
540 		delete packet;
541 
542 		// _buffer does not change
543 		return status;
544 	}
545 
546 	// This indicates that the packet is not yet complete
547 	*_buffer = NULL;
548 	return B_OK;
549 }
550 
551 
552 /*!
553 	Fragments the incoming buffer and send all fragments via the specified
554 	\a route.
555 */
556 static status_t
557 send_fragments(ipv4_protocol *protocol, struct net_route *route,
558 	net_buffer *buffer, uint32 mtu)
559 {
560 	TRACE(("ipv4 needs to fragment (size %lu, MTU %lu)...\n",
561 		buffer->size, mtu));
562 
563 	NetBufferHeader<ipv4_header> bufferHeader(buffer);
564 	if (bufferHeader.Status() < B_OK)
565 		return bufferHeader.Status();
566 
567 	ipv4_header *header = &bufferHeader.Data();
568 	bufferHeader.Detach();
569 
570 	uint16 headerLength = header->HeaderLength();
571 	uint32 bytesLeft = buffer->size - headerLength;
572 	uint32 fragmentOffset = 0;
573 	status_t status = B_OK;
574 
575 	net_buffer *headerBuffer = gBufferModule->split(buffer, headerLength);
576 	if (headerBuffer == NULL)
577 		return B_NO_MEMORY;
578 
579 	bufferHeader.SetTo(headerBuffer);
580 	header = &bufferHeader.Data();
581 	bufferHeader.Detach();
582 
583 	// adapt MTU to be a multiple of 8 (fragment offsets can only be specified this way)
584 	mtu -= headerLength;
585 	mtu &= ~7;
586 	dprintf("  adjusted MTU to %ld\n", mtu);
587 
588 	dprintf("  bytesLeft = %ld\n", bytesLeft);
589 	while (bytesLeft > 0) {
590 		uint32 fragmentLength = min_c(bytesLeft, mtu);
591 		bytesLeft -= fragmentLength;
592 		bool lastFragment = bytesLeft == 0;
593 
594 		header->total_length = htons(fragmentLength + headerLength);
595 		header->fragment_offset = htons((lastFragment ? 0 : IP_MORE_FRAGMENTS)
596 			| (fragmentOffset >> 3));
597 		header->checksum = 0;
598 		header->checksum = sStackModule->checksum((uint8 *)header, headerLength);
599 			// TODO: compute the checksum only for those parts that changed?
600 
601 		dprintf("  send fragment of %ld bytes (%ld bytes left)\n", fragmentLength, bytesLeft);
602 
603 		net_buffer *fragmentBuffer;
604 		if (!lastFragment) {
605 			fragmentBuffer = gBufferModule->split(buffer, fragmentLength);
606 			fragmentOffset += fragmentLength;
607 		} else
608 			fragmentBuffer = buffer;
609 
610 		if (fragmentBuffer == NULL) {
611 			status = B_NO_MEMORY;
612 			break;
613 		}
614 
615 		// copy header to fragment
616 		status = gBufferModule->prepend(fragmentBuffer, header, headerLength);
617 
618 		// send fragment
619 		if (status == B_OK)
620 			status = sDatalinkModule->send_data(route, fragmentBuffer);
621 
622 		if (lastFragment) {
623 			// we don't own the last buffer, so we don't have to free it
624 			break;
625 		}
626 
627 		if (status < B_OK) {
628 			gBufferModule->free(fragmentBuffer);
629 			break;
630 		}
631 	}
632 
633 	gBufferModule->free(headerBuffer);
634 	return status;
635 }
636 
637 
638 static void
639 raw_receive_data(net_buffer *buffer)
640 {
641 	BenaphoreLocker locker(sRawSocketsLock);
642 
643 	TRACE(("ipv4:raw_receive_data(): protocol %i\n", buffer->protocol));
644 
645 	RawSocketList::Iterator iterator = sRawSockets.GetIterator();
646 
647 	while (iterator.HasNext()) {
648 		RawSocket *raw = iterator.Next();
649 		raw->Write(buffer);
650 	}
651 }
652 
653 
654 static net_protocol_module_info *
655 receiving_protocol(uint8 protocol)
656 {
657 	net_protocol_module_info *module = sReceivingProtocol[protocol];
658 	if (module != NULL)
659 		return module;
660 
661 	BenaphoreLocker locker(sReceivingProtocolLock);
662 
663 	module = sReceivingProtocol[protocol];
664 	if (module != NULL)
665 		return module;
666 
667 	if (sStackModule->get_domain_receiving_protocol(sDomain, protocol, &module) == B_OK)
668 		sReceivingProtocol[protocol] = module;
669 
670 	return module;
671 }
672 
673 
674 //	#pragma mark -
675 
676 
677 net_protocol *
678 ipv4_init_protocol(net_socket *socket)
679 {
680 	ipv4_protocol *protocol = new (std::nothrow) ipv4_protocol;
681 	if (protocol == NULL)
682 		return NULL;
683 
684 	protocol->raw = NULL;
685 	protocol->service_type = 0;
686 	protocol->time_to_live = 254;
687 	protocol->flags = 0;
688 	return protocol;
689 }
690 
691 
692 status_t
693 ipv4_uninit_protocol(net_protocol *_protocol)
694 {
695 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
696 
697 	delete protocol->raw;
698 	delete protocol;
699 	return B_OK;
700 }
701 
702 
703 /*!
704 	Since open() is only called on the top level protocol, when we get here
705 	it means we are on a SOCK_RAW socket.
706 */
707 status_t
708 ipv4_open(net_protocol *_protocol)
709 {
710 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
711 
712 	RawSocket *raw = new (std::nothrow) RawSocket(protocol->socket);
713 	if (raw == NULL)
714 		return B_NO_MEMORY;
715 
716 	status_t status = raw->InitCheck();
717 	if (status < B_OK) {
718 		delete raw;
719 		return status;
720 	}
721 
722 	protocol->raw = raw;
723 
724 	BenaphoreLocker locker(sRawSocketsLock);
725 	sRawSockets.Add(raw);
726 	return B_OK;
727 }
728 
729 
730 status_t
731 ipv4_close(net_protocol *_protocol)
732 {
733 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
734 	RawSocket *raw = protocol->raw;
735 	if (raw == NULL)
736 		return B_ERROR;
737 
738 	BenaphoreLocker locker(sRawSocketsLock);
739 	sRawSockets.Remove(raw);
740 	delete raw;
741 	protocol->raw = NULL;
742 
743 	return B_OK;
744 }
745 
746 
747 status_t
748 ipv4_free(net_protocol *protocol)
749 {
750 	return B_OK;
751 }
752 
753 
754 status_t
755 ipv4_connect(net_protocol *protocol, const struct sockaddr *address)
756 {
757 	return B_ERROR;
758 }
759 
760 
761 status_t
762 ipv4_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
763 {
764 	return EOPNOTSUPP;
765 }
766 
767 
768 status_t
769 ipv4_control(net_protocol *_protocol, int level, int option, void *value,
770 	size_t *_length)
771 {
772 	if ((level & LEVEL_MASK) != IPPROTO_IP)
773 		return sDatalinkModule->control(sDomain, option, value, _length);
774 
775 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
776 
777 	if (level & LEVEL_GET_OPTION) {
778 		// get options
779 
780 		switch (option) {
781 			case IP_HDRINCL:
782 			{
783 				if (*_length != sizeof(int))
784 					return B_BAD_VALUE;
785 
786 				int headerIncluded = (protocol->flags & IP_FLAG_HEADER_INCLUDED) != 0;
787 				return user_memcpy(value, &headerIncluded, sizeof(headerIncluded));
788 			}
789 
790 			case IP_TTL:
791 			{
792 				if (*_length != sizeof(int))
793 					return B_BAD_VALUE;
794 
795 				int timeToLive = protocol->time_to_live;
796 				return user_memcpy(value, &timeToLive, sizeof(timeToLive));
797 			}
798 
799 			case IP_TOS:
800 			{
801 				if (*_length != sizeof(int))
802 					return B_BAD_VALUE;
803 
804 				int serviceType = protocol->service_type;
805 				return user_memcpy(value, &serviceType, sizeof(serviceType));
806 			}
807 
808 			default:
809 				dprintf("IPv4::control(): get unknown option: %d\n", option);
810 				return ENOPROTOOPT;
811 		}
812 	}
813 
814 	// set options
815 
816 	switch (option) {
817 		case IP_HDRINCL:
818 		{
819 			int headerIncluded;
820 			if (*_length != sizeof(int))
821 				return B_BAD_VALUE;
822 			if (user_memcpy(&headerIncluded, value, sizeof(headerIncluded)) < B_OK)
823 				return B_BAD_ADDRESS;
824 
825 			if (headerIncluded)
826 				protocol->flags |= IP_FLAG_HEADER_INCLUDED;
827 			else
828 				protocol->flags &= ~IP_FLAG_HEADER_INCLUDED;
829 			return B_OK;
830 		}
831 
832 		case IP_TTL:
833 		{
834 			int timeToLive;
835 			if (*_length != sizeof(int))
836 				return B_BAD_VALUE;
837 			if (user_memcpy(&timeToLive, value, sizeof(timeToLive)) < B_OK)
838 				return B_BAD_ADDRESS;
839 
840 			protocol->time_to_live = timeToLive;
841 			return B_OK;
842 		}
843 
844 		case IP_TOS:
845 		{
846 			int serviceType;
847 			if (*_length != sizeof(int))
848 				return B_BAD_VALUE;
849 			if (user_memcpy(&serviceType, value, sizeof(serviceType)) < B_OK)
850 				return B_BAD_ADDRESS;
851 
852 			protocol->service_type = serviceType;
853 			return B_OK;
854 		}
855 
856 		default:
857 			dprintf("IPv4::control(): set unknown option: %d\n", option);
858 			return ENOPROTOOPT;
859 	}
860 
861 	// never gets here
862 	return B_BAD_VALUE;
863 }
864 
865 
866 status_t
867 ipv4_bind(net_protocol *protocol, struct sockaddr *address)
868 {
869 	if (address->sa_family != AF_INET)
870 		return EAFNOSUPPORT;
871 
872 	// only INADDR_ANY and addresses of local interfaces are accepted:
873 	if (((sockaddr_in *)address)->sin_addr.s_addr == INADDR_ANY
874 		|| sDatalinkModule->is_local_address(sDomain, address, NULL, NULL)) {
875 		protocol->socket->address.ss_len = sizeof(struct sockaddr_in);
876 			// explicitly set length, as our callers can't be trusted to
877 			// always provide the correct length!
878 		return B_OK;
879 	}
880 
881 	return B_ERROR;
882 		// address is unknown on this host
883 }
884 
885 
886 status_t
887 ipv4_unbind(net_protocol *protocol, struct sockaddr *address)
888 {
889 	// nothing to do here
890 	return B_OK;
891 }
892 
893 
894 status_t
895 ipv4_listen(net_protocol *protocol, int count)
896 {
897 	return EOPNOTSUPP;
898 }
899 
900 
901 status_t
902 ipv4_shutdown(net_protocol *protocol, int direction)
903 {
904 	return EOPNOTSUPP;
905 }
906 
907 
908 status_t
909 ipv4_send_routed_data(net_protocol *_protocol, struct net_route *route,
910 	net_buffer *buffer)
911 {
912 	if (route == NULL)
913 		return B_BAD_VALUE;
914 
915 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
916 	net_interface *interface = route->interface;
917 
918 	TRACE(("someone tries to send some actual routed data!\n"));
919 
920 	sockaddr_in &source = *(sockaddr_in *)&buffer->source;
921 	if (source.sin_addr.s_addr == INADDR_ANY && route->interface->address != NULL) {
922 		// replace an unbound source address with the address of the interface
923 		// TODO: couldn't we replace all addresses here?
924 		source.sin_addr.s_addr = ((sockaddr_in *)route->interface->address)->sin_addr.s_addr;
925 	}
926 
927 	bool headerIncluded = false;
928 	if (protocol != NULL)
929 		headerIncluded = (protocol->flags & IP_FLAG_HEADER_INCLUDED) != 0;
930 
931 	// Add IP header (if needed)
932 
933 	if (!headerIncluded) {
934 		NetBufferPrepend<ipv4_header> bufferHeader(buffer);
935 		if (bufferHeader.Status() < B_OK)
936 			return bufferHeader.Status();
937 
938 		ipv4_header &header = bufferHeader.Data();
939 
940 		header.version = IP_VERSION;
941 		header.header_length = sizeof(ipv4_header) >> 2;
942 		header.service_type = protocol ? protocol->service_type : 0;
943 		header.total_length = htons(buffer->size);
944 		header.id = htons(atomic_add(&sPacketID, 1));
945 		header.fragment_offset = 0;
946 		header.time_to_live = protocol ? protocol->time_to_live : 254;
947 		header.protocol = protocol ? protocol->socket->protocol : buffer->protocol;
948 		header.checksum = 0;
949 		if (route->interface->address != NULL) {
950 			header.source = ((sockaddr_in *)route->interface->address)->sin_addr.s_addr;
951 				// always use the actual used source address
952 		} else
953 			header.source = 0;
954 
955 		header.destination = ((sockaddr_in *)&buffer->destination)->sin_addr.s_addr;
956 
957 		header.checksum = gBufferModule->checksum(buffer, 0,
958 			sizeof(ipv4_header), true);
959 		//dump_ipv4_header(header);
960 
961 		bufferHeader.Detach();
962 			// make sure the IP-header is already written to the
963 			// buffer at this point
964 	} else {
965 		// if IP_HDRINCL, check if the source address is set
966 		NetBufferHeader<ipv4_header> bufferHeader(buffer);
967 		if (bufferHeader.Status() < B_OK)
968 			return bufferHeader.Status();
969 
970 		ipv4_header &header = bufferHeader.Data();
971 		if (header.source == 0) {
972 			header.source = source.sin_addr.s_addr;
973 			header.checksum = gBufferModule->checksum(buffer,
974 				sizeof(ipv4_header), sizeof(ipv4_header), true);
975 		}
976 
977 		bufferHeader.Detach();
978 	}
979 
980 	if (buffer->size > 0xffff)
981 		return EMSGSIZE;
982 
983 	TRACE(("header chksum: %ld, buffer checksum: %ld\n",
984 		gBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true),
985 		gBufferModule->checksum(buffer, 0, buffer->size, true)));
986 
987 	TRACE(("destination-IP: buffer=%p addr=%p %08lx\n", buffer, &buffer->destination,
988 		ntohl(((sockaddr_in *)&buffer->destination)->sin_addr.s_addr)));
989 
990 	uint32 mtu = route->mtu ? route->mtu : interface->mtu;
991 	if (buffer->size > mtu) {
992 		// we need to fragment the packet
993 		return send_fragments(protocol, route, buffer, mtu);
994 	}
995 
996 	return sDatalinkModule->send_data(route, buffer);
997 }
998 
999 
1000 status_t
1001 ipv4_send_data(net_protocol *protocol, net_buffer *buffer)
1002 {
1003 	TRACE(("someone tries to send some actual data!\n"));
1004 
1005 	// find route
1006 	struct net_route *route = sDatalinkModule->get_route(sDomain,
1007 		(sockaddr *)&buffer->destination);
1008 	if (route == NULL)
1009 		return ENETUNREACH;
1010 
1011 	status_t status = ipv4_send_routed_data(protocol, route, buffer);
1012 	sDatalinkModule->put_route(sDomain, route);
1013 
1014 	return status;
1015 }
1016 
1017 
1018 ssize_t
1019 ipv4_send_avail(net_protocol *protocol)
1020 {
1021 	return B_ERROR;
1022 }
1023 
1024 
1025 status_t
1026 ipv4_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
1027 	net_buffer **_buffer)
1028 {
1029 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
1030 	RawSocket *raw = protocol->raw;
1031 	if (raw == NULL)
1032 		return B_ERROR;
1033 
1034 	TRACE(("read is waiting for data...\n"));
1035 	return raw->Read(numBytes, flags, protocol->socket->receive.timeout, _buffer);
1036 }
1037 
1038 
1039 ssize_t
1040 ipv4_read_avail(net_protocol *_protocol)
1041 {
1042 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
1043 	RawSocket *raw = protocol->raw;
1044 	if (raw == NULL)
1045 		return B_ERROR;
1046 
1047 	return raw->BytesAvailable();
1048 }
1049 
1050 
1051 struct net_domain *
1052 ipv4_get_domain(net_protocol *protocol)
1053 {
1054 	return sDomain;
1055 }
1056 
1057 
1058 size_t
1059 ipv4_get_mtu(net_protocol *protocol, const struct sockaddr *address)
1060 {
1061 	net_route *route = sDatalinkModule->get_route(sDomain, address);
1062 	if (route == NULL)
1063 		return 0;
1064 
1065 	size_t mtu;
1066 	if (route->mtu != 0)
1067 		mtu = route->mtu;
1068 	else
1069 		mtu = route->interface->mtu;
1070 
1071 	sDatalinkModule->put_route(sDomain, route);
1072 	return mtu - sizeof(ipv4_header);
1073 }
1074 
1075 
1076 status_t
1077 ipv4_receive_data(net_buffer *buffer)
1078 {
1079 	TRACE(("IPv4 received a packet (%p) of %ld size!\n", buffer, buffer->size));
1080 
1081 	NetBufferHeader<ipv4_header> bufferHeader(buffer);
1082 	if (bufferHeader.Status() < B_OK)
1083 		return bufferHeader.Status();
1084 
1085 	ipv4_header &header = bufferHeader.Data();
1086 	bufferHeader.Detach();
1087 	//dump_ipv4_header(header);
1088 
1089 	if (header.version != IP_VERSION)
1090 		return B_BAD_TYPE;
1091 
1092 	uint16 packetLength = header.TotalLength();
1093 	uint16 headerLength = header.HeaderLength();
1094 	if (packetLength > buffer->size
1095 		|| headerLength < sizeof(ipv4_header))
1096 		return B_BAD_DATA;
1097 
1098 	// TODO: would be nice to have a direct checksum function somewhere
1099 	if (gBufferModule->checksum(buffer, 0, headerLength, true) != 0)
1100 		return B_BAD_DATA;
1101 
1102 	struct sockaddr_in &source = *(struct sockaddr_in *)&buffer->source;
1103 	struct sockaddr_in &destination = *(struct sockaddr_in *)&buffer->destination;
1104 
1105 	source.sin_len = sizeof(sockaddr_in);
1106 	source.sin_family = AF_INET;
1107 	source.sin_addr.s_addr = header.source;
1108 
1109 	destination.sin_len = sizeof(sockaddr_in);
1110 	destination.sin_family = AF_INET;
1111 	destination.sin_addr.s_addr = header.destination;
1112 
1113 	// test if the packet is really for us
1114 	uint32 matchedAddressType;
1115 	if (!sDatalinkModule->is_local_address(sDomain, (sockaddr*)&destination,
1116 		&buffer->interface, &matchedAddressType)) {
1117 		TRACE(("this packet was not for us %lx -> %lx\n",
1118 			ntohl(header.source), ntohl(header.destination)));
1119 		return B_ERROR;
1120 	}
1121 	if (matchedAddressType != 0) {
1122 		// copy over special address types (MSG_BCAST or MSG_MCAST):
1123 		buffer->flags |= matchedAddressType;
1124 	}
1125 
1126 	uint8 protocol = buffer->protocol = header.protocol;
1127 
1128 	// remove any trailing/padding data
1129 	status_t status = gBufferModule->trim(buffer, packetLength);
1130 	if (status < B_OK)
1131 		return status;
1132 
1133 	// check for fragmentation
1134 	uint16 fragmentOffset = ntohs(header.fragment_offset);
1135 	if ((fragmentOffset & IP_MORE_FRAGMENTS) != 0
1136 		|| (fragmentOffset & IP_FRAGMENT_OFFSET_MASK) != 0) {
1137 		// this is a fragment
1138 		TRACE(("   Found a Fragment!\n"));
1139 		status = reassemble_fragments(header, &buffer);
1140 		TRACE(("   -> %s!\n", strerror(status)));
1141 		if (status != B_OK)
1142 			return status;
1143 
1144 		if (buffer == NULL) {
1145 			// buffer was put into fragment packet
1146 			TRACE(("   Not yet assembled...\n"));
1147 			return B_OK;
1148 		}
1149 	}
1150 
1151 	// Since the buffer might have been changed (reassembled fragment)
1152 	// we must no longer access bufferHeader or header anymore after
1153 	// this point
1154 
1155 	if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) {
1156 		// SOCK_RAW doesn't get all packets
1157 		raw_receive_data(buffer);
1158 	}
1159 
1160 	gBufferModule->remove_header(buffer, headerLength);
1161 		// the header is of variable size and may include IP options
1162 		// (that we ignore for now)
1163 
1164 	net_protocol_module_info *module = receiving_protocol(protocol);
1165 	if (module == NULL) {
1166 		// no handler for this packet
1167 		return EAFNOSUPPORT;
1168 	}
1169 
1170 	return module->receive_data(buffer);
1171 }
1172 
1173 
1174 status_t
1175 ipv4_error(uint32 code, net_buffer *data)
1176 {
1177 	return B_ERROR;
1178 }
1179 
1180 
1181 status_t
1182 ipv4_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code,
1183 	void *errorData)
1184 {
1185 	return B_ERROR;
1186 }
1187 
1188 
1189 //	#pragma mark -
1190 
1191 
1192 status_t
1193 init_ipv4()
1194 {
1195 	status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule);
1196 	if (status < B_OK)
1197 		return status;
1198 	status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
1199 	if (status < B_OK)
1200 		goto err1;
1201 	status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule);
1202 	if (status < B_OK)
1203 		goto err2;
1204 
1205 	sPacketID = (int32)system_time();
1206 
1207 	status = benaphore_init(&sRawSocketsLock, "raw sockets");
1208 	if (status < B_OK)
1209 		goto err3;
1210 
1211 	status = benaphore_init(&sFragmentLock, "IPv4 Fragments");
1212 	if (status < B_OK)
1213 		goto err4;
1214 
1215 	status = benaphore_init(&sReceivingProtocolLock, "IPv4 receiving protocols");
1216 	if (status < B_OK)
1217 		goto err5;
1218 
1219 	sFragmentHash = hash_init(MAX_HASH_FRAGMENTS, FragmentPacket::NextOffset(),
1220 		&FragmentPacket::Compare, &FragmentPacket::Hash);
1221 	if (sFragmentHash == NULL)
1222 		goto err6;
1223 
1224 	new (&sRawSockets) RawSocketList;
1225 		// static initializers do not work in the kernel,
1226 		// so we have to do it here, manually
1227 		// TODO: for modules, this shouldn't be required
1228 
1229 	status = sStackModule->register_domain_protocols(AF_INET, SOCK_RAW, 0,
1230 		"network/protocols/ipv4/v1", NULL);
1231 	if (status < B_OK)
1232 		goto err7;
1233 
1234 	status = sStackModule->register_domain(AF_INET, "internet", &gIPv4Module,
1235 		&gIPv4AddressModule, &sDomain);
1236 	if (status < B_OK)
1237 		goto err7;
1238 
1239 	return B_OK;
1240 
1241 err7:
1242 	hash_uninit(sFragmentHash);
1243 err6:
1244 	benaphore_destroy(&sReceivingProtocolLock);
1245 err5:
1246 	benaphore_destroy(&sFragmentLock);
1247 err4:
1248 	benaphore_destroy(&sRawSocketsLock);
1249 err3:
1250 	put_module(NET_DATALINK_MODULE_NAME);
1251 err2:
1252 	put_module(NET_BUFFER_MODULE_NAME);
1253 err1:
1254 	put_module(NET_STACK_MODULE_NAME);
1255 	return status;
1256 }
1257 
1258 
1259 status_t
1260 uninit_ipv4()
1261 {
1262 	benaphore_lock(&sReceivingProtocolLock);
1263 
1264 	// put all the domain receiving protocols we gathered so far
1265 	for (uint32 i = 0; i < 256; i++) {
1266 		if (sReceivingProtocol[i] != NULL)
1267 			sStackModule->put_domain_receiving_protocol(sDomain, i);
1268 	}
1269 
1270 	sStackModule->unregister_domain(sDomain);
1271 	benaphore_unlock(&sReceivingProtocolLock);
1272 
1273 	hash_uninit(sFragmentHash);
1274 
1275 	benaphore_destroy(&sFragmentLock);
1276 	benaphore_destroy(&sRawSocketsLock);
1277 	benaphore_destroy(&sReceivingProtocolLock);
1278 
1279 	put_module(NET_DATALINK_MODULE_NAME);
1280 	put_module(NET_BUFFER_MODULE_NAME);
1281 	put_module(NET_STACK_MODULE_NAME);
1282 	return B_OK;
1283 }
1284 
1285 
1286 static status_t
1287 ipv4_std_ops(int32 op, ...)
1288 {
1289 	switch (op) {
1290 		case B_MODULE_INIT:
1291 			return init_ipv4();
1292 		case B_MODULE_UNINIT:
1293 			return uninit_ipv4();
1294 
1295 		default:
1296 			return B_ERROR;
1297 	}
1298 }
1299 
1300 
1301 net_protocol_module_info gIPv4Module = {
1302 	{
1303 		"network/protocols/ipv4/v1",
1304 		0,
1305 		ipv4_std_ops
1306 	},
1307 	ipv4_init_protocol,
1308 	ipv4_uninit_protocol,
1309 	ipv4_open,
1310 	ipv4_close,
1311 	ipv4_free,
1312 	ipv4_connect,
1313 	ipv4_accept,
1314 	ipv4_control,
1315 	ipv4_bind,
1316 	ipv4_unbind,
1317 	ipv4_listen,
1318 	ipv4_shutdown,
1319 	ipv4_send_data,
1320 	ipv4_send_routed_data,
1321 	ipv4_send_avail,
1322 	ipv4_read_data,
1323 	ipv4_read_avail,
1324 	ipv4_get_domain,
1325 	ipv4_get_mtu,
1326 	ipv4_receive_data,
1327 	ipv4_error,
1328 	ipv4_error_reply,
1329 };
1330 
1331 module_info *modules[] = {
1332 	(module_info *)&gIPv4Module,
1333 	NULL
1334 };
1335