xref: /haiku/src/add-ons/kernel/network/stack/routes.cpp (revision 1acbe440b8dd798953bec31d18ee589aa3f71b73)
1 /*
2  * Copyright 2006-2007, Haiku, Inc. All Rights Reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Axel Dörfler, axeld@pinc-software.de
7  */
8 
9 
10 #include "domains.h"
11 #include "routes.h"
12 #include "stack_private.h"
13 #include "utility.h"
14 
15 #include <net_device.h>
16 #include <NetUtilities.h>
17 
18 #include <lock.h>
19 #include <util/AutoLock.h>
20 
21 #include <KernelExport.h>
22 
23 #include <net/if_dl.h>
24 #include <net/route.h>
25 #include <new>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/sockio.h>
29 
30 
31 //#define TRACE_ROUTES
32 #ifdef TRACE_ROUTES
33 #	define TRACE(x) dprintf x
34 #else
35 #	define TRACE(x) ;
36 #endif
37 
38 
39 net_route_private::net_route_private()
40 {
41 	destination = mask = gateway = NULL;
42 }
43 
44 
45 net_route_private::~net_route_private()
46 {
47 	free(destination);
48 	free(mask);
49 	free(gateway);
50 }
51 
52 
53 //	#pragma mark -
54 
55 
56 static status_t
57 user_copy_address(const sockaddr *from, sockaddr **to)
58 {
59 	if (from == NULL) {
60 		*to = NULL;
61 		return B_OK;
62 	}
63 
64 	sockaddr address;
65 	if (user_memcpy(&address, from, sizeof(struct sockaddr)) < B_OK)
66 		return B_BAD_ADDRESS;
67 
68 	*to = (sockaddr *)malloc(address.sa_len);
69 	if (*to == NULL)
70 		return B_NO_MEMORY;
71 
72 	if (address.sa_len > sizeof(struct sockaddr)) {
73 		if (user_memcpy(*to, from, address.sa_len) < B_OK)
74 			return B_BAD_ADDRESS;
75 	} else
76 		memcpy(*to, &address, address.sa_len);
77 
78 	return B_OK;
79 }
80 
81 
82 static status_t
83 user_copy_address(const sockaddr *from, sockaddr_storage *to)
84 {
85 	if (from == NULL)
86 		return B_BAD_ADDRESS;
87 
88 	if (user_memcpy(to, from, sizeof(sockaddr)) < B_OK)
89 		return B_BAD_ADDRESS;
90 
91 	if (to->ss_len > sizeof(sockaddr)) {
92 		if (to->ss_len > sizeof(sockaddr_storage))
93 			return B_BAD_VALUE;
94 		if (user_memcpy(to, from, to->ss_len) < B_OK)
95 			return B_BAD_ADDRESS;
96 	}
97 
98 	return B_OK;
99 }
100 
101 static net_route_private *
102 find_route(struct net_domain *_domain, const net_route *description)
103 {
104 	struct net_domain_private *domain = (net_domain_private *)_domain;
105 	RouteList::Iterator iterator = domain->routes.GetIterator();
106 
107 	while (iterator.HasNext()) {
108 		net_route_private *route = iterator.Next();
109 
110 		if ((route->flags & RTF_DEFAULT) != 0
111 			&& (description->flags & RTF_DEFAULT) != 0) {
112 			// there can only be one default route
113 			return route;
114 		}
115 
116 		if ((route->flags & (RTF_GATEWAY | RTF_HOST | RTF_LOCAL)) ==
117 				(description->flags & (RTF_GATEWAY | RTF_HOST | RTF_LOCAL))
118 			&& domain->address_module->equal_masked_addresses(route->destination,
119 				description->destination, description->mask)
120 			&& domain->address_module->equal_addresses(route->mask,
121 				description->mask)
122 			&& domain->address_module->equal_addresses(route->gateway,
123 				description->gateway))
124 			return route;
125 	}
126 
127 	return NULL;
128 }
129 
130 
131 static net_route_private *
132 find_route(struct net_domain *_domain, const struct sockaddr *address)
133 {
134 	struct net_domain_private *domain = (net_domain_private *)_domain;
135 
136 	// TODO: the following only works for IPv4 routes!
137 	if (domain->family != AF_INET)
138 		panic("you should have known better...");
139 
140 	// find last matching route
141 
142 	RouteList::Iterator iterator = domain->routes.GetIterator();
143 	TRACE(("test address %s for routes...\n", AddressString(domain, address).Data()));
144 
145 	while (iterator.HasNext()) {
146 		net_route_private *route = iterator.Next();
147 
148 		bool found;
149  		if (route->mask != NULL) {
150 			sockaddr maskedAddress;
151 			domain->address_module->mask_address(address, route->mask,
152 				&maskedAddress);
153 			found = domain->address_module->equal_addresses(&maskedAddress,
154 				route->destination);
155 		} else {
156 			found = domain->address_module->equal_addresses(address,
157 				route->destination);
158  		}
159 
160 		if (found) {
161 			TRACE(("  found route: %s, flags %lx\n",
162 				AddressString(domain, route->destination).Data(), route->flags));
163 			return route;
164 		}
165 	}
166 
167 	return NULL;
168 }
169 
170 
171 static void
172 put_route_internal(struct net_domain_private *domain, net_route *_route)
173 {
174 	net_route_private *route = (net_route_private *)_route;
175 	if (route == NULL || atomic_add(&route->ref_count, -1) != 1)
176 		return;
177 
178 	// remove route
179 
180 	domain->routes.Remove(route);
181 	delete route;
182 }
183 
184 
185 struct net_route *
186 get_route_internal(struct net_domain_private *domain, const struct sockaddr *address)
187 {
188 	net_route_private *route = find_route(domain, address);
189 	if (route != NULL && atomic_add(&route->ref_count, 1) == 0) {
190 		// route has been deleted already
191 		route = NULL;
192 	}
193 
194 	return route;
195 }
196 
197 
198 void
199 update_route_infos(struct net_domain_private *domain)
200 {
201 	RouteInfoList::Iterator iterator = domain->route_infos.GetIterator();
202 
203 	while (iterator.HasNext()) {
204 		net_route_info *info = iterator.Next();
205 
206 		put_route_internal(domain, info->route);
207 		info->route = get_route_internal(domain, &info->address);
208 	}
209 }
210 
211 
212 //	#pragma mark -
213 
214 
215 /*!
216 	Determines the size of a buffer large enough to contain the whole
217 	routing table.
218 */
219 uint32
220 route_table_size(net_domain_private *domain)
221 {
222 	BenaphoreLocker locker(domain->lock);
223 	uint32 size = 0;
224 
225 	RouteList::Iterator iterator = domain->routes.GetIterator();
226 	while (iterator.HasNext()) {
227 		net_route_private *route = iterator.Next();
228 		size += IF_NAMESIZE + sizeof(route_entry);
229 
230 		if (route->destination)
231 			size += route->destination->sa_len;
232 		if (route->mask)
233 			size += route->mask->sa_len;
234 		if (route->gateway)
235 			size += route->gateway->sa_len;
236 	}
237 
238 	return size;
239 }
240 
241 
242 /*!
243 	Dumps a list of all routes into the supplied userland buffer.
244 	If the routes don't fit into the buffer, an error (\c ENOBUFS) is
245 	returned.
246 */
247 status_t
248 list_routes(net_domain_private *domain, void *buffer, size_t size)
249 {
250 	RouteList::Iterator iterator = domain->routes.GetIterator();
251 	size_t spaceLeft = size;
252 
253 	sockaddr zeros;
254 	memset(&zeros, 0, sizeof(sockaddr));
255 	zeros.sa_family = domain->family;
256 	zeros.sa_len = sizeof(sockaddr);
257 
258 	while (iterator.HasNext()) {
259 		net_route *route = iterator.Next();
260 
261 		size = IF_NAMESIZE + sizeof(route_entry);
262 
263 		sockaddr *destination = NULL;
264 		sockaddr *mask = NULL;
265 		sockaddr *gateway = NULL;
266 		uint8 *next = (uint8 *)buffer + size;
267 
268 		if (route->destination != NULL) {
269 			destination = (sockaddr *)next;
270 			next += route->destination->sa_len;
271 			size += route->destination->sa_len;
272 		}
273 		if (route->mask != NULL) {
274 			mask = (sockaddr *)next;
275 			next += route->mask->sa_len;
276 			size += route->mask->sa_len;
277 		}
278 		if (route->gateway != NULL) {
279 			gateway = (sockaddr *)next;
280 			next += route->gateway->sa_len;
281 			size += route->gateway->sa_len;
282 		}
283 
284 		if (spaceLeft < size)
285 			return ENOBUFS;
286 
287 		ifreq request;
288 		strlcpy(request.ifr_name, route->interface->name, IF_NAMESIZE);
289 		request.ifr_route.destination = destination;
290 		request.ifr_route.mask = mask;
291 		request.ifr_route.gateway = gateway;
292 		request.ifr_route.mtu = route->mtu;
293 		request.ifr_route.flags = route->flags;
294 
295 		if (user_memcpy(buffer, &request, size) < B_OK
296 			|| (route->destination != NULL && user_memcpy(request.ifr_route.destination, route->destination, route->destination->sa_len) < B_OK)
297 			|| (route->mask != NULL && user_memcpy(request.ifr_route.mask, route->mask, route->mask->sa_len) < B_OK)
298 			|| (route->gateway != NULL && user_memcpy(request.ifr_route.gateway, route->gateway, route->gateway->sa_len) < B_OK))
299 			return B_BAD_ADDRESS;
300 
301 		buffer = (void *)next;
302 		spaceLeft -= size;
303 	}
304 
305 	return B_OK;
306 }
307 
308 
309 status_t
310 control_routes(struct net_interface *interface, int32 option, void *argument, size_t length)
311 {
312 	net_domain_private *domain = (net_domain_private *)interface->domain;
313 
314 	switch (option) {
315 		case SIOCADDRT:
316 		case SIOCDELRT:
317 		{
318 			// add or remove a route
319 			if (length != sizeof(struct ifreq))
320 				return B_BAD_VALUE;
321 
322 			route_entry entry;
323 			if (user_memcpy(&entry, &((ifreq *)argument)->ifr_route, sizeof(route_entry)) != B_OK)
324 				return B_BAD_ADDRESS;
325 
326 			net_route_private route;
327 			status_t status;
328 			if ((status = user_copy_address(entry.destination, &route.destination)) != B_OK
329 				|| (status = user_copy_address(entry.mask, &route.mask)) != B_OK
330 				|| (status = user_copy_address(entry.gateway, &route.gateway)) != B_OK)
331 				return status;
332 
333 			route.mtu = entry.mtu;
334 			route.flags = entry.flags;
335 			route.interface = interface;
336 
337 			if (option == SIOCADDRT)
338 				return add_route(domain, &route);
339 
340 			return remove_route(domain, &route);
341 		}
342 	}
343 	return B_BAD_VALUE;
344 }
345 
346 
347 status_t
348 add_route(struct net_domain *_domain, const struct net_route *newRoute)
349 {
350 	struct net_domain_private *domain = (net_domain_private *)_domain;
351 
352 	TRACE(("add route to domain %s: dest %s, mask %s, gw %s, flags %lx\n",
353 		domain->name,
354 		AddressString(domain, newRoute->destination ? newRoute->destination : NULL).Data(),
355 		AddressString(domain, newRoute->mask ? newRoute->mask : NULL).Data(),
356 		AddressString(domain, newRoute->gateway ? newRoute->gateway : NULL).Data(),
357 		newRoute->flags));
358 
359 	if (domain == NULL || newRoute == NULL || newRoute->interface == NULL
360 		|| ((newRoute->flags & RTF_HOST) != 0 && newRoute->mask != NULL)
361 		|| ((newRoute->flags & RTF_DEFAULT) == 0 && newRoute->destination == NULL)
362 		|| ((newRoute->flags & RTF_GATEWAY) != 0 && newRoute->gateway == NULL)
363 		|| !domain->address_module->check_mask(newRoute->mask))
364 		return B_BAD_VALUE;
365 
366 	net_route_private *route = find_route(domain, newRoute);
367 	if (route != NULL)
368 		return B_FILE_EXISTS;
369 
370 	route = new (std::nothrow) net_route_private;
371 	if (route == NULL)
372 		return B_NO_MEMORY;
373 
374 	if (domain->address_module->copy_address(newRoute->destination,
375 		&route->destination, (newRoute->flags & RTF_DEFAULT) != 0,
376 		newRoute->mask) != B_OK
377 		|| domain->address_module->copy_address(newRoute->mask, &route->mask,
378 				(newRoute->flags & RTF_DEFAULT) != 0, NULL) != B_OK
379 		|| domain->address_module->copy_address(newRoute->gateway,
380 			&route->gateway, false, NULL) != B_OK) {
381 		delete route;
382 		return B_NO_MEMORY;
383 	}
384 
385 	route->flags = newRoute->flags;
386 	route->interface = newRoute->interface;
387 	route->mtu = 0;
388 	route->ref_count = 1;
389 
390 	// TODO: for now...
391 	//BenaphoreLocker locker(domain->lock);
392 
393 	// Insert the route sorted by completeness of its mask
394 
395 	RouteList::Iterator iterator = domain->routes.GetIterator();
396 	net_route_private *before = NULL;
397 
398 	while ((before = iterator.Next()) != NULL) {
399 		// if the before mask is less specific than the one of the route,
400 		// we can insert it before that route.
401 		if (domain->address_module->first_mask_bit(before->mask)
402 			> domain->address_module->first_mask_bit(route->mask))
403 			break;
404 	}
405 
406 	domain->routes.Insert(before, route);
407 	update_route_infos(domain);
408 
409 	return B_OK;
410 }
411 
412 
413 status_t
414 remove_route(struct net_domain *_domain, const struct net_route *removeRoute)
415 {
416 	struct net_domain_private *domain = (net_domain_private *)_domain;
417 
418 	TRACE(("remove route from domain %s: dest %s, mask %s, gw %s, flags %lx\n",
419 		domain->name,
420 		AddressString(domain, removeRoute->destination ? removeRoute->destination : NULL).Data(),
421 		AddressString(domain, removeRoute->mask ? removeRoute->mask : NULL).Data(),
422 		AddressString(domain, removeRoute->gateway ? removeRoute->gateway : NULL).Data(),
423 		removeRoute->flags));
424 
425 	// TODO: for now...
426 	//BenaphoreLocker locker(domain->lock);
427 
428 	net_route_private *route = find_route(domain, removeRoute);
429 	if (route == NULL)
430 		return B_ENTRY_NOT_FOUND;
431 
432 	put_route_internal(domain, route);
433 	update_route_infos(domain);
434 	return B_OK;
435 }
436 
437 
438 static sockaddr *
439 copy_address(UserBuffer &buffer, sockaddr *address)
440 {
441 	if (address == NULL)
442 		return NULL;
443 
444 	return (sockaddr *)buffer.Copy(address, address->sa_len);
445 }
446 
447 static status_t
448 fill_route_entry(route_entry *target, void *_buffer, size_t bufferSize,
449 		 net_route *route)
450 {
451 	UserBuffer buffer(((uint8 *)_buffer) + sizeof(route_entry),
452 		bufferSize - sizeof(route_entry));
453 
454 	target->destination = copy_address(buffer, route->destination);
455 	target->mask = copy_address(buffer, route->mask);
456 	target->gateway = copy_address(buffer, route->gateway);
457 	target->source = copy_address(buffer, route->interface->address);
458 	target->flags = route->flags;
459 	target->mtu = route->mtu;
460 
461 	return buffer.Status();
462 }
463 
464 
465 status_t
466 get_route_information(struct net_domain *_domain, void *value, size_t length)
467 {
468 	struct net_domain_private *domain = (net_domain_private *)_domain;
469 
470 	if (length < sizeof(route_entry))
471 		return B_BAD_VALUE;
472 
473 	route_entry entry;
474 	if (user_memcpy(&entry, value, sizeof(route_entry)) < B_OK)
475 		return B_BAD_ADDRESS;
476 
477 	sockaddr_storage destination;
478 	status_t status = user_copy_address(entry.destination, &destination);
479 	if (status != B_OK)
480 		return status;
481 
482 	BenaphoreLocker locker(domain->lock);
483 
484 	net_route_private *route = find_route(domain, (sockaddr *)&destination);
485 	if (route == NULL)
486 		return B_ENTRY_NOT_FOUND;
487 
488 	status = fill_route_entry(&entry, value, length, route);
489 	if (status != B_OK)
490 		return status;
491 
492 	return user_memcpy(value, &entry, sizeof(route_entry));
493 }
494 
495 
496 void
497 invalidate_routes(net_domain *_domain, net_interface *interface)
498 {
499 	// this function is called with the domain locked
500 	// (see domain_interface_went_down)
501 	net_domain_private *domain = (net_domain_private *)_domain;
502 
503 	dprintf("invalidate_routes(%i, %s)\n", domain->family, interface->name);
504 
505 	RouteList::Iterator iterator = domain->routes.GetIterator();
506 	while (iterator.HasNext()) {
507 		net_route *route = iterator.Next();
508 
509 		// TODO handle refcounting, if the route needs to linger
510 		//      for some reason we should set interface or
511 		//      something of the sorts that invalidates it's reference
512 		if (route->interface == interface)
513 			remove_route(domain, route);
514 	}
515 }
516 
517 
518 struct net_route *
519 get_route(struct net_domain *_domain, const struct sockaddr *address)
520 {
521 	struct net_domain_private *domain = (net_domain_private *)_domain;
522 	BenaphoreLocker locker(domain->lock);
523 
524 	return get_route_internal(domain, address);
525 }
526 
527 
528 void
529 put_route(struct net_domain *_domain, net_route *route)
530 {
531 	struct net_domain_private *domain = (net_domain_private *)_domain;
532 	BenaphoreLocker locker(domain->lock);
533 
534 	put_route_internal(domain, (net_route *)route);
535 }
536 
537 
538 status_t
539 register_route_info(struct net_domain *_domain,
540 	struct net_route_info *info)
541 {
542 	struct net_domain_private *domain = (net_domain_private *)_domain;
543 	BenaphoreLocker locker(domain->lock);
544 
545 	domain->route_infos.Add(info);
546 	info->route = get_route_internal(domain, &info->address);
547 
548 	return B_OK;
549 }
550 
551 
552 status_t
553 unregister_route_info(struct net_domain *_domain,
554 	struct net_route_info *info)
555 {
556 	struct net_domain_private *domain = (net_domain_private *)_domain;
557 	BenaphoreLocker locker(domain->lock);
558 
559 	domain->route_infos.Remove(info);
560 	if (info->route != NULL)
561 		put_route_internal(domain, info->route);
562 
563 	return B_OK;
564 }
565 
566 
567 status_t
568 update_route_info(struct net_domain *domain,
569 	struct net_route_info *info)
570 {
571 	return B_ERROR;
572 }
573 
574