xref: /haiku/src/tests/system/kernel/syscall_restart_test.cpp (revision f9eba888cacd8d30171db256c98aa88cebcf5b17)
1 #include <errno.h>
2 #include <netinet/in.h>
3 #include <signal.h>
4 #include <stdarg.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <sys/socket.h>
9 
10 #include <OS.h>
11 
12 
13 enum run_mode {
14 	RUN_IGNORE_SIGNAL,
15 	RUN_HANDLE_SIGNAL,
16 	RUN_HANDLE_SIGNAL_RESTART
17 };
18 
19 
20 class Test {
21 public:
Test(const char * name)22 	Test(const char* name)
23 		: fName(name)
24 	{
25 	}
26 
~Test()27 	virtual ~Test()
28 	{
29 	}
30 
Run(run_mode mode)31 	bool Run(run_mode mode)
32 	{
33 		fRunMode = mode;
34 		fSignalHandlerCalled = false;
35 
36 		status_t error = Prepare();
37 		if (error != B_OK)
38 			return Error("Failed to prepare test: %s", strerror(error));
39 
40 		thread_id thread = spawn_thread(_ThreadEntry, fName, B_NORMAL_PRIORITY,
41 			this);
42 		if (thread < 0)
43 			return Error("Failed to spawn thread: %s\n", strerror(thread));
44 
45 		resume_thread(thread);
46 
47 		// ...
48 		// * interrupt without restart
49 		// * interrupt with restart
50 
51 		snooze(100000);
52 		kill(thread, SIGINT);
53 
54 		PrepareFinish();
55 
56 		status_t result;
57 		wait_for_thread(thread, &result);
58 
59 		if (result != (Interrupted() ? B_INTERRUPTED : B_OK)) {
60 			return Error("Unexpected syscall return value: %s\n",
61 				strerror(result));
62 		}
63 
64 		if ((RunMode() == RUN_IGNORE_SIGNAL) == fSignalHandlerCalled) {
65 			if (RunMode() == RUN_IGNORE_SIGNAL)
66 				return Error("Handler was called but shouldn't have been!");
67 			else
68 				return Error("Handler was not called!");
69 		}
70 
71 		return Finish(Interrupted());
72 	}
73 
Run()74 	void Run()
75 	{
76 		printf("%s\n", fName);
77 
78 		struct {
79 			const char*	name;
80 			run_mode	mode;
81 		} tests[] = {
82 			{ "ignore signal", RUN_IGNORE_SIGNAL },
83 			{ "handle signal no restart", RUN_HANDLE_SIGNAL },
84 			{ "handle signal restart", RUN_HANDLE_SIGNAL_RESTART },
85 			{}
86 		};
87 
88 		for (int i = 0; tests[i].name != NULL; i++) {
89 			printf("  %-30s: ", tests[i].name);
90 			fflush(stdout);
91 			ClearError();
92 			if (Run(tests[i].mode))
93 				printf("ok\n");
94 			else
95 				printf("failed (%s)\n", fError);
96 
97 			Cleanup();
98 		}
99 	}
100 
RunMode() const101 	run_mode RunMode() const { return fRunMode; }
Interrupted() const102 	bool Interrupted() const { return RunMode() == RUN_HANDLE_SIGNAL; }
TimeWaited() const103 	bigtime_t TimeWaited() const { return fTimeWaited; }
104 
105 protected:
Prepare()106 	virtual status_t Prepare()
107 	{
108 		return B_OK;
109 	}
110 
111 	virtual status_t DoSyscall() = 0;
112 
HandleSignal()113 	virtual void HandleSignal()
114 	{
115 	}
116 
PrepareFinish()117 	virtual void PrepareFinish()
118 	{
119 	}
120 
Finish(bool interrupted)121 	virtual bool Finish(bool interrupted)
122 	{
123 		return true;
124 	}
125 
Cleanup()126 	virtual void Cleanup()
127 	{
128 	}
129 
Error(const char * format,...)130 	bool Error(const char* format,...)
131 	{
132 		va_list args;
133 		va_start(args, format);
134 		vsnprintf(fError, sizeof(fError), format, args);
135 		va_end(args);
136 
137 		return false;
138 	}
139 
Check(bool condition,const char * format,...)140 	bool Check(bool condition, const char* format,...)
141 	{
142 		if (condition)
143 			return true;
144 
145 		va_list args;
146 		va_start(args, format);
147 		vsnprintf(fError, sizeof(fError), format, args);
148 		va_end(args);
149 
150 		return false;
151 	}
152 
ClearError()153 	void ClearError()
154 	{
155 		fError[0] = '\0';
156 	}
157 
158 private:
_ThreadEntry(void * data)159 	static status_t _ThreadEntry(void* data)
160 	{
161 		return ((Test*)data)->_TestThread();
162 	}
163 
_SignalHandler(int signal,char * userData)164 	static void _SignalHandler(int signal, char* userData)
165 	{
166 		Test* self = (Test*)userData;
167 
168 		self->fSignalHandlerCalled = true;
169 		self->HandleSignal();
170 	}
171 
_TestThread()172 	status_t _TestThread()
173 	{
174 		// install handler
175 		struct sigaction action;
176 		if (RunMode() == RUN_IGNORE_SIGNAL)
177 			action.sa_handler = SIG_IGN;
178 		else
179 			action.sa_handler = (void (*)(int))_SignalHandler;
180 
181 		action.sa_flags = RunMode() == RUN_HANDLE_SIGNAL_RESTART
182 			? SA_RESTART : 0;
183 
184 		sigemptyset(&action.sa_mask);
185 		action.sa_userdata = this;
186 
187 		sigaction(SIGINT, &action, NULL);
188 
189 		bigtime_t startTime = system_time();
190 		status_t status = DoSyscall();
191 		fTimeWaited = system_time() - startTime;
192 
193 		return status;
194 	}
195 
196 private:
197 	const char*	fName;
198 	run_mode	fRunMode;
199 	bool		fSignalHandlerCalled;
200 	bigtime_t	fTimeWaited;
201 	char		fError[1024];
202 };
203 
204 
205 class SnoozeTest : public Test {
206 public:
SnoozeTest()207 	SnoozeTest()
208 		: Test("snooze")
209 	{
210 	}
211 
DoSyscall()212 	virtual status_t DoSyscall()
213 	{
214 		return snooze(1000000);
215 	}
216 
Finish(bool interrupted)217 	virtual bool Finish(bool interrupted)
218 	{
219 		if (interrupted)
220 			return Check(TimeWaited() < 200000, "waited too long");
221 
222 		return Check(TimeWaited() > 900000 && TimeWaited() < 1100000,
223 			"waited %lld us instead of 1000000 us", TimeWaited());
224 	}
225 };
226 
227 
228 class ReadTest : public Test {
229 public:
ReadTest()230 	ReadTest()
231 		: Test("read")
232 	{
233 	}
234 
Prepare()235 	virtual status_t Prepare()
236 	{
237 		fBytesRead = -1;
238 		fFDs[0] = -1;
239 		fFDs[1] = -1;
240 
241 		if (pipe(fFDs) != 0)
242 			return errno;
243 
244 		return B_OK;
245 	}
246 
DoSyscall()247 	virtual status_t DoSyscall()
248 	{
249 		char buffer[256];
250 		fBytesRead = read(fFDs[0], buffer, sizeof(buffer));
251 
252 		return fBytesRead < 0 ? errno : B_OK;
253 	}
254 
PrepareFinish()255 	virtual void PrepareFinish()
256 	{
257 		write(fFDs[1], "Ingo", 4);
258 	}
259 
Finish(bool interrupted)260 	virtual bool Finish(bool interrupted)
261 	{
262 		if (interrupted)
263 			return Check(fBytesRead < 0, "unexpected read");
264 
265 		return Check(fBytesRead == 4, "should have read 4 bytes, read only %ld "
266 			"bytes", fBytesRead);
267 	}
268 
Cleanup()269 	virtual void Cleanup()
270 	{
271 		close(fFDs[0]);
272 		close(fFDs[1]);
273 	}
274 
275 private:
276 	bigtime_t	fTimeWaited;
277 	ssize_t		fBytesRead;
278 	int			fFDs[2];
279 };
280 
281 
282 class WriteTest : public Test {
283 public:
WriteTest()284 	WriteTest()
285 		: Test("write")
286 	{
287 	}
288 
Prepare()289 	virtual status_t Prepare()
290 	{
291 		fBytesWritten = -1;
292 		fFDs[0] = -1;
293 		fFDs[1] = -1;
294 
295 		if (pipe(fFDs) != 0)
296 			return errno;
297 
298 		// fill pipe
299 		fcntl(fFDs[1], F_SETFL, O_NONBLOCK);
300 		while (write(fFDs[1], "a", 1) == 1);
301 
302 		return B_OK;
303 	}
304 
DoSyscall()305 	virtual status_t DoSyscall()
306 	{
307 		// blocking wait
308 		fcntl(fFDs[1], F_SETFL, 0);
309 		fBytesWritten = write(fFDs[1], "Ingo", 4);
310 
311 		return fBytesWritten < 0 ? errno : B_OK;
312 	}
313 
PrepareFinish()314 	virtual void PrepareFinish()
315 	{
316 		char buffer[256];
317 		read(fFDs[0], buffer, sizeof(buffer));
318 	}
319 
Finish(bool interrupted)320 	virtual bool Finish(bool interrupted)
321 	{
322 		if (interrupted)
323 			return Check(fBytesWritten < 0, "unexpected write");
324 
325 		return Check(fBytesWritten == 4, "should have written 4 bytes, wrote only %ld "
326 			"bytes", fBytesWritten);
327 	}
328 
Cleanup()329 	virtual void Cleanup()
330 	{
331 		close(fFDs[0]);
332 		close(fFDs[1]);
333 	}
334 
335 private:
336 	ssize_t		fBytesWritten;
337 	int			fFDs[2];
338 };
339 
340 
341 class AcquireSwitchSemTest : public Test {
342 public:
AcquireSwitchSemTest(bool useSwitch)343 	AcquireSwitchSemTest(bool useSwitch)
344 		: Test(useSwitch ? "switch_sem" : "acquire_sem"),
345 		fSwitch(useSwitch)
346 	{
347 	}
348 
Prepare()349 	virtual status_t Prepare()
350 	{
351 		fSemaphore = create_sem(0, "test sem");
352 
353 		return (fSemaphore >= 0 ? B_OK : fSemaphore);
354 	}
355 
DoSyscall()356 	virtual status_t DoSyscall()
357 	{
358 		if (fSwitch)
359 			return switch_sem(-1, fSemaphore);
360 
361 		return acquire_sem(fSemaphore);
362 	}
363 
PrepareFinish()364 	virtual void PrepareFinish()
365 	{
366 		release_sem(fSemaphore);
367 	}
368 
369 /*
370 	virtual bool Finish(bool interrupted)
371 	{
372 //		int32 semCount = -1;
373 //		get_sem_count(fSemaphore, &semCount);
374 
375 		if (interrupted)
376 			return true;
377 
378 		return Check(fBytesWritten == 4, "should have written 4 bytes, wrote only %ld "
379 			"bytes", fBytesWritten);
380 	}
381 */
382 
Cleanup()383 	virtual void Cleanup()
384 	{
385 		delete_sem(fSemaphore);
386 	}
387 
388 protected:
389 	sem_id		fSemaphore;
390 	bool		fSwitch;
391 };
392 
393 
394 class AcquireSwitchSemEtcTest : public Test {
395 public:
AcquireSwitchSemEtcTest(bool useSwitch)396 	AcquireSwitchSemEtcTest(bool useSwitch)
397 		: Test(useSwitch ? "switch_sem_etc" : "acquire_sem_etc"),
398 		fSwitch(useSwitch)
399 	{
400 	}
401 
Prepare()402 	virtual status_t Prepare()
403 	{
404 		fSemaphore = create_sem(0, "test sem");
405 
406 		return fSemaphore >= 0 ? B_OK : fSemaphore;
407 	}
408 
DoSyscall()409 	virtual status_t DoSyscall()
410 	{
411 		status_t status;
412 		if (fSwitch) {
413 			status = switch_sem_etc(-1, fSemaphore, 1, B_RELATIVE_TIMEOUT,
414 				1000000);
415 		} else {
416 			status = acquire_sem_etc(fSemaphore, 1, B_RELATIVE_TIMEOUT,
417 				1000000);
418 		}
419 
420 		if (!Interrupted() && status == B_TIMED_OUT)
421 			return B_OK;
422 
423 		return status;
424 	}
425 
Finish(bool interrupted)426 	virtual bool Finish(bool interrupted)
427 	{
428 		if (interrupted)
429 			return Check(TimeWaited() < 200000, "waited too long");
430 
431 		return Check(TimeWaited() > 900000 && TimeWaited() < 1100000,
432 			"waited %lld us instead of 1000000 us", TimeWaited());
433 	}
434 
Cleanup()435 	virtual void Cleanup()
436 	{
437 		delete_sem(fSemaphore);
438 	}
439 
440 protected:
441 	sem_id		fSemaphore;
442 	bool		fSwitch;
443 };
444 
445 
446 class AcceptTest : public Test {
447 public:
AcceptTest()448 	AcceptTest()
449 		: Test("accept")
450 	{
451 	}
452 
Prepare()453 	virtual status_t Prepare()
454 	{
455 		fAcceptedSocket = -1;
456 
457 		fServerSocket = socket(AF_INET, SOCK_STREAM, 0);
458 		if (fServerSocket < 0) {
459 			fprintf(stderr, "Could not open server socket: %s\n",
460 				strerror(errno));
461 			return errno;
462 		}
463 
464 		int reuse = 1;
465 		if (setsockopt(fServerSocket, SOL_SOCKET, SO_REUSEADDR, &reuse,
466 				sizeof(int)) == -1) {
467 			fprintf(stderr, "Could not make server socket reusable: %s\n",
468 				strerror(errno));
469 			return errno;
470 		}
471 
472 		memset(&fServerAddress, 0, sizeof(sockaddr_in));
473 		fServerAddress.sin_family = AF_INET;
474 		fServerAddress.sin_addr.s_addr = INADDR_LOOPBACK;
475 
476 		if (bind(fServerSocket, (struct sockaddr *)&fServerAddress,
477 				sizeof(struct sockaddr)) == -1) {
478 			fprintf(stderr, "Could not bind server socket: %s\n",
479 				strerror(errno));
480 			return errno;
481 		}
482 
483 		socklen_t length = sizeof(sockaddr_in);
484 		getsockname(fServerSocket, (sockaddr*)&fServerAddress,
485 			&length);
486 
487 		if (listen(fServerSocket, 10) == -1) {
488 			fprintf(stderr, "Could not listen on server socket: %s\n",
489 				strerror(errno));
490 			return errno;
491 		}
492 
493 		return B_OK;
494 	}
495 
DoSyscall()496 	virtual status_t DoSyscall()
497 	{
498 		sockaddr_in clientAddress;
499 		socklen_t length = sizeof(struct sockaddr_in);
500 
501 		fAcceptedSocket = accept(fServerSocket,
502 			(struct sockaddr *)&clientAddress, &length);
503 		if (fAcceptedSocket == -1)
504 			return errno;
505 
506 		return B_OK;
507 	}
508 
PrepareFinish()509 	virtual void PrepareFinish()
510 	{
511 		if (Interrupted())
512 			return;
513 
514 		int clientSocket = socket(AF_INET, SOCK_STREAM, 0);
515 		if (clientSocket == -1) {
516 			fprintf(stderr, "Could not open client socket: %s\n",
517 				strerror(errno));
518 			return;
519 		}
520 
521 		if (connect(clientSocket, (struct sockaddr *)&fServerAddress,
522 				sizeof(struct sockaddr)) == -1) {
523 			fprintf(stderr, "Could not connect to server socket: %s\n",
524 				strerror(errno));
525 		}
526 
527 		close(clientSocket);
528 	}
529 
Finish(bool interrupted)530 	virtual bool Finish(bool interrupted)
531 	{
532 		if (interrupted)
533 			return Check(fAcceptedSocket < 0, "got socket");
534 
535 		return Check(fAcceptedSocket >= 0, "got no socket");
536 	}
537 
Cleanup()538 	virtual void Cleanup()
539 	{
540 		close(fAcceptedSocket);
541 		close(fServerSocket);
542 	}
543 
544 protected:
545 	int			fServerSocket;
546 	sockaddr_in	fServerAddress;
547 	int			fAcceptedSocket;
548 };
549 
550 
551 class ReceiveTest : public Test {
552 public:
ReceiveTest()553 	ReceiveTest()
554 		: Test("recv")
555 	{
556 	}
557 
Prepare()558 	virtual status_t Prepare()
559 	{
560 		fBytesRead = -1;
561 		fAcceptedSocket = -1;
562 		fClientSocket = -1;
563 
564 		fServerSocket = socket(AF_INET, SOCK_STREAM, 0);
565 		if (fServerSocket < 0) {
566 			fprintf(stderr, "Could not open server socket: %s\n",
567 				strerror(errno));
568 			return errno;
569 		}
570 
571 		int reuse = 1;
572 		if (setsockopt(fServerSocket, SOL_SOCKET, SO_REUSEADDR, &reuse,
573 				sizeof(int)) == -1) {
574 			fprintf(stderr, "Could not make server socket reusable: %s\n",
575 				strerror(errno));
576 			return errno;
577 		}
578 
579 		memset(&fServerAddress, 0, sizeof(sockaddr_in));
580 		fServerAddress.sin_family = AF_INET;
581 		fServerAddress.sin_addr.s_addr = INADDR_LOOPBACK;
582 
583 		if (bind(fServerSocket, (struct sockaddr *)&fServerAddress,
584 				sizeof(struct sockaddr)) == -1) {
585 			fprintf(stderr, "Could not bind server socket: %s\n",
586 				strerror(errno));
587 			return errno;
588 		}
589 
590 		socklen_t length = sizeof(sockaddr_in);
591 		getsockname(fServerSocket, (sockaddr*)&fServerAddress,
592 			&length);
593 
594 		if (listen(fServerSocket, 10) == -1) {
595 			fprintf(stderr, "Could not listen on server socket: %s\n",
596 				strerror(errno));
597 			return errno;
598 		}
599 
600 		fClientSocket = socket(AF_INET, SOCK_STREAM, 0);
601 		if (fClientSocket == -1) {
602 			fprintf(stderr, "Could not open client socket: %s\n",
603 				strerror(errno));
604 			return errno;
605 		}
606 
607 		fcntl(fClientSocket, F_SETFL, O_NONBLOCK);
608 
609 		if (connect(fClientSocket, (struct sockaddr *)&fServerAddress,
610 				sizeof(struct sockaddr)) == -1) {
611 			if (errno != EINPROGRESS) {
612 				fprintf(stderr, "Could not connect to server socket: %s\n",
613 					strerror(errno));
614 				return errno;
615 			}
616 		}
617 
618 		sockaddr_in clientAddress;
619 		length = sizeof(struct sockaddr_in);
620 
621 		fAcceptedSocket = accept(fServerSocket,
622 			(struct sockaddr *)&clientAddress, &length);
623 		if (fAcceptedSocket == -1)
624 			return errno;
625 
626 		fcntl(fClientSocket, F_SETFL, 0);
627 
628 		snooze(100000);
629 
630 		return B_OK;
631 	}
632 
DoSyscall()633 	virtual status_t DoSyscall()
634 	{
635 		char buffer[256];
636 		fBytesRead = recv(fAcceptedSocket, buffer, sizeof(buffer), 0);
637 
638 		return fBytesRead < 0 ? errno : B_OK;
639 	}
640 
PrepareFinish()641 	virtual void PrepareFinish()
642 	{
643 		write(fClientSocket, "Axel", 4);
644 	}
645 
Finish(bool interrupted)646 	virtual bool Finish(bool interrupted)
647 	{
648 		if (interrupted)
649 			return Check(fBytesRead < 0, "unexpected read");
650 
651 		return Check(fBytesRead == 4, "should have read 4 bytes, read only %ld "
652 			"bytes", fBytesRead);
653 	}
654 
Cleanup()655 	virtual void Cleanup()
656 	{
657 		close(fAcceptedSocket);
658 		close(fServerSocket);
659 		close(fClientSocket);
660 	}
661 
662 protected:
663 	int			fServerSocket;
664 	sockaddr_in	fServerAddress;
665 	int			fAcceptedSocket;
666 	int			fClientSocket;
667 	ssize_t		fBytesRead;
668 };
669 
670 
671 int
main()672 main()
673 {
674 	Test* tests[] = {
675 		new SnoozeTest,
676 		new ReadTest,
677 		new WriteTest,
678 		new AcquireSwitchSemTest(false),
679 		new AcquireSwitchSemTest(true),
680 		new AcquireSwitchSemEtcTest(false),
681 		new AcquireSwitchSemEtcTest(true),
682 		new AcceptTest,
683 		new ReceiveTest,
684 		NULL
685 	};
686 
687 	for (int i = 0; tests[i] != NULL; i++)
688 		tests[i]->Run();
689 
690 	return 0;
691 }
692