xref: /haiku/src/add-ons/kernel/network/protocols/unix/unix.cpp (revision 9a6a20d4689307142a7ed26a1437ba47e244e73f)
1 /*
2  * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de.
3  * Distributed under the terms of the MIT License.
4  */
5 
6 
7 #include <stdio.h>
8 #include <sys/un.h>
9 
10 #include <new>
11 
12 #include <AutoDeleter.h>
13 #include <StackOrHeapArray.h>
14 
15 #include <fs/fd.h>
16 #include <lock.h>
17 #include <util/AutoLock.h>
18 #include <vfs.h>
19 
20 #include <net_buffer.h>
21 #include <net_protocol.h>
22 #include <net_socket.h>
23 #include <net_stack.h>
24 
25 #include "unix.h"
26 #include "UnixAddressManager.h"
27 #include "UnixEndpoint.h"
28 
29 
30 #define UNIX_MODULE_DEBUG_LEVEL	0
31 #define UNIX_DEBUG_LEVEL		UNIX_MODULE_DEBUG_LEVEL
32 #include "UnixDebug.h"
33 
34 
35 extern net_protocol_module_info gUnixModule;
36 	// extern only for forwarding
37 
38 net_stack_module_info *gStackModule;
39 net_socket_module_info *gSocketModule;
40 net_buffer_module_info *gBufferModule;
41 UnixAddressManager gAddressManager;
42 
43 static struct net_domain *sDomain;
44 
45 
46 void
47 destroy_scm_rights_descriptors(const ancillary_data_header* header,
48 	void* data)
49 {
50 	int count = header->len / sizeof(file_descriptor*);
51 	file_descriptor** descriptors = (file_descriptor**)data;
52 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
53 
54 	for (int i = 0; i < count; i++) {
55 		if (descriptors[i] != NULL) {
56 			close_fd(ioContext, descriptors[i]);
57 			put_fd(descriptors[i]);
58 		}
59 	}
60 }
61 
62 
63 // #pragma mark -
64 
65 
66 net_protocol *
67 unix_init_protocol(net_socket *socket)
68 {
69 	TRACE("[%" B_PRId32 "] unix_init_protocol(%p)\n", find_thread(NULL),
70 		socket);
71 
72 	UnixEndpoint* endpoint;
73 	status_t error = UnixEndpoint::Create(socket, &endpoint);
74 	if (error != B_OK)
75 		return NULL;
76 
77 	error = endpoint->Init();
78 	if (error != B_OK) {
79 		delete endpoint;
80 		return NULL;
81 	}
82 
83 	return endpoint;
84 }
85 
86 
87 status_t
88 unix_uninit_protocol(net_protocol *_protocol)
89 {
90 	TRACE("[%" B_PRId32 "] unix_uninit_protocol(%p)\n", find_thread(NULL),
91 		_protocol);
92 	((UnixEndpoint*)_protocol)->Uninit();
93 	return B_OK;
94 }
95 
96 
97 status_t
98 unix_open(net_protocol *_protocol)
99 {
100 	return ((UnixEndpoint*)_protocol)->Open();
101 }
102 
103 
104 status_t
105 unix_close(net_protocol *_protocol)
106 {
107 	return ((UnixEndpoint*)_protocol)->Close();
108 }
109 
110 
111 status_t
112 unix_free(net_protocol *_protocol)
113 {
114 	return ((UnixEndpoint*)_protocol)->Free();
115 }
116 
117 
118 status_t
119 unix_connect(net_protocol *_protocol, const struct sockaddr *address)
120 {
121 	return ((UnixEndpoint*)_protocol)->Connect(address);
122 }
123 
124 
125 status_t
126 unix_accept(net_protocol *_protocol, struct net_socket **_acceptedSocket)
127 {
128 	return ((UnixEndpoint*)_protocol)->Accept(_acceptedSocket);
129 }
130 
131 
132 status_t
133 unix_control(net_protocol *protocol, int level, int option, void *value,
134 	size_t *_length)
135 {
136 	return B_BAD_VALUE;
137 }
138 
139 
140 status_t
141 unix_getsockopt(net_protocol *protocol, int level, int option, void *value,
142 	int *_length)
143 {
144 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
145 
146 	if (level == SOL_SOCKET && option == SO_PEERCRED) {
147 		if (*_length < (int)sizeof(ucred))
148 			return B_BAD_VALUE;
149 
150 		*_length = sizeof(ucred);
151 
152 		return endpoint->GetPeerCredentials((ucred*)value);
153 	}
154 
155 	return gSocketModule->get_option(protocol->socket, level, option, value,
156 		_length);
157 }
158 
159 
160 status_t
161 unix_setsockopt(net_protocol *protocol, int level, int option,
162 	const void *_value, int length)
163 {
164 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
165 
166 	if (level == SOL_SOCKET) {
167 		if (option == SO_RCVBUF) {
168 			if (length != sizeof(int))
169 				return B_BAD_VALUE;
170 
171 			status_t error = endpoint->SetReceiveBufferSize(*(int*)_value);
172 			if (error != B_OK)
173 				return error;
174 		} else if (option == SO_SNDBUF) {
175 			// We don't have a receive buffer, so silently ignore this one.
176 		}
177 	}
178 
179 	return gSocketModule->set_option(protocol->socket, level, option,
180 		_value, length);
181 }
182 
183 
184 status_t
185 unix_bind(net_protocol *_protocol, const struct sockaddr *_address)
186 {
187 	return ((UnixEndpoint*)_protocol)->Bind(_address);
188 }
189 
190 
191 status_t
192 unix_unbind(net_protocol *_protocol, struct sockaddr *_address)
193 {
194 	return ((UnixEndpoint*)_protocol)->Unbind();
195 }
196 
197 
198 status_t
199 unix_listen(net_protocol *_protocol, int count)
200 {
201 	return ((UnixEndpoint*)_protocol)->Listen(count);
202 }
203 
204 
205 status_t
206 unix_shutdown(net_protocol *_protocol, int direction)
207 {
208 	return ((UnixEndpoint*)_protocol)->Shutdown(direction);
209 }
210 
211 
212 status_t
213 unix_send_routed_data(net_protocol *_protocol, struct net_route *route,
214 	net_buffer *buffer)
215 {
216 	return B_ERROR;
217 }
218 
219 
220 status_t
221 unix_send_data(net_protocol *_protocol, net_buffer *buffer)
222 {
223 	return B_ERROR;
224 }
225 
226 
227 ssize_t
228 unix_send_avail(net_protocol *_protocol)
229 {
230 	return ((UnixEndpoint*)_protocol)->Sendable();
231 }
232 
233 
234 status_t
235 unix_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
236 	net_buffer **_buffer)
237 {
238 	return B_ERROR;
239 }
240 
241 
242 ssize_t
243 unix_read_avail(net_protocol *_protocol)
244 {
245 	return ((UnixEndpoint*)_protocol)->Receivable();
246 }
247 
248 
249 struct net_domain *
250 unix_get_domain(net_protocol *protocol)
251 {
252 	return sDomain;
253 }
254 
255 
256 size_t
257 unix_get_mtu(net_protocol *protocol, const struct sockaddr *address)
258 {
259 	return UNIX_MAX_TRANSFER_UNIT;
260 }
261 
262 
263 status_t
264 unix_receive_data(net_buffer *buffer)
265 {
266 	return B_ERROR;
267 }
268 
269 
270 status_t
271 unix_deliver_data(net_protocol *_protocol, net_buffer *buffer)
272 {
273 	return B_ERROR;
274 }
275 
276 
277 status_t
278 unix_error_received(net_error error, net_buffer *data)
279 {
280 	return B_ERROR;
281 }
282 
283 
284 status_t
285 unix_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
286 	net_error_data *errorData)
287 {
288 	return B_ERROR;
289 }
290 
291 
292 status_t
293 unix_add_ancillary_data(net_protocol *self, ancillary_data_container *container,
294 	const cmsghdr *header)
295 {
296 	TRACE("[%" B_PRId32 "] unix_add_ancillary_data(%p, %p, %p (level: %d, type: %d, "
297 		"len: %" B_PRId32 "))\n", find_thread(NULL), self, container, header,
298 		header->cmsg_level, header->cmsg_type, header->cmsg_len);
299 
300 	// we support only SCM_RIGHTS
301 	if (header->cmsg_level != SOL_SOCKET || header->cmsg_type != SCM_RIGHTS)
302 		return B_BAD_VALUE;
303 
304 	int* fds = (int*)CMSG_DATA(header);
305 	int count = (header->cmsg_len - CMSG_LEN(0)) / sizeof(int);
306 	if (count == 0)
307 		return B_BAD_VALUE;
308 
309 	BStackOrHeapArray<file_descriptor*, 8> descriptors(count);
310 	if (!descriptors.IsValid())
311 		return ENOBUFS;
312 	memset(descriptors, 0, sizeof(file_descriptor*) * count);
313 
314 	// get the file descriptors
315 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
316 
317 	status_t error = B_OK;
318 	for (int i = 0; i < count; i++) {
319 		descriptors[i] = get_open_fd(ioContext, fds[i]);
320 		if (descriptors[i] == NULL) {
321 			error = EBADF;
322 			break;
323 		}
324 	}
325 
326 	// attach the ancillary data to the container
327 	if (error == B_OK) {
328 		ancillary_data_header header;
329 		header.level = SOL_SOCKET;
330 		header.type = SCM_RIGHTS;
331 		header.len = count * sizeof(file_descriptor*);
332 
333 		TRACE("[%" B_PRId32 "] unix_add_ancillary_data(): adding %d FDs to "
334 			"container\n", find_thread(NULL), count);
335 
336 		error = gStackModule->add_ancillary_data(container, &header,
337 			descriptors, destroy_scm_rights_descriptors, NULL);
338 	}
339 
340 	// cleanup on error
341 	if (error != B_OK) {
342 		for (int i = 0; i < count; i++) {
343 			if (descriptors[i] != NULL) {
344 				close_fd(ioContext, descriptors[i]);
345 				put_fd(descriptors[i]);
346 			}
347 		}
348 	}
349 
350 	return error;
351 }
352 
353 
354 ssize_t
355 unix_process_ancillary_data(net_protocol *self,
356 	const ancillary_data_header *header, const void *data, void *buffer,
357 	size_t bufferSize)
358 {
359 	TRACE("[%" B_PRId32 "] unix_process_ancillary_data(%p, %p (level: %d, "
360 		"type: %d, len: %lu), %p, %p, %lu)\n", find_thread(NULL), self, header,
361 		header->level, header->type, header->len, data, buffer, bufferSize);
362 
363 	// we support only SCM_RIGHTS
364 	if (header->level != SOL_SOCKET || header->type != SCM_RIGHTS)
365 		return B_BAD_VALUE;
366 
367 	int count = header->len / sizeof(file_descriptor*);
368 	file_descriptor** descriptors = (file_descriptor**)data;
369 
370 	// check if there's enough space in the buffer
371 	size_t neededBufferSpace = CMSG_SPACE(sizeof(int) * count);
372 	if (bufferSize < neededBufferSpace)
373 		return B_BAD_VALUE;
374 
375 	// init header
376 	cmsghdr* messageHeader = (cmsghdr*)buffer;
377 	messageHeader->cmsg_level = header->level;
378 	messageHeader->cmsg_type = header->type;
379 	messageHeader->cmsg_len = CMSG_LEN(sizeof(int) * count);
380 
381 	// create FDs for the current process
382 	int* fds = (int*)CMSG_DATA(messageHeader);
383 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
384 
385 	status_t error = B_OK;
386 	for (int i = 0; i < count; i++) {
387 		// Get an additional reference which will go to the FD table index. The
388 		// reference and open reference acquired in unix_add_ancillary_data()
389 		// will be released when the container is destroyed.
390 		inc_fd_ref_count(descriptors[i]);
391 		fds[i] = new_fd(ioContext, descriptors[i]);
392 
393 		if (fds[i] < 0) {
394 			error = fds[i];
395 			put_fd(descriptors[i]);
396 
397 			// close FD indices
398 			for (int k = i - 1; k >= 0; k--)
399 				close_fd_index(ioContext, fds[k]);
400 			break;
401 		}
402 	}
403 
404 	return error == B_OK ? neededBufferSpace : error;
405 }
406 
407 
408 ssize_t
409 unix_send_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
410 	size_t vecCount, ancillary_data_container *ancillaryData,
411 	const struct sockaddr *address, socklen_t addressLength, int flags)
412 {
413 	return ((UnixEndpoint*)_protocol)->Send(vecs, vecCount, ancillaryData,
414 		address, addressLength, flags);
415 }
416 
417 
418 ssize_t
419 unix_read_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
420 	size_t vecCount, ancillary_data_container **_ancillaryData,
421 	struct sockaddr *_address, socklen_t *_addressLength, int flags)
422 {
423 	return ((UnixEndpoint*)_protocol)->Receive(vecs, vecCount, _ancillaryData,
424 		_address, _addressLength, flags);
425 }
426 
427 
428 // #pragma mark -
429 
430 
431 status_t
432 init_unix()
433 {
434 	new(&gAddressManager) UnixAddressManager;
435 	status_t error = gAddressManager.Init();
436 	if (error != B_OK)
437 		return error;
438 
439 	error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_STREAM, 0,
440 		"network/protocols/unix/v1", NULL);
441 	if (error == B_OK) {
442 		error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_DGRAM, 0,
443 			"network/protocols/unix/v1", NULL);
444 	}
445 
446 	if (error != B_OK) {
447 		gAddressManager.~UnixAddressManager();
448 		return error;
449 	}
450 
451 	error = gStackModule->register_domain(AF_UNIX, "unix", &gUnixModule,
452 		&gAddressModule, &sDomain);
453 	if (error != B_OK) {
454 		gAddressManager.~UnixAddressManager();
455 		return error;
456 	}
457 
458 	return B_OK;
459 }
460 
461 
462 status_t
463 uninit_unix()
464 {
465 	gStackModule->unregister_domain(sDomain);
466 
467 	gAddressManager.~UnixAddressManager();
468 
469 	return B_OK;
470 }
471 
472 
473 static status_t
474 unix_std_ops(int32 op, ...)
475 {
476 	switch (op) {
477 		case B_MODULE_INIT:
478 			return init_unix();
479 		case B_MODULE_UNINIT:
480 			return uninit_unix();
481 
482 		default:
483 			return B_ERROR;
484 	}
485 }
486 
487 
488 net_protocol_module_info gUnixModule = {
489 	{
490 		"network/protocols/unix/v1",
491 		0,
492 		unix_std_ops
493 	},
494 	0,	// NET_PROTOCOL_ATOMIC_MESSAGES,
495 
496 	unix_init_protocol,
497 	unix_uninit_protocol,
498 	unix_open,
499 	unix_close,
500 	unix_free,
501 	unix_connect,
502 	unix_accept,
503 	unix_control,
504 	unix_getsockopt,
505 	unix_setsockopt,
506 	unix_bind,
507 	unix_unbind,
508 	unix_listen,
509 	unix_shutdown,
510 	unix_send_data,
511 	unix_send_routed_data,
512 	unix_send_avail,
513 	unix_read_data,
514 	unix_read_avail,
515 	unix_get_domain,
516 	unix_get_mtu,
517 	unix_receive_data,
518 	unix_deliver_data,
519 	unix_error_received,
520 	unix_error_reply,
521 	unix_add_ancillary_data,
522 	unix_process_ancillary_data,
523 	NULL,
524 	unix_send_data_no_buffer,
525 	unix_read_data_no_buffer
526 };
527 
528 module_dependency module_dependencies[] = {
529 	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
530 	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
531 //	{NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule},
532 	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
533 	{}
534 };
535 
536 module_info *modules[] = {
537 	(module_info *)&gUnixModule,
538 	NULL
539 };
540