xref: /haiku/src/add-ons/kernel/network/protocols/unix/unix.cpp (revision d2e1e872611179c9cfaa43ce11bd58b1e3554e4b)
1 /*
2  * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de.
3  * Distributed under the terms of the MIT License.
4  */
5 
6 #include <stdio.h>
7 #include <sys/un.h>
8 
9 #include <new>
10 
11 #include <AutoDeleter.h>
12 
13 #include <fs/fd.h>
14 #include <lock.h>
15 #include <util/AutoLock.h>
16 #include <vfs.h>
17 
18 #include <net_buffer.h>
19 #include <net_protocol.h>
20 #include <net_socket.h>
21 #include <net_stack.h>
22 
23 #include "UnixAddressManager.h"
24 #include "UnixEndpoint.h"
25 
26 
27 #define UNIX_MODULE_DEBUG_LEVEL	2
28 #define UNIX_DEBUG_LEVEL		UNIX_MODULE_DEBUG_LEVEL
29 #include "UnixDebug.h"
30 
31 
32 extern net_protocol_module_info gUnixModule;
33 	// extern only for forwarding
34 
35 net_stack_module_info *gStackModule;
36 net_socket_module_info *gSocketModule;
37 net_buffer_module_info *gBufferModule;
38 UnixAddressManager gAddressManager;
39 
40 static struct net_domain *sDomain;
41 
42 
43 void
44 destroy_scm_rights_descriptors(const ancillary_data_header* header,
45 	void* data)
46 {
47 	int count = header->len / sizeof(file_descriptor*);
48 	file_descriptor** descriptors = (file_descriptor**)data;
49 
50 	for (int i = 0; i < count; i++) {
51 		if (descriptors[i] != NULL)
52 			put_fd(descriptors[i]);
53 	}
54 }
55 
56 
57 // #pragma mark -
58 
59 
60 net_protocol *
61 unix_init_protocol(net_socket *socket)
62 {
63 	TRACE("[%ld] unix_init_protocol(%p)\n", find_thread(NULL), socket);
64 
65 	UnixEndpoint* endpoint = new(std::nothrow) UnixEndpoint(socket);
66 	if (endpoint == NULL)
67 		return NULL;
68 
69 	status_t error = endpoint->Init();
70 	if (error != B_OK) {
71 		delete endpoint;
72 		return NULL;
73 	}
74 
75 	return endpoint;
76 }
77 
78 
79 status_t
80 unix_uninit_protocol(net_protocol *_protocol)
81 {
82 	TRACE("[%ld] unix_uninit_protocol(%p)\n", find_thread(NULL), _protocol);
83 	((UnixEndpoint*)_protocol)->Uninit();
84 	return B_OK;
85 }
86 
87 
88 status_t
89 unix_open(net_protocol *_protocol)
90 {
91 	return ((UnixEndpoint*)_protocol)->Open();
92 }
93 
94 
95 status_t
96 unix_close(net_protocol *_protocol)
97 {
98 	return ((UnixEndpoint*)_protocol)->Close();
99 }
100 
101 
102 status_t
103 unix_free(net_protocol *_protocol)
104 {
105 	return ((UnixEndpoint*)_protocol)->Free();
106 }
107 
108 
109 status_t
110 unix_connect(net_protocol *_protocol, const struct sockaddr *address)
111 {
112 	return ((UnixEndpoint*)_protocol)->Connect(address);
113 }
114 
115 
116 status_t
117 unix_accept(net_protocol *_protocol, struct net_socket **_acceptedSocket)
118 {
119 	return ((UnixEndpoint*)_protocol)->Accept(_acceptedSocket);
120 }
121 
122 
123 status_t
124 unix_control(net_protocol *protocol, int level, int option, void *value,
125 	size_t *_length)
126 {
127 	return B_BAD_VALUE;
128 }
129 
130 
131 status_t
132 unix_getsockopt(net_protocol *protocol, int level, int option, void *value,
133 	int *_length)
134 {
135 	return gSocketModule->get_option(protocol->socket, level, option, value,
136 		_length);
137 }
138 
139 
140 status_t
141 unix_setsockopt(net_protocol *protocol, int level, int option,
142 	const void *_value, int length)
143 {
144 	UnixEndpoint* endpoint = (UnixEndpoint*)protocol;
145 
146 	if (level == SOL_SOCKET) {
147 		if (option == SO_RCVBUF) {
148 			if (length != sizeof(int))
149 				return B_BAD_VALUE;
150 
151 			endpoint->SetReceiveBufferSize(*(int*)_value);
152 		} else if (option == SO_SNDBUF) {
153 			// We don't have a receive buffer, so silently ignore this one.
154 		}
155 	}
156 
157 	return gSocketModule->set_option(protocol->socket, level, option,
158 		_value, length);
159 }
160 
161 
162 status_t
163 unix_bind(net_protocol *_protocol, const struct sockaddr *_address)
164 {
165 	return ((UnixEndpoint*)_protocol)->Bind(_address);
166 }
167 
168 
169 status_t
170 unix_unbind(net_protocol *_protocol, struct sockaddr *_address)
171 {
172 	return ((UnixEndpoint*)_protocol)->Unbind();
173 }
174 
175 
176 status_t
177 unix_listen(net_protocol *_protocol, int count)
178 {
179 	return ((UnixEndpoint*)_protocol)->Listen(count);
180 }
181 
182 
183 status_t
184 unix_shutdown(net_protocol *_protocol, int direction)
185 {
186 	return ((UnixEndpoint*)_protocol)->Shutdown(direction);
187 }
188 
189 
190 status_t
191 unix_send_routed_data(net_protocol *_protocol, struct net_route *route,
192 	net_buffer *buffer)
193 {
194 	return B_ERROR;
195 }
196 
197 
198 status_t
199 unix_send_data(net_protocol *_protocol, net_buffer *buffer)
200 {
201 	return ((UnixEndpoint*)_protocol)->Send(buffer);
202 }
203 
204 
205 ssize_t
206 unix_send_avail(net_protocol *_protocol)
207 {
208 	return ((UnixEndpoint*)_protocol)->Sendable();
209 }
210 
211 
212 status_t
213 unix_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
214 	net_buffer **_buffer)
215 {
216 	return ((UnixEndpoint*)_protocol)->Receive(numBytes, flags, _buffer);
217 }
218 
219 
220 ssize_t
221 unix_read_avail(net_protocol *_protocol)
222 {
223 	return ((UnixEndpoint*)_protocol)->Receivable();
224 }
225 
226 
227 struct net_domain *
228 unix_get_domain(net_protocol *protocol)
229 {
230 	return sDomain;
231 }
232 
233 
234 size_t
235 unix_get_mtu(net_protocol *protocol, const struct sockaddr *address)
236 {
237 	return UNIX_MAX_TRANSFER_UNIT;
238 }
239 
240 
241 status_t
242 unix_receive_data(net_buffer *buffer)
243 {
244 	return B_ERROR;
245 }
246 
247 
248 status_t
249 unix_deliver_data(net_protocol *_protocol, net_buffer *buffer)
250 {
251 	return B_ERROR;
252 }
253 
254 
255 status_t
256 unix_error(uint32 code, net_buffer *data)
257 {
258 	return B_ERROR;
259 }
260 
261 
262 status_t
263 unix_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code,
264 	void *errorData)
265 {
266 	return B_ERROR;
267 }
268 
269 
270 status_t
271 unix_attach_ancillary_data(net_protocol *self, net_buffer *buffer,
272 	const cmsghdr *header)
273 {
274 	TRACE("[%ld] unix_attach_ancillary_data(%p, %p, %p (level: %d, type: %d, "
275 		"len: %d))\n", find_thread(NULL), self, buffer, header,
276 		header->cmsg_level, header->cmsg_type, (int)header->cmsg_len);
277 
278 	// we support only SCM_RIGHTS
279 	if (header->cmsg_level != SOL_SOCKET || header->cmsg_type != SCM_RIGHTS)
280 		return B_BAD_VALUE;
281 
282 	int* fds = (int*)CMSG_DATA(header);
283 	int count = (header->cmsg_len - CMSG_ALIGN(sizeof(cmsghdr))) / sizeof(int);
284 	if (count == 0)
285 		return B_BAD_VALUE;
286 
287 	file_descriptor** descriptors = new(std::nothrow) file_descriptor*[count];
288 	if (descriptors == NULL)
289 		return ENOBUFS;
290 	ArrayDeleter<file_descriptor*> _(descriptors);
291 	memset(descriptors, 0, sizeof(file_descriptor*) * count);
292 
293 	// get the file descriptors
294 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
295 
296 	status_t error = B_OK;
297 	for (int i = 0; i < count; i++) {
298 		descriptors[i] = get_fd(ioContext, fds[i]);
299 		if (descriptors[i] == NULL) {
300 			error = EBADF;
301 			break;
302 		}
303 	}
304 
305 	// attach the ancillary data to the buffer
306 	if (error == B_OK) {
307 		ancillary_data_header header;
308 		header.level = SOL_SOCKET;
309 		header.type = SCM_RIGHTS;
310 		header.len = count * sizeof(file_descriptor*);
311 
312 		TRACE("[%ld] unix_attach_ancillary_data(): attaching %d FDs to "
313 			"buffer\n", find_thread(NULL), count);
314 
315 		error = gBufferModule->attach_ancillary_data(buffer, &header,
316 			descriptors, destroy_scm_rights_descriptors, NULL);
317 	}
318 
319 	// cleanup on error
320 	if (error != B_OK) {
321 		for (int i = 0; i < count; i++) {
322 			if (descriptors[i] != NULL)
323 				put_fd(descriptors[i]);
324 		}
325 	}
326 
327 	return error;
328 }
329 
330 
331 ssize_t
332 unix_process_ancillary_data(net_protocol *self,
333 	const ancillary_data_header *header, const void *data, void *buffer,
334 	size_t bufferSize)
335 {
336 	TRACE("[%ld] unix_process_ancillary_data(%p, %p (level: %d, type: %d, "
337 		"len: %lu), %p, %p, %lu)\n", find_thread(NULL), self, header,
338 		header->level, header->type, header->len, data, buffer, bufferSize);
339 
340 	// we support only SCM_RIGHTS
341 	if (header->level != SOL_SOCKET || header->type != SCM_RIGHTS)
342 		return B_BAD_VALUE;
343 
344 	int count = header->len / sizeof(file_descriptor*);
345 	file_descriptor** descriptors = (file_descriptor**)data;
346 
347 	// check if there's enough space in the buffer
348 	size_t neededBufferSpace = CMSG_SPACE(sizeof(int) * count);
349 	if (bufferSize < neededBufferSpace)
350 		return B_BAD_VALUE;
351 
352 	// init header
353 	cmsghdr* messageHeader = (cmsghdr*)buffer;
354 	messageHeader->cmsg_level = header->level;
355 	messageHeader->cmsg_type = header->type;
356 	messageHeader->cmsg_len = CMSG_LEN(sizeof(int) * count);
357 
358 	// create FDs for the current process
359 	int* fds = (int*)CMSG_DATA(messageHeader);
360 	io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());
361 
362 	status_t error = B_OK;
363 	for (int i = 0; i < count; i++) {
364 		// get an additional reference which will go to the FD table index
365 		inc_fd_ref_count(descriptors[i]);
366 		fds[i] = new_fd(ioContext, descriptors[i]);
367 
368 		if (fds[i] < 0) {
369 			error = fds[i];
370 			put_fd(descriptors[i]);
371 
372 			// close FD indices
373 			for (int k = i - 1; k >= 0; k--)
374 				close_fd_index(ioContext, fds[k]);
375 			break;
376 		}
377 	}
378 
379 	return error == B_OK ? neededBufferSpace : error;
380 }
381 
382 
383 // #pragma mark -
384 
385 
386 status_t
387 init_unix()
388 {
389 	new(&gAddressManager) UnixAddressManager;
390 	status_t error = gAddressManager.Init();
391 	if (error != B_OK)
392 		return error;
393 
394 	error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_STREAM, 0,
395 		"network/protocols/unix/v1", NULL);
396 	if (error != B_OK) {
397 		gAddressManager.~UnixAddressManager();
398 		return error;
399 	}
400 
401 	error = gStackModule->register_domain(AF_UNIX, "unix", &gUnixModule,
402 		&gAddressModule, &sDomain);
403 	if (error != B_OK) {
404 		gAddressManager.~UnixAddressManager();
405 		return error;
406 	}
407 
408 	return B_OK;
409 }
410 
411 
412 status_t
413 uninit_unix()
414 {
415 	gStackModule->unregister_domain(sDomain);
416 
417 	gAddressManager.~UnixAddressManager();
418 
419 	return B_OK;
420 }
421 
422 
423 static status_t
424 unix_std_ops(int32 op, ...)
425 {
426 	switch (op) {
427 		case B_MODULE_INIT:
428 			return init_unix();
429 		case B_MODULE_UNINIT:
430 			return uninit_unix();
431 
432 		default:
433 			return B_ERROR;
434 	}
435 }
436 
437 
438 net_protocol_module_info gUnixModule = {
439 	{
440 		"network/protocols/unix/v1",
441 		0,
442 		unix_std_ops
443 	},
444 	0,	// NET_PROTOCOL_ATOMIC_MESSAGES,
445 
446 	unix_init_protocol,
447 	unix_uninit_protocol,
448 	unix_open,
449 	unix_close,
450 	unix_free,
451 	unix_connect,
452 	unix_accept,
453 	unix_control,
454 	unix_getsockopt,
455 	unix_setsockopt,
456 	unix_bind,
457 	unix_unbind,
458 	unix_listen,
459 	unix_shutdown,
460 	unix_send_data,
461 	unix_send_routed_data,
462 	unix_send_avail,
463 	unix_read_data,
464 	unix_read_avail,
465 	unix_get_domain,
466 	unix_get_mtu,
467 	unix_receive_data,
468 	unix_deliver_data,
469 	unix_error,
470 	unix_error_reply,
471 	unix_attach_ancillary_data,
472 	unix_process_ancillary_data
473 };
474 
475 module_dependency module_dependencies[] = {
476 	{NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
477 	{NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
478 //	{NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule},
479 	{NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
480 	{}
481 };
482 
483 module_info *modules[] = {
484 	(module_info *)&gUnixModule,
485 	NULL
486 };
487