xref: /haiku/src/add-ons/kernel/network/protocols/unix/unix.cpp (revision e1c4049fed1047bdb957b0529e1921e97ef94770)
1 /*
2  * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de.
3  * Distributed under the terms of the MIT License.
4  */
5 
6 
7 #include <stdio.h>
8 #include <sys/un.h>
9 
10 #include <new>
11 
12 #include <AutoDeleter.h>
13 
14 #include <fs/fd.h>
15 #include <lock.h>
16 #include <util/AutoLock.h>
17 #include <vfs.h>
18 
19 #include <net_buffer.h>
20 #include <net_protocol.h>
21 #include <net_socket.h>
22 #include <net_stack.h>
23 
24 #include "unix.h"
25 #include "UnixAddressManager.h"
26 #include "UnixEndpoint.h"
27 
28 
29 #define UNIX_MODULE_DEBUG_LEVEL	0
30 #define UNIX_DEBUG_LEVEL		UNIX_MODULE_DEBUG_LEVEL
31 #include "UnixDebug.h"
32 
33 
34 extern net_protocol_module_info gUnixModule;
35 	// extern only for forwarding
36 
37 net_stack_module_info *gStackModule;
38 net_socket_module_info *gSocketModule;
39 net_buffer_module_info *gBufferModule;
40 UnixAddressManager gAddressManager;
41 
42 static struct net_domain *sDomain;
43 
44 
45 void
46 destroy_scm_rights_descriptors(const ancillary_data_header* header,
47 	void* data)
48 {
49 	int count = header->len / sizeof(file_descriptor*);
50 	file_descriptor** descriptors = (file_descriptor**)data;
51 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
52 
53 	for (int i = 0; i < count; i++) {
54 		if (descriptors[i] != NULL) {
55 			close_fd(ioContext, descriptors[i]);
56 			put_fd(descriptors[i]);
57 		}
58 	}
59 }
60 
61 
62 // #pragma mark -
63 
64 
65 net_protocol *
66 unix_init_protocol(net_socket *socket)
67 {
68 	TRACE("[%" B_PRId32 "] unix_init_protocol(%p)\n", find_thread(NULL),
69 		socket);
70 
71 	UnixEndpoint* endpoint;
72 	status_t error = UnixEndpoint::Create(socket, &endpoint);
73 	if (error != B_OK)
74 		return NULL;
75 
76 	error = endpoint->Init();
77 	if (error != B_OK) {
78 		delete endpoint;
79 		return NULL;
80 	}
81 
82 	return endpoint;
83 }
84 
85 
86 status_t
87 unix_uninit_protocol(net_protocol *_protocol)
88 {
89 	TRACE("[%" B_PRId32 "] unix_uninit_protocol(%p)\n", find_thread(NULL),
90 		_protocol);
91 	((UnixEndpoint*)_protocol)->Uninit();
92 	return B_OK;
93 }
94 
95 
96 status_t
97 unix_open(net_protocol *_protocol)
98 {
99 	return ((UnixEndpoint*)_protocol)->Open();
100 }
101 
102 
103 status_t
104 unix_close(net_protocol *_protocol)
105 {
106 	return ((UnixEndpoint*)_protocol)->Close();
107 }
108 
109 
110 status_t
111 unix_free(net_protocol *_protocol)
112 {
113 	return ((UnixEndpoint*)_protocol)->Free();
114 }
115 
116 
117 status_t
118 unix_connect(net_protocol *_protocol, const struct sockaddr *address)
119 {
120 	return ((UnixEndpoint*)_protocol)->Connect(address);
121 }
122 
123 
124 status_t
125 unix_accept(net_protocol *_protocol, struct net_socket **_acceptedSocket)
126 {
127 	return ((UnixEndpoint*)_protocol)->Accept(_acceptedSocket);
128 }
129 
130 
131 status_t
132 unix_control(net_protocol *protocol, int level, int option, void *value,
133 	size_t *_length)
134 {
135 	return B_BAD_VALUE;
136 }
137 
138 
139 status_t
140 unix_getsockopt(net_protocol *protocol, int level, int option, void *value,
141 	int *_length)
142 {
143 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
144 
145 	if (level == SOL_SOCKET && option == SO_PEERCRED) {
146 		if (*_length < (int)sizeof(ucred))
147 			return B_BAD_VALUE;
148 
149 		*_length = sizeof(ucred);
150 
151 		return endpoint->GetPeerCredentials((ucred*)value);
152 	}
153 
154 	return gSocketModule->get_option(protocol->socket, level, option, value,
155 		_length);
156 }
157 
158 
159 status_t
160 unix_setsockopt(net_protocol *protocol, int level, int option,
161 	const void *_value, int length)
162 {
163 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
164 
165 	if (level == SOL_SOCKET) {
166 		if (option == SO_RCVBUF) {
167 			if (length != sizeof(int))
168 				return B_BAD_VALUE;
169 
170 			status_t error = endpoint->SetReceiveBufferSize(*(int*)_value);
171 			if (error != B_OK)
172 				return error;
173 		} else if (option == SO_SNDBUF) {
174 			// We don't have a receive buffer, so silently ignore this one.
175 		}
176 	}
177 
178 	return gSocketModule->set_option(protocol->socket, level, option,
179 		_value, length);
180 }
181 
182 
183 status_t
184 unix_bind(net_protocol *_protocol, const struct sockaddr *_address)
185 {
186 	return ((UnixEndpoint*)_protocol)->Bind(_address);
187 }
188 
189 
190 status_t
191 unix_unbind(net_protocol *_protocol, struct sockaddr *_address)
192 {
193 	return ((UnixEndpoint*)_protocol)->Unbind();
194 }
195 
196 
197 status_t
198 unix_listen(net_protocol *_protocol, int count)
199 {
200 	return ((UnixEndpoint*)_protocol)->Listen(count);
201 }
202 
203 
204 status_t
205 unix_shutdown(net_protocol *_protocol, int direction)
206 {
207 	return ((UnixEndpoint*)_protocol)->Shutdown(direction);
208 }
209 
210 
211 status_t
212 unix_send_routed_data(net_protocol *_protocol, struct net_route *route,
213 	net_buffer *buffer)
214 {
215 	return B_ERROR;
216 }
217 
218 
219 status_t
220 unix_send_data(net_protocol *_protocol, net_buffer *buffer)
221 {
222 	return B_ERROR;
223 }
224 
225 
226 ssize_t
227 unix_send_avail(net_protocol *_protocol)
228 {
229 	return ((UnixEndpoint*)_protocol)->Sendable();
230 }
231 
232 
233 status_t
234 unix_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
235 	net_buffer **_buffer)
236 {
237 	return B_ERROR;
238 }
239 
240 
241 ssize_t
242 unix_read_avail(net_protocol *_protocol)
243 {
244 	return ((UnixEndpoint*)_protocol)->Receivable();
245 }
246 
247 
248 struct net_domain *
249 unix_get_domain(net_protocol *protocol)
250 {
251 	return sDomain;
252 }
253 
254 
255 size_t
256 unix_get_mtu(net_protocol *protocol, const struct sockaddr *address)
257 {
258 	return UNIX_MAX_TRANSFER_UNIT;
259 }
260 
261 
262 status_t
263 unix_receive_data(net_buffer *buffer)
264 {
265 	return B_ERROR;
266 }
267 
268 
269 status_t
270 unix_deliver_data(net_protocol *_protocol, net_buffer *buffer)
271 {
272 	return B_ERROR;
273 }
274 
275 
276 status_t
277 unix_error_received(net_error error, net_buffer *data)
278 {
279 	return B_ERROR;
280 }
281 
282 
283 status_t
284 unix_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
285 	net_error_data *errorData)
286 {
287 	return B_ERROR;
288 }
289 
290 
291 status_t
292 unix_add_ancillary_data(net_protocol *self, ancillary_data_container *container,
293 	const cmsghdr *header)
294 {
295 	TRACE("[%" B_PRId32 "] unix_add_ancillary_data(%p, %p, %p (level: %d, type: %d, "
296 		"len: %" B_PRId32 "))\n", find_thread(NULL), self, container, header,
297 		header->cmsg_level, header->cmsg_type, header->cmsg_len);
298 
299 	// we support only SCM_RIGHTS
300 	if (header->cmsg_level != SOL_SOCKET || header->cmsg_type != SCM_RIGHTS)
301 		return B_BAD_VALUE;
302 
303 	int* fds = (int*)CMSG_DATA(header);
304 	int count = (header->cmsg_len - CMSG_ALIGN(sizeof(cmsghdr))) / sizeof(int);
305 	if (count == 0)
306 		return B_BAD_VALUE;
307 
308 	file_descriptor** descriptors = new(std::nothrow) file_descriptor*[count];
309 	if (descriptors == NULL)
310 		return ENOBUFS;
311 	ArrayDeleter<file_descriptor*> _(descriptors);
312 	memset(descriptors, 0, sizeof(file_descriptor*) * count);
313 
314 	// get the file descriptors
315 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
316 
317 	status_t error = B_OK;
318 	for (int i = 0; i < count; i++) {
319 		descriptors[i] = get_open_fd(ioContext, fds[i]);
320 		if (descriptors[i] == NULL) {
321 			error = EBADF;
322 			break;
323 		}
324 	}
325 
326 	// attach the ancillary data to the container
327 	if (error == B_OK) {
328 		ancillary_data_header header;
329 		header.level = SOL_SOCKET;
330 		header.type = SCM_RIGHTS;
331 		header.len = count * sizeof(file_descriptor*);
332 
333 		TRACE("[%" B_PRId32 "] unix_add_ancillary_data(): adding %d FDs to "
334 			"container\n", find_thread(NULL), count);
335 
336 		error = gStackModule->add_ancillary_data(container, &header,
337 			descriptors, destroy_scm_rights_descriptors, NULL);
338 	}
339 
340 	// cleanup on error
341 	if (error != B_OK) {
342 		for (int i = 0; i < count; i++) {
343 			if (descriptors[i] != NULL) {
344 				close_fd(ioContext, descriptors[i]);
345 				put_fd(descriptors[i]);
346 			}
347 		}
348 	}
349 
350 	return error;
351 }
352 
353 
354 ssize_t
355 unix_process_ancillary_data(net_protocol *self,
356 	const ancillary_data_header *header, const void *data, void *buffer,
357 	size_t bufferSize)
358 {
359 	TRACE("[%" B_PRId32 "] unix_process_ancillary_data(%p, %p (level: %d, "
360 		"type: %d, len: %lu), %p, %p, %lu)\n", find_thread(NULL), self, header,
361 		header->level, header->type, header->len, data, buffer, bufferSize);
362 
363 	// we support only SCM_RIGHTS
364 	if (header->level != SOL_SOCKET || header->type != SCM_RIGHTS)
365 		return B_BAD_VALUE;
366 
367 	int count = header->len / sizeof(file_descriptor*);
368 	file_descriptor** descriptors = (file_descriptor**)data;
369 
370 	// check if there's enough space in the buffer
371 	size_t neededBufferSpace = CMSG_SPACE(sizeof(int) * count);
372 	if (bufferSize < neededBufferSpace)
373 		return B_BAD_VALUE;
374 
375 	// init header
376 	cmsghdr* messageHeader = (cmsghdr*)buffer;
377 	messageHeader->cmsg_level = header->level;
378 	messageHeader->cmsg_type = header->type;
379 	messageHeader->cmsg_len = CMSG_LEN(sizeof(int) * count);
380 
381 	// create FDs for the current process
382 	int* fds = (int*)CMSG_DATA(messageHeader);
383 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
384 
385 	status_t error = B_OK;
386 	for (int i = 0; i < count; i++) {
387 		// Get an additional reference which will go to the FD table index. The
388 		// reference and open reference acquired in unix_add_ancillary_data()
389 		// will be released when the container is destroyed.
390 		inc_fd_ref_count(descriptors[i]);
391 		fds[i] = new_fd(ioContext, descriptors[i]);
392 
393 		if (fds[i] < 0) {
394 			error = fds[i];
395 			put_fd(descriptors[i]);
396 
397 			// close FD indices
398 			for (int k = i - 1; k >= 0; k--)
399 				close_fd_index(ioContext, fds[k]);
400 			break;
401 		}
402 	}
403 
404 	return error == B_OK ? neededBufferSpace : error;
405 }
406 
407 
408 ssize_t
409 unix_send_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
410 	size_t vecCount, ancillary_data_container *ancillaryData,
411 	const struct sockaddr *address, socklen_t addressLength, int flags)
412 {
413 	return ((UnixEndpoint*)_protocol)->Send(vecs, vecCount, ancillaryData,
414 		address, addressLength, flags);
415 }
416 
417 
418 ssize_t
419 unix_read_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
420 	size_t vecCount, ancillary_data_container **_ancillaryData,
421 	struct sockaddr *_address, socklen_t *_addressLength, int flags)
422 {
423 	return ((UnixEndpoint*)_protocol)->Receive(vecs, vecCount, _ancillaryData,
424 		_address, _addressLength, flags);
425 }
426 
427 
428 // #pragma mark -
429 
430 
431 status_t
432 init_unix()
433 {
434 	new(&gAddressManager) UnixAddressManager;
435 	status_t error = gAddressManager.Init();
436 	if (error != B_OK)
437 		return error;
438 
439 	error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_STREAM, 0,
440 		"network/protocols/unix/v1", NULL);
441 	if (error == B_OK) {
442 		error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_DGRAM, 0,
443 			"network/protocols/unix/v1", NULL);
444 	}
445 
446 	if (error != B_OK) {
447 		gAddressManager.~UnixAddressManager();
448 		return error;
449 	}
450 
451 	error = gStackModule->register_domain(AF_UNIX, "unix", &gUnixModule,
452 		&gAddressModule, &sDomain);
453 	if (error != B_OK) {
454 		gAddressManager.~UnixAddressManager();
455 		return error;
456 	}
457 
458 	return B_OK;
459 }
460 
461 
462 status_t
463 uninit_unix()
464 {
465 	gStackModule->unregister_domain(sDomain);
466 
467 	gAddressManager.~UnixAddressManager();
468 
469 	return B_OK;
470 }
471 
472 
473 static status_t
474 unix_std_ops(int32 op, ...)
475 {
476 	switch (op) {
477 		case B_MODULE_INIT:
478 			return init_unix();
479 		case B_MODULE_UNINIT:
480 			return uninit_unix();
481 
482 		default:
483 			return B_ERROR;
484 	}
485 }
486 
487 
488 net_protocol_module_info gUnixModule = {
489 	{
490 		"network/protocols/unix/v1",
491 		0,
492 		unix_std_ops
493 	},
494 	0,	// NET_PROTOCOL_ATOMIC_MESSAGES,
495 
496 	unix_init_protocol,
497 	unix_uninit_protocol,
498 	unix_open,
499 	unix_close,
500 	unix_free,
501 	unix_connect,
502 	unix_accept,
503 	unix_control,
504 	unix_getsockopt,
505 	unix_setsockopt,
506 	unix_bind,
507 	unix_unbind,
508 	unix_listen,
509 	unix_shutdown,
510 	unix_send_data,
511 	unix_send_routed_data,
512 	unix_send_avail,
513 	unix_read_data,
514 	unix_read_avail,
515 	unix_get_domain,
516 	unix_get_mtu,
517 	unix_receive_data,
518 	unix_deliver_data,
519 	unix_error_received,
520 	unix_error_reply,
521 	unix_add_ancillary_data,
522 	unix_process_ancillary_data,
523 	NULL,
524 	unix_send_data_no_buffer,
525 	unix_read_data_no_buffer
526 };
527 
528 module_dependency module_dependencies[] = {
529 	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
530 	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
531 //	{NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule},
532 	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
533 	{}
534 };
535 
536 module_info *modules[] = {
537 	(module_info *)&gUnixModule,
538 	NULL
539 };
540