xref: /haiku/src/add-ons/kernel/network/protocols/unix/unix.cpp (revision 344ded80d400028c8f561b4b876257b94c12db4a)
1 /*
2  * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de.
3  * Distributed under the terms of the MIT License.
4  */
5 
6 
7 #include <stdio.h>
8 #include <sys/un.h>
9 
10 #include <new>
11 
12 #include <AutoDeleter.h>
13 #include <StackOrHeapArray.h>
14 
15 #include <fs/fd.h>
16 #include <lock.h>
17 #include <util/AutoLock.h>
18 #include <vfs.h>
19 
20 #include <net_buffer.h>
21 #include <net_protocol.h>
22 #include <net_socket.h>
23 #include <net_stack.h>
24 
25 #include "unix.h"
26 #include "UnixAddressManager.h"
27 #include "UnixEndpoint.h"
28 
29 
30 #define UNIX_MODULE_DEBUG_LEVEL	0
31 #define UNIX_DEBUG_LEVEL		UNIX_MODULE_DEBUG_LEVEL
32 #include "UnixDebug.h"
33 
34 
35 extern net_protocol_module_info gUnixModule;
36 	// extern only for forwarding
37 
38 net_stack_module_info *gStackModule;
39 net_socket_module_info *gSocketModule;
40 net_buffer_module_info *gBufferModule;
41 UnixAddressManager gAddressManager;
42 
43 static struct net_domain *sDomain;
44 
45 
46 void
47 destroy_scm_rights_descriptors(const ancillary_data_header* header,
48 	void* data)
49 {
50 	int count = header->len / sizeof(file_descriptor*);
51 	file_descriptor** descriptors = (file_descriptor**)data;
52 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
53 
54 	for (int i = 0; i < count; i++) {
55 		if (descriptors[i] != NULL) {
56 			close_fd(ioContext, descriptors[i]);
57 			put_fd(descriptors[i]);
58 		}
59 	}
60 }
61 
62 
63 // #pragma mark -
64 
65 
66 net_protocol *
67 unix_init_protocol(net_socket *socket)
68 {
69 	TRACE("[%" B_PRId32 "] unix_init_protocol(%p)\n", find_thread(NULL),
70 		socket);
71 
72 	UnixEndpoint* endpoint;
73 	status_t error = UnixEndpoint::Create(socket, &endpoint);
74 	if (error != B_OK)
75 		return NULL;
76 
77 	error = endpoint->Init();
78 	if (error != B_OK) {
79 		delete endpoint;
80 		return NULL;
81 	}
82 
83 	return endpoint;
84 }
85 
86 
87 status_t
88 unix_uninit_protocol(net_protocol *_protocol)
89 {
90 	TRACE("[%" B_PRId32 "] unix_uninit_protocol(%p)\n", find_thread(NULL),
91 		_protocol);
92 	((UnixEndpoint*)_protocol)->Uninit();
93 	return B_OK;
94 }
95 
96 
97 status_t
98 unix_open(net_protocol *_protocol)
99 {
100 	return ((UnixEndpoint*)_protocol)->Open();
101 }
102 
103 
104 status_t
105 unix_close(net_protocol *_protocol)
106 {
107 	return ((UnixEndpoint*)_protocol)->Close();
108 }
109 
110 
111 status_t
112 unix_free(net_protocol *_protocol)
113 {
114 	return ((UnixEndpoint*)_protocol)->Free();
115 }
116 
117 
118 status_t
119 unix_connect(net_protocol *_protocol, const struct sockaddr *address)
120 {
121 	return ((UnixEndpoint*)_protocol)->Connect(address);
122 }
123 
124 
125 status_t
126 unix_accept(net_protocol *_protocol, struct net_socket **_acceptedSocket)
127 {
128 	return ((UnixEndpoint*)_protocol)->Accept(_acceptedSocket);
129 }
130 
131 
132 status_t
133 unix_control(net_protocol *protocol, int level, int option, void *value,
134 	size_t *_length)
135 {
136 	return B_BAD_VALUE;
137 }
138 
139 
140 status_t
141 unix_getsockopt(net_protocol *protocol, int level, int option, void *value,
142 	int *_length)
143 {
144 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
145 
146 	if (level == SOL_SOCKET && option == SO_PEERCRED) {
147 		if (*_length < (int)sizeof(ucred))
148 			return B_BAD_VALUE;
149 
150 		*_length = sizeof(ucred);
151 
152 		return endpoint->GetPeerCredentials((ucred*)value);
153 	}
154 
155 	return gSocketModule->get_option(protocol->socket, level, option, value,
156 		_length);
157 }
158 
159 
160 status_t
161 unix_setsockopt(net_protocol *protocol, int level, int option,
162 	const void *_value, int length)
163 {
164 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
165 
166 	if (level == SOL_SOCKET) {
167 		if (option == SO_RCVBUF) {
168 			if (length != sizeof(int))
169 				return B_BAD_VALUE;
170 
171 			status_t error = endpoint->SetReceiveBufferSize(*(int*)_value);
172 			if (error != B_OK)
173 				return error;
174 		} else if (option == SO_SNDBUF) {
175 			// We don't have a receive buffer, so silently ignore this one.
176 		}
177 	}
178 
179 	return gSocketModule->set_option(protocol->socket, level, option,
180 		_value, length);
181 }
182 
183 
184 status_t
185 unix_bind(net_protocol *_protocol, const struct sockaddr *_address)
186 {
187 	return ((UnixEndpoint*)_protocol)->Bind(_address);
188 }
189 
190 
191 status_t
192 unix_unbind(net_protocol *_protocol, struct sockaddr *_address)
193 {
194 	return ((UnixEndpoint*)_protocol)->Unbind();
195 }
196 
197 
198 status_t
199 unix_listen(net_protocol *_protocol, int count)
200 {
201 	return ((UnixEndpoint*)_protocol)->Listen(count);
202 }
203 
204 
205 status_t
206 unix_shutdown(net_protocol *_protocol, int direction)
207 {
208 	return ((UnixEndpoint*)_protocol)->Shutdown(direction);
209 }
210 
211 
212 status_t
213 unix_send_routed_data(net_protocol *_protocol, struct net_route *route,
214 	net_buffer *buffer)
215 {
216 	return B_ERROR;
217 }
218 
219 
220 status_t
221 unix_send_data(net_protocol *_protocol, net_buffer *buffer)
222 {
223 	return B_ERROR;
224 }
225 
226 
227 ssize_t
228 unix_send_avail(net_protocol *_protocol)
229 {
230 	return ((UnixEndpoint*)_protocol)->Sendable();
231 }
232 
233 
234 status_t
235 unix_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
236 	net_buffer **_buffer)
237 {
238 	return B_ERROR;
239 }
240 
241 
242 ssize_t
243 unix_read_avail(net_protocol *_protocol)
244 {
245 	return ((UnixEndpoint*)_protocol)->Receivable();
246 }
247 
248 
249 struct net_domain *
250 unix_get_domain(net_protocol *protocol)
251 {
252 	return sDomain;
253 }
254 
255 
256 size_t
257 unix_get_mtu(net_protocol *protocol, const struct sockaddr *address)
258 {
259 	return UNIX_MAX_TRANSFER_UNIT;
260 }
261 
262 
263 status_t
264 unix_receive_data(net_buffer *buffer)
265 {
266 	return B_ERROR;
267 }
268 
269 
270 status_t
271 unix_deliver_data(net_protocol *_protocol, net_buffer *buffer)
272 {
273 	return B_ERROR;
274 }
275 
276 
277 status_t
278 unix_error_received(net_error error, net_buffer *data)
279 {
280 	return B_ERROR;
281 }
282 
283 
284 status_t
285 unix_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
286 	net_error_data *errorData)
287 {
288 	return B_ERROR;
289 }
290 
291 
292 status_t
293 unix_add_ancillary_data(net_protocol *self, ancillary_data_container *container,
294 	const cmsghdr *header)
295 {
296 	TRACE("[%" B_PRId32 "] unix_add_ancillary_data(%p, %p, %p (level: %d, type: %d, "
297 		"len: %" B_PRId32 "))\n", find_thread(NULL), self, container, header,
298 		header->cmsg_level, header->cmsg_type, header->cmsg_len);
299 
300 	// we support only SCM_RIGHTS
301 	if (header->cmsg_level != SOL_SOCKET || header->cmsg_type != SCM_RIGHTS)
302 		return B_BAD_VALUE;
303 
304 	int* fds = (int*)CMSG_DATA(header);
305 	int count = (header->cmsg_len - CMSG_LEN(0)) / sizeof(int);
306 	if (count == 0)
307 		return B_BAD_VALUE;
308 
309 	BStackOrHeapArray<file_descriptor*, 8> descriptors(count);
310 	if (!descriptors.IsValid())
311 		return ENOBUFS;
312 	memset(descriptors, 0, sizeof(file_descriptor*) * count);
313 
314 	// get the file descriptors
315 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
316 
317 	status_t error = B_OK;
318 	for (int i = 0; i < count; i++) {
319 		descriptors[i] = get_open_fd(ioContext, fds[i]);
320 		if (descriptors[i] == NULL) {
321 			error = EBADF;
322 			break;
323 		}
324 	}
325 
326 	// attach the ancillary data to the container
327 	if (error == B_OK) {
328 		ancillary_data_header header;
329 		header.level = SOL_SOCKET;
330 		header.type = SCM_RIGHTS;
331 		header.len = count * sizeof(file_descriptor*);
332 
333 		TRACE("[%" B_PRId32 "] unix_add_ancillary_data(): adding %d FDs to "
334 			"container\n", find_thread(NULL), count);
335 
336 		error = gStackModule->add_ancillary_data(container, &header,
337 			descriptors, destroy_scm_rights_descriptors, NULL);
338 	}
339 
340 	// cleanup on error
341 	if (error != B_OK) {
342 		for (int i = 0; i < count; i++) {
343 			if (descriptors[i] != NULL) {
344 				close_fd(ioContext, descriptors[i]);
345 				put_fd(descriptors[i]);
346 			}
347 		}
348 	}
349 
350 	return error;
351 }
352 
353 
354 ssize_t
355 unix_process_ancillary_data(net_protocol *self,
356 	const ancillary_data_container *container, void *buffer,
357 	size_t bufferSize)
358 {
359 	TRACE("[%" B_PRId32 "] unix_process_ancillary_data(%p, %p, %p, %p, %lu)\n",
360 		find_thread(NULL), self, container, buffer, bufferSize);
361 
362 	int totalCount = 0;
363 
364 	ancillary_data_header header;
365 	void* data = NULL;
366 	while ((data = gStackModule->next_ancillary_data(container, data, &header)) != NULL) {
367 		// we support only SCM_RIGHTS
368 		if (header.level != SOL_SOCKET || header.type != SCM_RIGHTS)
369 			return B_BAD_VALUE;
370 
371 		totalCount += header.len / sizeof(file_descriptor*);
372 	}
373 
374 	// check if there's enough space in the buffer
375 	size_t neededBufferSpace = CMSG_SPACE(sizeof(int) * totalCount);
376 	if (bufferSize < neededBufferSpace)
377 		return B_BAD_VALUE;
378 
379 	// init header
380 	cmsghdr* messageHeader = (cmsghdr*)buffer;
381 	messageHeader->cmsg_level = SOL_SOCKET;
382 	messageHeader->cmsg_type = SCM_RIGHTS;
383 	messageHeader->cmsg_len = CMSG_LEN(sizeof(int) * totalCount);
384 
385 	// create FDs for the current process
386 	int* fds = (int*)CMSG_DATA(messageHeader);
387 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
388 
389 	status_t error = B_OK;
390 	int i = 0;
391 	data = NULL;
392 	while ((data = gStackModule->next_ancillary_data(container, data, &header)) != NULL) {
393 		int count = header.len / sizeof(file_descriptor*);
394 		file_descriptor** descriptors = (file_descriptor**)data;
395 
396 		for (int k = 0; k < count; k++, i++) {
397 			// Get an additional reference which will go to the FD table index. The
398 			// reference and open reference acquired in unix_add_ancillary_data()
399 			// will be released when the container is destroyed.
400 			inc_fd_ref_count(descriptors[k]);
401 			fds[i] = new_fd(ioContext, descriptors[k]);
402 
403 			if (fds[i] < 0) {
404 				error = fds[i];
405 				put_fd(descriptors[k]);
406 
407 				// close FD indices
408 				for (int j = i - 1; j >= 0; j--)
409 					close_fd_index(ioContext, fds[j]);
410 				break;
411 			}
412 		}
413 		if (error != B_OK)
414 			break;
415 	}
416 
417 	return error == B_OK ? neededBufferSpace : error;
418 }
419 
420 
421 ssize_t
422 unix_send_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
423 	size_t vecCount, ancillary_data_container *ancillaryData,
424 	const struct sockaddr *address, socklen_t addressLength, int flags)
425 {
426 	return ((UnixEndpoint*)_protocol)->Send(vecs, vecCount, ancillaryData,
427 		address, addressLength, flags);
428 }
429 
430 
431 ssize_t
432 unix_read_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
433 	size_t vecCount, ancillary_data_container **_ancillaryData,
434 	struct sockaddr *_address, socklen_t *_addressLength, int flags)
435 {
436 	return ((UnixEndpoint*)_protocol)->Receive(vecs, vecCount, _ancillaryData,
437 		_address, _addressLength, flags);
438 }
439 
440 
441 // #pragma mark -
442 
443 
444 status_t
445 init_unix()
446 {
447 	new(&gAddressManager) UnixAddressManager;
448 	status_t error = gAddressManager.Init();
449 	if (error != B_OK)
450 		return error;
451 
452 	error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_STREAM, 0,
453 		"network/protocols/unix/v1", NULL);
454 	if (error == B_OK) {
455 		error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_DGRAM, 0,
456 			"network/protocols/unix/v1", NULL);
457 	}
458 
459 	if (error != B_OK) {
460 		gAddressManager.~UnixAddressManager();
461 		return error;
462 	}
463 
464 	error = gStackModule->register_domain(AF_UNIX, "unix", &gUnixModule,
465 		&gAddressModule, &sDomain);
466 	if (error != B_OK) {
467 		gAddressManager.~UnixAddressManager();
468 		return error;
469 	}
470 
471 	return B_OK;
472 }
473 
474 
475 status_t
476 uninit_unix()
477 {
478 	gStackModule->unregister_domain(sDomain);
479 
480 	gAddressManager.~UnixAddressManager();
481 
482 	return B_OK;
483 }
484 
485 
486 static status_t
487 unix_std_ops(int32 op, ...)
488 {
489 	switch (op) {
490 		case B_MODULE_INIT:
491 			return init_unix();
492 		case B_MODULE_UNINIT:
493 			return uninit_unix();
494 
495 		default:
496 			return B_ERROR;
497 	}
498 }
499 
500 
501 net_protocol_module_info gUnixModule = {
502 	{
503 		"network/protocols/unix/v1",
504 		0,
505 		unix_std_ops
506 	},
507 	0,	// NET_PROTOCOL_ATOMIC_MESSAGES,
508 
509 	unix_init_protocol,
510 	unix_uninit_protocol,
511 	unix_open,
512 	unix_close,
513 	unix_free,
514 	unix_connect,
515 	unix_accept,
516 	unix_control,
517 	unix_getsockopt,
518 	unix_setsockopt,
519 	unix_bind,
520 	unix_unbind,
521 	unix_listen,
522 	unix_shutdown,
523 	unix_send_data,
524 	unix_send_routed_data,
525 	unix_send_avail,
526 	unix_read_data,
527 	unix_read_avail,
528 	unix_get_domain,
529 	unix_get_mtu,
530 	unix_receive_data,
531 	unix_deliver_data,
532 	unix_error_received,
533 	unix_error_reply,
534 	unix_add_ancillary_data,
535 	unix_process_ancillary_data,
536 	NULL,
537 	unix_send_data_no_buffer,
538 	unix_read_data_no_buffer
539 };
540 
541 module_dependency module_dependencies[] = {
542 	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
543 	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
544 //	{NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule},
545 	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
546 	{}
547 };
548 
549 module_info *modules[] = {
550 	(module_info *)&gUnixModule,
551 	NULL
552 };
553