xref: /haiku/src/add-ons/kernel/network/protocols/ipv4/ipv4.cpp (revision 3cb015b1ee509d69c643506e8ff573808c86dcfc)
1 /*
2  * Copyright 2006, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  */
8 
9 
10 #include "ipv4_address.h"
11 
12 #include <net_datalink.h>
13 #include <net_protocol.h>
14 #include <net_stack.h>
15 #include <NetBufferUtilities.h>
16 
17 #include <ByteOrder.h>
18 #include <KernelExport.h>
19 #include <util/AutoLock.h>
20 #include <util/list.h>
21 #include <util/khash.h>
22 #include <util/DoublyLinkedList.h>
23 
24 #include <netinet/in.h>
25 #include <netinet/ip.h>
26 #include <new>
27 #include <stdlib.h>
28 #include <string.h>
29 
30 
31 //#define TRACE_IPV4
32 #ifdef TRACE_IPV4
33 #	define TRACE(x) dprintf x
34 #else
35 #	define TRACE(x) ;
36 #endif
37 
38 struct ipv4_header {
39 #if B_HOST_IS_LENDIAN == 1
40 	uint8		header_length : 4;	// header length in 32-bit words
41 	uint8		version : 4;
42 #else
43 	uint8		version : 4;
44 	uint8		header_length : 4;
45 #endif
46 	uint8		service_type;
47 	uint16		total_length;
48 	uint16		id;
49 	uint16		fragment_offset;
50 	uint8		time_to_live;
51 	uint8		protocol;
52 	uint16		checksum;
53 	in_addr_t	source;
54 	in_addr_t	destination;
55 
56 	uint16 HeaderLength() const { return header_length << 2; }
57 	uint16 TotalLength() const { return ntohs(total_length); }
58 	uint16 FragmentOffset() const { return ntohs(fragment_offset); }
59 } _PACKED;
60 
61 #define IP_VERSION				4
62 
63 // fragment flags
64 #define IP_RESERVED_FLAG		0x8000
65 #define IP_DONT_FRAGMENT		0x4000
66 #define IP_MORE_FRAGMENTS		0x2000
67 #define IP_FRAGMENT_OFFSET_MASK	0x1fff
68 
69 #define MAX_HASH_FRAGMENTS 		64
70 	// slots in the fragment packet's hash
71 #define FRAGMENT_TIMEOUT		60000000LL
72 	// discard fragment after 60 seconds
73 
74 typedef DoublyLinkedList<struct net_buffer, DoublyLinkedListCLink<struct net_buffer> > FragmentList;
75 
76 struct ipv4_packet_key {
77 	in_addr_t	source;
78 	in_addr_t	destination;
79 	uint16		id;
80 	uint8		protocol;
81 };
82 
83 class FragmentPacket {
84 	public:
85 		FragmentPacket(const ipv4_packet_key &key);
86 		~FragmentPacket();
87 
88 		status_t AddFragment(uint16 start, uint16 end, net_buffer *buffer,
89 					bool lastFragment);
90 		status_t Reassemble(net_buffer *to);
91 
92 		bool IsComplete() const { return fReceivedLastFragment && fBytesLeft == 0; }
93 
94 		static uint32 Hash(void *_packet, const void *_key, uint32 range);
95 		static int Compare(void *_packet, const void *_key);
96 		static int32 NextOffset() { return offsetof(FragmentPacket, fNext); }
97 		static void StaleTimer(struct net_timer *timer, void *data);
98 
99 	private:
100 		FragmentPacket	*fNext;
101 		struct ipv4_packet_key fKey;
102 		bool			fReceivedLastFragment;
103 		int32			fBytesLeft;
104 		FragmentList	fFragments;
105 		net_timer		fTimer;
106 };
107 
108 typedef DoublyLinkedList<class RawSocket> RawSocketList;
109 
110 class RawSocket : public DoublyLinkedListLinkImpl<RawSocket> {
111 	public:
112 		RawSocket(net_socket *socket);
113 		~RawSocket();
114 
115 		status_t InitCheck();
116 
117 		status_t Read(size_t numBytes, uint32 flags, bigtime_t timeout,
118 					net_buffer **_buffer);
119 		ssize_t BytesAvailable();
120 
121 		status_t Write(net_buffer *buffer);
122 
123 	private:
124 		net_socket	*fSocket;
125 		net_fifo	fFifo;
126 };
127 
128 struct ipv4_protocol : net_protocol {
129 	RawSocket	*raw;
130 	uint8		service_type;
131 	uint8		time_to_live;
132 	uint32		flags;
133 };
134 
135 // protocol flags
136 #define IP_FLAG_HEADER_INCLUDED	0x01
137 
138 
139 extern net_protocol_module_info gIPv4Module;
140 	// we need this in ipv4_std_ops() for registering the AF_INET domain
141 
142 static struct net_domain *sDomain;
143 static net_datalink_module_info *sDatalinkModule;
144 static net_stack_module_info *sStackModule;
145 struct net_buffer_module_info *gBufferModule;
146 static int32 sPacketID;
147 static RawSocketList sRawSockets;
148 static benaphore sRawSocketsLock;
149 static benaphore sFragmentLock;
150 static hash_table *sFragmentHash;
151 
152 
153 RawSocket::RawSocket(net_socket *socket)
154 	:
155 	fSocket(socket)
156 {
157 	status_t status = sStackModule->init_fifo(&fFifo, "ipv4 raw socket", 65536);
158 	if (status < B_OK)
159 		fFifo.notify = status;
160 }
161 
162 
163 RawSocket::~RawSocket()
164 {
165 	if (fFifo.notify >= B_OK)
166 		sStackModule->uninit_fifo(&fFifo);
167 }
168 
169 
170 status_t
171 RawSocket::InitCheck()
172 {
173 	return fFifo.notify >= B_OK ? B_OK : fFifo.notify;
174 }
175 
176 
177 status_t
178 RawSocket::Read(size_t numBytes, uint32 flags, bigtime_t timeout,
179 	net_buffer **_buffer)
180 {
181 	net_buffer *buffer;
182 	status_t status = sStackModule->fifo_dequeue_buffer(&fFifo,
183 		flags, timeout, &buffer);
184 	if (status < B_OK)
185 		return status;
186 
187 	if (numBytes < buffer->size) {
188 		// discard any data behind the amount requested
189 		gBufferModule->trim(buffer, numBytes);
190 	}
191 
192 	*_buffer = buffer;
193 	return B_OK;
194 }
195 
196 
197 ssize_t
198 RawSocket::BytesAvailable()
199 {
200 	return fFifo.current_bytes;
201 }
202 
203 
204 status_t
205 RawSocket::Write(net_buffer *source)
206 {
207 	// we need to make a clone for that buffer and pass it to the socket
208 	net_buffer *buffer = gBufferModule->clone(source, false);
209 	TRACE(("ipv4::RawSocket::Write(): cloned buffer %p\n", buffer));
210 	if (buffer == NULL)
211 		return B_NO_MEMORY;
212 
213 	status_t status = sStackModule->fifo_enqueue_buffer(&fFifo, buffer);
214 	if (status >= B_OK)
215 		sStackModule->notify_socket(fSocket, B_SELECT_READ, BytesAvailable());
216 	else
217 		gBufferModule->free(buffer);
218 
219 	return status;
220 }
221 
222 
223 //	#pragma mark -
224 
225 
226 FragmentPacket::FragmentPacket(const ipv4_packet_key &key)
227 	:
228 	fKey(key),
229 	fReceivedLastFragment(false),
230 	fBytesLeft(IP_MAXPACKET)
231 {
232 	sStackModule->init_timer(&fTimer, StaleTimer, this);
233 }
234 
235 
236 FragmentPacket::~FragmentPacket()
237 {
238 	// cancel the kill timer
239 	sStackModule->set_timer(&fTimer, -1);
240 
241 	// delete all fragments
242 	net_buffer *buffer;
243 	while ((buffer = fFragments.RemoveHead()) != NULL) {
244 		gBufferModule->free(buffer);
245 	}
246 }
247 
248 
249 status_t
250 FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer *buffer,
251 	bool lastFragment)
252 {
253 	// restart the timer
254 	sStackModule->set_timer(&fTimer, FRAGMENT_TIMEOUT);
255 
256 	if (start >= end) {
257 		// invalid fragment
258 		return B_BAD_DATA;
259 	}
260 
261 	// Search for a position in the list to insert the fragment
262 
263 	FragmentList::ReverseIterator iterator = fFragments.GetReverseIterator();
264 	net_buffer *previous = NULL;
265 	net_buffer *next = NULL;
266 	while ((previous = iterator.Next()) != NULL) {
267 		if (previous->fragment.start <= start) {
268 			// The new fragment can be inserted after this one
269 			break;
270 		}
271 
272 		next = previous;
273 	}
274 
275 	// See if we already have the fragment's data
276 
277 	if (previous != NULL && previous->fragment.start <= start
278 		&& previous->fragment.end >= end) {
279 		// we do, so we can just drop this fragment
280 		gBufferModule->free(buffer);
281 		return B_OK;
282 	}
283 
284 	TRACE(("    previous: %p, next: %p\n", previous, next));
285 
286 	// If we have parts of the data already, truncate as needed
287 
288 	if (previous != NULL && previous->fragment.end > start) {
289 		TRACE(("    remove header %d bytes\n", previous->fragment.end - start));
290 		gBufferModule->remove_header(buffer, previous->fragment.end - start);
291 		start = previous->fragment.end;
292 	}
293 	if (next != NULL && next->fragment.start < end) {
294 		TRACE(("    remove trailer %d bytes\n", next->fragment.start - end));
295 		gBufferModule->remove_trailer(buffer, next->fragment.start - end);
296 		end = next->fragment.start;
297 	}
298 
299 	// Now try if we can already merge the fragments together
300 
301 	// We will always keep the last buffer received, so that we can still
302 	// report an error (in which case we're not responsible for freeing it)
303 
304 	if (previous != NULL && previous->fragment.end == start) {
305 		fFragments.Remove(previous);
306 
307 		buffer->fragment.start = previous->fragment.start;
308 		buffer->fragment.end = end;
309 
310 		status_t status = gBufferModule->merge(buffer, previous, false);
311 		TRACE(("    merge previous: %s\n", strerror(status)));
312 		if (status < B_OK) {
313 			fFragments.Insert(next, previous);
314 			return status;
315 		}
316 
317 		fFragments.Insert(next, buffer);
318 
319 		// cut down existing hole
320 		fBytesLeft -= end - start;
321 
322 		if (lastFragment && !fReceivedLastFragment) {
323 			fReceivedLastFragment = true;
324 			fBytesLeft -= IP_MAXPACKET - end;
325 		}
326 
327 		TRACE(("    hole length: %d\n", (int)fBytesLeft));
328 
329 		return B_OK;
330 	} else if (next != NULL && next->fragment.start == end) {
331 		fFragments.Remove(next);
332 
333 		buffer->fragment.start = start;
334 		buffer->fragment.end = next->fragment.end;
335 
336 		status_t status = gBufferModule->merge(buffer, next, true);
337 		TRACE(("    merge next: %s\n", strerror(status)));
338 		if (status < B_OK) {
339 			fFragments.Insert((net_buffer *)previous->link.next, next);
340 			return status;
341 		}
342 
343 		fFragments.Insert((net_buffer *)previous->link.next, buffer);
344 
345 		// cut down existing hole
346 		fBytesLeft -= end - start;
347 
348 		if (lastFragment && !fReceivedLastFragment) {
349 			fReceivedLastFragment = true;
350 			fBytesLeft -= IP_MAXPACKET - end;
351 		}
352 
353 		TRACE(("    hole length: %d\n", (int)fBytesLeft));
354 
355 		return B_OK;
356 	}
357 
358 	// We couldn't merge the fragments, so we need to add it as is
359 
360 	TRACE(("    new fragment: %p, bytes %d-%d\n", buffer, start, end));
361 
362 	buffer->fragment.start = start;
363 	buffer->fragment.end = end;
364 	fFragments.Insert(next, buffer);
365 
366 	// update length of the hole, if any
367 	fBytesLeft -= end - start;
368 
369 	if (lastFragment && !fReceivedLastFragment) {
370 		fReceivedLastFragment = true;
371 		fBytesLeft -= IP_MAXPACKET - end;
372 	}
373 
374 	TRACE(("    hole length: %d\n", (int)fBytesLeft));
375 
376 	return B_OK;
377 }
378 
379 
380 /*!
381 	Reassembles the fragments to the specified buffer \a to.
382 	This buffer must have been added via AddFragment() before.
383 */
384 status_t
385 FragmentPacket::Reassemble(net_buffer *to)
386 {
387 	if (!IsComplete())
388 		return NULL;
389 
390 	net_buffer *buffer = NULL;
391 
392 	net_buffer *fragment;
393 	while ((fragment = fFragments.RemoveHead()) != NULL) {
394 		if (buffer != NULL) {
395 			status_t status;
396 			if (to == fragment) {
397 				status = gBufferModule->merge(fragment, buffer, false);
398 				buffer = fragment;
399 			} else
400 				status = gBufferModule->merge(buffer, fragment, true);
401 			if (status < B_OK)
402 				return status;
403 		} else
404 			buffer = fragment;
405 	}
406 
407 	if (buffer != to)
408 		panic("ipv4 packet reassembly did not work correctly.\n");
409 
410 	return B_OK;
411 }
412 
413 
414 int
415 FragmentPacket::Compare(void *_packet, const void *_key)
416 {
417 	const ipv4_packet_key *key = (ipv4_packet_key *)_key;
418 	ipv4_packet_key *packetKey = &((FragmentPacket *)_packet)->fKey;
419 
420 	if (packetKey->id == key->id
421 		&& packetKey->source == key->source
422 		&& packetKey->destination == key->destination
423 		&& packetKey->protocol == key->protocol)
424 		return 0;
425 
426 	return 1;
427 }
428 
429 
430 uint32
431 FragmentPacket::Hash(void *_packet, const void *_key, uint32 range)
432 {
433 	const struct ipv4_packet_key *key = (struct ipv4_packet_key *)_key;
434 	FragmentPacket *packet = (FragmentPacket *)_packet;
435 	if (packet != NULL)
436 		key = &packet->fKey;
437 
438 	return (key->source ^ key->destination ^ key->protocol ^ key->id) % range;
439 }
440 
441 
442 /*static*/ void
443 FragmentPacket::StaleTimer(struct net_timer *timer, void *data)
444 {
445 	FragmentPacket *packet = (FragmentPacket *)data;
446 	TRACE(("Assembling FragmentPacket %p timed out!\n", packet));
447 
448 	BenaphoreLocker locker(&sFragmentLock);
449 
450 	hash_remove(sFragmentHash, packet);
451 	delete packet;
452 }
453 
454 
455 //	#pragma mark -
456 
457 
458 #if 0
459 static void
460 dump_ipv4_header(ipv4_header &header)
461 {
462 	struct pretty_ipv4 {
463 	#if B_HOST_IS_LENDIAN == 1
464 		uint8 a;
465 		uint8 b;
466 		uint8 c;
467 		uint8 d;
468 	#else
469 		uint8 d;
470 		uint8 c;
471 		uint8 b;
472 		uint8 a;
473 	#endif
474 	};
475 	struct pretty_ipv4 *src = (struct pretty_ipv4 *)&header.source;
476 	struct pretty_ipv4 *dst = (struct pretty_ipv4 *)&header.destination;
477 	dprintf("  version: %d\n", header.version);
478 	dprintf("  header_length: 4 * %d\n", header.header_length);
479 	dprintf("  service_type: %d\n", header.service_type);
480 	dprintf("  total_length: %d\n", header.TotalLength());
481 	dprintf("  id: %d\n", ntohs(header.id));
482 	dprintf("  fragment_offset: %d (flags: %c%c%c)\n",
483 		header.FragmentOffset() & IP_FRAGMENT_OFFSET_MASK,
484 		(header.FragmentOffset() & IP_RESERVED_FLAG) ? 'r' : '-',
485 		(header.FragmentOffset() & IP_DONT_FRAGMENT) ? 'd' : '-',
486 		(header.FragmentOffset() & IP_MORE_FRAGMENTS) ? 'm' : '-');
487 	dprintf("  time_to_live: %d\n", header.time_to_live);
488 	dprintf("  protocol: %d\n", header.protocol);
489 	dprintf("  checksum: %d\n", ntohs(header.checksum));
490 	dprintf("  source: %d.%d.%d.%d\n", src->a, src->b, src->c, src->d);
491 	dprintf("  destination: %d.%d.%d.%d\n", dst->a, dst->b, dst->c, dst->d);
492 }
493 #endif
494 
495 
496 /*!
497 	Attempts to re-assemble fragmented packets.
498 	\return B_OK if everything went well; if it could reassemble the packet, \a _buffer
499 		will point to its buffer, otherwise, it will be \c NULL.
500 	\return various error codes if something went wrong (mostly B_NO_MEMORY)
501 */
502 static status_t
503 reassemble_fragments(const ipv4_header &header, net_buffer **_buffer)
504 {
505 	net_buffer *buffer = *_buffer;
506 	status_t status;
507 
508 	struct ipv4_packet_key key;
509 	key.source = (in_addr_t)header.source;
510 	key.destination = (in_addr_t)header.destination;
511 	key.id = header.id;
512 	key.protocol = header.protocol;
513 
514 	// TODO: Make locking finer grained.
515 	BenaphoreLocker locker(&sFragmentLock);
516 
517 	FragmentPacket *packet = (FragmentPacket *)hash_lookup(sFragmentHash, &key);
518 	if (packet == NULL) {
519 		// New fragment packet
520 		packet = new (std::nothrow) FragmentPacket(key);
521 		if (packet == NULL)
522 			return B_NO_MEMORY;
523 
524 		// add packet to hash
525 		status = hash_insert(sFragmentHash, packet);
526 		if (status != B_OK) {
527 			delete packet;
528 			return status;
529 		}
530 	}
531 
532 	uint16 fragmentOffset = header.FragmentOffset();
533 	uint16 start = (fragmentOffset & IP_FRAGMENT_OFFSET_MASK) << 3;
534 	uint16 end = start + header.TotalLength() - header.HeaderLength();
535 	bool lastFragment = (fragmentOffset & IP_MORE_FRAGMENTS) == 0;
536 
537 	TRACE(("   Received IPv4 %sfragment of size %d, offset %d.\n",
538 		lastFragment ? "last ": "", end - start, start));
539 
540 	// Remove header unless this is the first fragment
541 	if (start != 0)
542 		gBufferModule->remove_header(buffer, header.HeaderLength());
543 
544 	status = packet->AddFragment(start, end, buffer, lastFragment);
545 	if (status != B_OK)
546 		return status;
547 
548 	if (packet->IsComplete()) {
549 		hash_remove(sFragmentHash, packet);
550 			// no matter if reassembling succeeds, we won't need this packet anymore
551 
552 		status = packet->Reassemble(buffer);
553 		delete packet;
554 
555 		// _buffer does not change
556 		return status;
557 	}
558 
559 	// This indicates that the packet is not yet complete
560 	*_buffer = NULL;
561 	return B_OK;
562 }
563 
564 
565 /*!
566 	Fragments the incoming buffer and send all fragments via the specified
567 	\a route.
568 */
569 static status_t
570 send_fragments(ipv4_protocol *protocol, struct net_route *route,
571 	net_buffer *buffer, uint32 mtu)
572 {
573 	TRACE(("ipv4 needs to fragment (size %lu, MTU %lu)...\n",
574 		buffer->size, mtu));
575 
576 	NetBufferHeader<ipv4_header> bufferHeader(buffer);
577 	if (bufferHeader.Status() < B_OK)
578 		return bufferHeader.Status();
579 
580 	ipv4_header *header = &bufferHeader.Data();
581 	bufferHeader.Detach();
582 
583 	uint16 headerLength = header->HeaderLength();
584 	uint32 bytesLeft = buffer->size - headerLength;
585 	uint32 fragmentOffset = 0;
586 	status_t status = B_OK;
587 
588 	net_buffer *headerBuffer = gBufferModule->split(buffer, headerLength);
589 	if (headerBuffer == NULL)
590 		return B_NO_MEMORY;
591 
592 	bufferHeader.SetTo(headerBuffer);
593 	header = &bufferHeader.Data();
594 	bufferHeader.Detach();
595 
596 	// adapt MTU to be a multiple of 8 (fragment offsets can only be specified this way)
597 	mtu -= headerLength;
598 	mtu &= ~7;
599 	dprintf("  adjusted MTU to %ld\n", mtu);
600 
601 	dprintf("  bytesLeft = %ld\n", bytesLeft);
602 	while (bytesLeft > 0) {
603 		uint32 fragmentLength = min_c(bytesLeft, mtu);
604 		bytesLeft -= fragmentLength;
605 		bool lastFragment = bytesLeft == 0;
606 
607 		header->total_length = htons(fragmentLength + headerLength);
608 		header->fragment_offset = htons((lastFragment ? 0 : IP_MORE_FRAGMENTS)
609 			| (fragmentOffset >> 3));
610 		header->checksum = 0;
611 		header->checksum = sStackModule->checksum((uint8 *)header, headerLength);
612 			// TODO: compute the checksum only for those parts that changed?
613 
614 		dprintf("  send fragment of %ld bytes (%ld bytes left)\n", fragmentLength, bytesLeft);
615 
616 		net_buffer *fragmentBuffer;
617 		if (!lastFragment) {
618 			fragmentBuffer = gBufferModule->split(buffer, fragmentLength);
619 			fragmentOffset += fragmentLength;
620 		} else
621 			fragmentBuffer = buffer;
622 
623 		if (fragmentBuffer == NULL) {
624 			status = B_NO_MEMORY;
625 			break;
626 		}
627 
628 		// copy header to fragment
629 		status = gBufferModule->prepend(fragmentBuffer, header, headerLength);
630 
631 		// send fragment
632 		if (status == B_OK)
633 			status = sDatalinkModule->send_data(route, fragmentBuffer);
634 
635 		if (lastFragment) {
636 			// we don't own the last buffer, so we don't have to free it
637 			break;
638 		}
639 
640 		if (status < B_OK) {
641 			gBufferModule->free(fragmentBuffer);
642 			break;
643 		}
644 	}
645 
646 	gBufferModule->free(headerBuffer);
647 	return status;
648 }
649 
650 
651 static void
652 raw_receive_data(net_buffer *buffer)
653 {
654 	BenaphoreLocker locker(sRawSocketsLock);
655 	RawSocketList::Iterator iterator = sRawSockets.GetIterator();
656 
657 	while (iterator.HasNext()) {
658 		RawSocket *raw = iterator.Next();
659 		raw->Write(buffer);
660 	}
661 }
662 
663 
664 //	#pragma mark -
665 
666 
667 net_protocol *
668 ipv4_init_protocol(net_socket *socket)
669 {
670 	ipv4_protocol *protocol = new (std::nothrow) ipv4_protocol;
671 	if (protocol == NULL)
672 		return NULL;
673 
674 	protocol->raw = NULL;
675 	protocol->service_type = 0;
676 	protocol->time_to_live = 254;
677 	protocol->flags = 0;
678 	return protocol;
679 }
680 
681 
682 status_t
683 ipv4_uninit_protocol(net_protocol *_protocol)
684 {
685 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
686 
687 	delete protocol->raw;
688 	delete protocol;
689 	return B_OK;
690 }
691 
692 
693 /*!
694 	Since open() is only called on the top level protocol, when we get here
695 	it means we are on a SOCK_RAW socket.
696 */
697 status_t
698 ipv4_open(net_protocol *_protocol)
699 {
700 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
701 
702 	RawSocket *raw = new (std::nothrow) RawSocket(protocol->socket);
703 	if (raw == NULL)
704 		return B_NO_MEMORY;
705 
706 	status_t status = raw->InitCheck();
707 	if (status < B_OK) {
708 		delete raw;
709 		return status;
710 	}
711 
712 	protocol->raw = raw;
713 
714 	BenaphoreLocker locker(sRawSocketsLock);
715 	sRawSockets.Add(raw);
716 	return B_OK;
717 }
718 
719 
720 status_t
721 ipv4_close(net_protocol *_protocol)
722 {
723 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
724 	RawSocket *raw = protocol->raw;
725 	if (raw == NULL)
726 		return B_ERROR;
727 
728 	BenaphoreLocker locker(sRawSocketsLock);
729 	sRawSockets.Remove(raw);
730 	delete raw;
731 	protocol->raw = NULL;
732 
733 	return B_OK;
734 }
735 
736 
737 status_t
738 ipv4_free(net_protocol *protocol)
739 {
740 	return B_OK;
741 }
742 
743 
744 status_t
745 ipv4_connect(net_protocol *protocol, const struct sockaddr *address)
746 {
747 	return B_ERROR;
748 }
749 
750 
751 status_t
752 ipv4_accept(net_protocol *protocol, struct net_socket **_acceptedSocket)
753 {
754 	return EOPNOTSUPP;
755 }
756 
757 
758 status_t
759 ipv4_control(net_protocol *_protocol, int level, int option, void *value,
760 	size_t *_length)
761 {
762 	if ((level & LEVEL_MASK) != IPPROTO_IP)
763 		return sDatalinkModule->control(sDomain, option, value, _length);
764 
765 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
766 
767 	if (level & LEVEL_GET_OPTION) {
768 		// get options
769 
770 		switch (option) {
771 			case IP_HDRINCL:
772 			{
773 				if (*_length != sizeof(int))
774 					return B_BAD_VALUE;
775 
776 				int headerIncluded = (protocol->flags & IP_FLAG_HEADER_INCLUDED) != 0;
777 				return user_memcpy(value, &headerIncluded, sizeof(headerIncluded));
778 			}
779 
780 			case IP_TTL:
781 			{
782 				if (*_length != sizeof(int))
783 					return B_BAD_VALUE;
784 
785 				int timeToLive = protocol->time_to_live;
786 				return user_memcpy(value, &timeToLive, sizeof(timeToLive));
787 			}
788 
789 			case IP_TOS:
790 			{
791 				if (*_length != sizeof(int))
792 					return B_BAD_VALUE;
793 
794 				int serviceType = protocol->service_type;
795 				return user_memcpy(value, &serviceType, sizeof(serviceType));
796 			}
797 
798 			default:
799 				dprintf("IPv4::control(): get unknown option: %d\n", option);
800 				return ENOPROTOOPT;
801 		}
802 	} else {
803 		// set options
804 
805 		switch (option) {
806 			case IP_HDRINCL:
807 			{
808 				int headerIncluded;
809 				if (*_length != sizeof(int))
810 					return B_BAD_VALUE;
811 				if (user_memcpy(&headerIncluded, value, sizeof(headerIncluded)) < B_OK)
812 					return B_BAD_ADDRESS;
813 
814 				if (headerIncluded)
815 					protocol->flags |= IP_FLAG_HEADER_INCLUDED;
816 				else
817 					protocol->flags &= ~IP_FLAG_HEADER_INCLUDED;
818 				break;
819 			}
820 
821 			case IP_TTL:
822 			{
823 				int timeToLive;
824 				if (*_length != sizeof(int))
825 					return B_BAD_VALUE;
826 				if (user_memcpy(&timeToLive, value, sizeof(timeToLive)) < B_OK)
827 					return B_BAD_ADDRESS;
828 
829 				protocol->time_to_live = timeToLive;
830 				break;
831 			}
832 
833 			case IP_TOS:
834 			{
835 				int serviceType;
836 				if (*_length != sizeof(int))
837 					return B_BAD_VALUE;
838 				if (user_memcpy(&serviceType, value, sizeof(serviceType)) < B_OK)
839 					return B_BAD_ADDRESS;
840 
841 				protocol->service_type = serviceType;
842 				break;
843 			}
844 
845 			default:
846 				dprintf("IPv4::control(): set unknown option: %d\n", option);
847 				return ENOPROTOOPT;
848 		}
849 	}
850 
851 	return B_BAD_VALUE;
852 }
853 
854 
855 status_t
856 ipv4_bind(net_protocol *protocol, struct sockaddr *address)
857 {
858 	if (address->sa_family != AF_INET)
859 		return EAFNOSUPPORT;
860 
861 	// only INADDR_ANY and addresses of local interfaces are accepted:
862 	if (((sockaddr_in *)address)->sin_addr.s_addr == INADDR_ANY
863 		|| sDatalinkModule->is_local_address(sDomain, address, NULL, NULL)) {
864 		protocol->socket->address.ss_len = sizeof(struct sockaddr_in);
865 			// explicitly set length, as our callers can't be trusted to
866 			// always provide the correct length!
867 		return B_OK;
868 	}
869 
870 	return B_ERROR;
871 		// address is unknown on this host
872 }
873 
874 
875 status_t
876 ipv4_unbind(net_protocol *protocol, struct sockaddr *address)
877 {
878 	// nothing to do here
879 	return B_OK;
880 }
881 
882 
883 status_t
884 ipv4_listen(net_protocol *protocol, int count)
885 {
886 	return EOPNOTSUPP;
887 }
888 
889 
890 status_t
891 ipv4_shutdown(net_protocol *protocol, int direction)
892 {
893 	return EOPNOTSUPP;
894 }
895 
896 
897 status_t
898 ipv4_send_routed_data(net_protocol *_protocol, struct net_route *route,
899 	net_buffer *buffer)
900 {
901 	if (route == NULL)
902 		return B_BAD_VALUE;
903 
904 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
905 	net_interface *interface = route->interface;
906 
907 	TRACE(("someone tries to send some actual routed data!\n"));
908 
909 	sockaddr_in &source = *(sockaddr_in *)&buffer->source;
910 	if (source.sin_addr.s_addr == INADDR_ANY && route->interface->address != NULL) {
911 		// replace an unbound source address with the address of the interface
912 		// TODO: couldn't we replace all addresses here?
913 		source.sin_addr.s_addr = ((sockaddr_in *)route->interface->address)->sin_addr.s_addr;
914 	}
915 
916 	// Add IP header (if needed)
917 
918 	if (protocol == NULL || (protocol->flags & IP_FLAG_HEADER_INCLUDED) == 0) {
919 		NetBufferPrepend<ipv4_header> bufferHeader(buffer);
920 		if (bufferHeader.Status() < B_OK)
921 			return bufferHeader.Status();
922 
923 		ipv4_header &header = bufferHeader.Data();
924 
925 		header.version = IP_VERSION;
926 		header.header_length = sizeof(ipv4_header) >> 2;
927 		header.service_type = protocol ? protocol->service_type : 0;
928 		header.total_length = htons(buffer->size);
929 		header.id = htons(atomic_add(&sPacketID, 1));
930 		header.fragment_offset = 0;
931 		header.time_to_live = protocol ? protocol->time_to_live : 254;
932 		header.protocol = protocol ? protocol->socket->protocol : buffer->protocol;
933 		header.checksum = 0;
934 		if (route->interface->address != NULL) {
935 			header.source = ((sockaddr_in *)route->interface->address)->sin_addr.s_addr;
936 				// always use the actual used source address
937 		} else
938 			header.source = 0;
939 
940 		header.destination = ((sockaddr_in *)&buffer->destination)->sin_addr.s_addr;
941 
942 		header.checksum = gBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true);
943 		//dump_ipv4_header(header);
944 
945 		bufferHeader.Detach();
946 			// make sure the IP-header is already written to the buffer at this point
947 	}
948 
949 	TRACE(("header chksum: %ld, buffer checksum: %ld\n",
950 		gBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true),
951 		gBufferModule->checksum(buffer, 0, buffer->size, true)));
952 
953 	TRACE(("destination-IP: buffer=%p addr=%p %08lx\n", buffer, &buffer->destination,
954 		ntohl(((sockaddr_in *)&buffer->destination)->sin_addr.s_addr)));
955 
956 	uint32 mtu = route->mtu ? route->mtu : interface->mtu;
957 	if (buffer->size > mtu) {
958 		// we need to fragment the packet
959 		return send_fragments(protocol, route, buffer, mtu);
960 	}
961 
962 	return sDatalinkModule->send_data(route, buffer);
963 }
964 
965 
966 status_t
967 ipv4_send_data(net_protocol *protocol, net_buffer *buffer)
968 {
969 	TRACE(("someone tries to send some actual data!\n"));
970 
971 	// find route
972 	struct net_route *route = sDatalinkModule->get_route(sDomain,
973 		(sockaddr *)&buffer->destination);
974 	if (route == NULL)
975 		return ENETUNREACH;
976 
977 	status_t status = ipv4_send_routed_data(protocol, route, buffer);
978 	sDatalinkModule->put_route(sDomain, route);
979 
980 	return status;
981 }
982 
983 
984 ssize_t
985 ipv4_send_avail(net_protocol *protocol)
986 {
987 	return B_ERROR;
988 }
989 
990 
991 status_t
992 ipv4_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
993 	net_buffer **_buffer)
994 {
995 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
996 	RawSocket *raw = protocol->raw;
997 	if (raw == NULL)
998 		return B_ERROR;
999 
1000 	TRACE(("read is waiting for data...\n"));
1001 	return raw->Read(numBytes, flags, protocol->socket->receive.timeout, _buffer);
1002 }
1003 
1004 
1005 ssize_t
1006 ipv4_read_avail(net_protocol *_protocol)
1007 {
1008 	ipv4_protocol *protocol = (ipv4_protocol *)_protocol;
1009 	RawSocket *raw = protocol->raw;
1010 	if (raw == NULL)
1011 		return B_ERROR;
1012 
1013 	return raw->BytesAvailable();
1014 }
1015 
1016 
1017 struct net_domain *
1018 ipv4_get_domain(net_protocol *protocol)
1019 {
1020 	return sDomain;
1021 }
1022 
1023 
1024 size_t
1025 ipv4_get_mtu(net_protocol *protocol, const struct sockaddr *address)
1026 {
1027 	net_route *route = sDatalinkModule->get_route(sDomain, address);
1028 	if (route == NULL)
1029 		return 0;
1030 
1031 	size_t mtu;
1032 	if (route->mtu != 0)
1033 		mtu = route->mtu;
1034 	else
1035 		mtu = route->interface->mtu;
1036 
1037 	sDatalinkModule->put_route(sDomain, route);
1038 	return mtu - sizeof(ipv4_header);
1039 }
1040 
1041 
1042 status_t
1043 ipv4_receive_data(net_buffer *buffer)
1044 {
1045 	TRACE(("IPv4 received a packet (%p) of %ld size!\n", buffer, buffer->size));
1046 
1047 	NetBufferHeader<ipv4_header> bufferHeader(buffer);
1048 	if (bufferHeader.Status() < B_OK)
1049 		return bufferHeader.Status();
1050 
1051 	ipv4_header &header = bufferHeader.Data();
1052 	bufferHeader.Detach();
1053 	//dump_ipv4_header(header);
1054 
1055 	if (header.version != IP_VERSION)
1056 		return B_BAD_TYPE;
1057 
1058 	uint16 packetLength = header.TotalLength();
1059 	uint16 headerLength = header.HeaderLength();
1060 	if (packetLength > buffer->size
1061 		|| headerLength < sizeof(ipv4_header))
1062 		return B_BAD_DATA;
1063 
1064 	// TODO: would be nice to have a direct checksum function somewhere
1065 	if (gBufferModule->checksum(buffer, 0, headerLength, true) != 0)
1066 		return B_BAD_DATA;
1067 
1068 	struct sockaddr_in &source = *(struct sockaddr_in *)&buffer->source;
1069 	struct sockaddr_in &destination = *(struct sockaddr_in *)&buffer->destination;
1070 
1071 	source.sin_len = sizeof(sockaddr_in);
1072 	source.sin_family = AF_INET;
1073 	source.sin_addr.s_addr = header.source;
1074 
1075 	destination.sin_len = sizeof(sockaddr_in);
1076 	destination.sin_family = AF_INET;
1077 	destination.sin_addr.s_addr = header.destination;
1078 
1079 	// test if the packet is really for us
1080 	uint32 matchedAddressType;
1081 	if (!sDatalinkModule->is_local_address(sDomain, (sockaddr*)&destination,
1082 		&buffer->interface, &matchedAddressType)) {
1083 		TRACE(("this packet was not for us\n"));
1084 		return B_ERROR;
1085 	}
1086 	if (matchedAddressType != 0) {
1087 		// copy over special address types (MSG_BCAST or MSG_MCAST):
1088 		buffer->flags |= matchedAddressType;
1089 	}
1090 
1091 	uint8 protocol = buffer->protocol = header.protocol;
1092 
1093 	// remove any trailing/padding data
1094 	status_t status = gBufferModule->trim(buffer, packetLength);
1095 	if (status < B_OK)
1096 		return status;
1097 
1098 	// check for fragmentation
1099 	uint16 fragmentOffset = ntohs(header.fragment_offset);
1100 	if ((fragmentOffset & IP_MORE_FRAGMENTS) != 0
1101 		|| (fragmentOffset & IP_FRAGMENT_OFFSET_MASK) != 0) {
1102 		// this is a fragment
1103 		TRACE(("   Found a Fragment!\n"));
1104 		status = reassemble_fragments(header, &buffer);
1105 		TRACE(("   -> %s!\n", strerror(status)));
1106 		if (status != B_OK)
1107 			return status;
1108 
1109 		if (buffer == NULL) {
1110 			// buffer was put into fragment packet
1111 			TRACE(("   Not yet assembled...\n"));
1112 			return B_OK;
1113 		}
1114 	}
1115 
1116 	// Since the buffer might have been changed (reassembled fragment)
1117 	// we must no longer access bufferHeader or header anymore after
1118 	// this point
1119 
1120 	if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) {
1121 		// SOCK_RAW doesn't get all packets
1122 		raw_receive_data(buffer);
1123 	}
1124 
1125 	gBufferModule->remove_header(buffer, headerLength);
1126 		// the header is of variable size and may include IP options
1127 		// (that we ignore for now)
1128 
1129 	// TODO: since we'll doing this for every packet, we may want to cache the module
1130 	//	(and only put them when we're about to be unloaded)
1131 	net_protocol_module_info *module;
1132 	status = sStackModule->get_domain_receiving_protocol(sDomain, protocol, &module);
1133 	if (status < B_OK) {
1134 		// no handler for this packet
1135 		return status;
1136 	}
1137 
1138 	status = module->receive_data(buffer);
1139 	sStackModule->put_domain_receiving_protocol(sDomain, protocol);
1140 
1141 	return status;
1142 }
1143 
1144 
1145 status_t
1146 ipv4_error(uint32 code, net_buffer *data)
1147 {
1148 	return B_ERROR;
1149 }
1150 
1151 
1152 status_t
1153 ipv4_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code,
1154 	void *errorData)
1155 {
1156 	return B_ERROR;
1157 }
1158 
1159 
1160 //	#pragma mark -
1161 
1162 
1163 status_t
1164 init_ipv4()
1165 {
1166 	status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule);
1167 	if (status < B_OK)
1168 		return status;
1169 	status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
1170 	if (status < B_OK)
1171 		goto err1;
1172 	status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule);
1173 	if (status < B_OK)
1174 		goto err2;
1175 
1176 	sPacketID = (int32)system_time();
1177 
1178 	status = benaphore_init(&sRawSocketsLock, "raw sockets");
1179 	if (status < B_OK)
1180 		goto err3;
1181 
1182 	status = benaphore_init(&sFragmentLock, "IPv4 Fragments");
1183 	if (status < B_OK)
1184 		goto err4;
1185 
1186 	sFragmentHash = hash_init(MAX_HASH_FRAGMENTS, FragmentPacket::NextOffset(),
1187 		&FragmentPacket::Compare, &FragmentPacket::Hash);
1188 	if (sFragmentHash == NULL)
1189 		goto err5;
1190 
1191 	new (&sRawSockets) RawSocketList;
1192 		// static initializers do not work in the kernel,
1193 		// so we have to do it here, manually
1194 		// TODO: for modules, this shouldn't be required
1195 
1196 	status = sStackModule->register_domain_protocols(AF_INET, SOCK_RAW, 0,
1197 		"network/protocols/ipv4/v1", NULL);
1198 	if (status < B_OK)
1199 		goto err6;
1200 
1201 	status = sStackModule->register_domain(AF_INET, "internet", &gIPv4Module,
1202 		&gIPv4AddressModule, &sDomain);
1203 	if (status < B_OK)
1204 		goto err6;
1205 
1206 	return B_OK;
1207 
1208 err6:
1209 	hash_uninit(sFragmentHash);
1210 err5:
1211 	benaphore_destroy(&sFragmentLock);
1212 err4:
1213 	benaphore_destroy(&sRawSocketsLock);
1214 err3:
1215 	put_module(NET_DATALINK_MODULE_NAME);
1216 err2:
1217 	put_module(NET_BUFFER_MODULE_NAME);
1218 err1:
1219 	put_module(NET_STACK_MODULE_NAME);
1220 	return status;
1221 }
1222 
1223 
1224 status_t
1225 uninit_ipv4()
1226 {
1227 	hash_uninit(sFragmentHash);
1228 
1229 	benaphore_destroy(&sFragmentLock);
1230 	benaphore_destroy(&sRawSocketsLock);
1231 
1232 	sStackModule->unregister_domain(sDomain);
1233 	put_module(NET_DATALINK_MODULE_NAME);
1234 	put_module(NET_BUFFER_MODULE_NAME);
1235 	put_module(NET_STACK_MODULE_NAME);
1236 	return B_OK;
1237 }
1238 
1239 
1240 static status_t
1241 ipv4_std_ops(int32 op, ...)
1242 {
1243 	switch (op) {
1244 		case B_MODULE_INIT:
1245 			return init_ipv4();
1246 		case B_MODULE_UNINIT:
1247 			return uninit_ipv4();
1248 
1249 		default:
1250 			return B_ERROR;
1251 	}
1252 }
1253 
1254 
1255 net_protocol_module_info gIPv4Module = {
1256 	{
1257 		"network/protocols/ipv4/v1",
1258 		0,
1259 		ipv4_std_ops
1260 	},
1261 	ipv4_init_protocol,
1262 	ipv4_uninit_protocol,
1263 	ipv4_open,
1264 	ipv4_close,
1265 	ipv4_free,
1266 	ipv4_connect,
1267 	ipv4_accept,
1268 	ipv4_control,
1269 	ipv4_bind,
1270 	ipv4_unbind,
1271 	ipv4_listen,
1272 	ipv4_shutdown,
1273 	ipv4_send_data,
1274 	ipv4_send_routed_data,
1275 	ipv4_send_avail,
1276 	ipv4_read_data,
1277 	ipv4_read_avail,
1278 	ipv4_get_domain,
1279 	ipv4_get_mtu,
1280 	ipv4_receive_data,
1281 	ipv4_error,
1282 	ipv4_error_reply,
1283 };
1284 
1285 module_info *modules[] = {
1286 	(module_info *)&gIPv4Module,
1287 	NULL
1288 };
1289