xref: /haiku/src/tests/kits/net/NetEndpointTest.cpp (revision b671e9bbdbd10268a042b4f4cc4317ccd03d105e)
1 /*
2  * Copyright 2008, Oliver Tappe, zooey@hirschkaefer.de.
3  * Distributed under the terms of the MIT license.
4  */
5 
6 
7 #include <Message.h>
8 #include <NetEndpoint.h>
9 
10 #include <errno.h>
11 #include <netinet/in.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <sys/wait.h>
16 
17 
18 static BNetAddress serverAddr("127.0.0.1", 1234);
19 static BNetAddress clientAddr("127.0.0.1", 51234);
20 
21 
22 static int problemCount = 0;
23 
24 
25 void
26 checkAddrsAreEqual(const BNetAddress& na1, const BNetAddress& na2,
27 	const char* fmt)
28 {
29 	in_addr addr1, addr2;
30 	unsigned short port1, port2;
31 	na1.GetAddr(addr1, &port1);
32 	na2.GetAddr(addr2, &port2);
33 	if (addr1.s_addr == addr2.s_addr && port1 == port2)
34 		return;
35 	fprintf(stderr, fmt, addr1.s_addr, port1, addr2.s_addr, port2);
36 	exit(1);
37 }
38 
39 
40 void
41 checkArchive(const BNetEndpoint ne, int32 protocol,
42 	const BNetAddress& localNetAddress, const BNetAddress& remoteNetAddress)
43 {
44 	in_addr localAddr, remoteAddr;
45 	unsigned short localPort, remotePort;
46 	localNetAddress.GetAddr(localAddr, &localPort);
47 	remoteNetAddress.GetAddr(remoteAddr, &remotePort);
48 
49 	BMessage archive(0UL);
50 	status_t status = ne.Archive(&archive);
51 	if (status != B_OK) {
52 		fprintf(stderr, "Archive() failed - %lx:%s\n", status,
53 			strerror(status));
54 		problemCount++;
55 		exit(1);
56 	}
57 	const char* arcClass;
58 	if (archive.FindString("class", &arcClass) != B_OK) {
59 		fprintf(stderr, "'class' not found in archive\n");
60 		problemCount++;
61 		exit(1);
62 	}
63 	if (strcmp(arcClass, "BNetEndpoint") != 0) {
64 		fprintf(stderr, "expected 'class' to be 'BNetEndpoint' - is '%s'\n",
65 			arcClass);
66 		problemCount++;
67 		exit(1);
68 	}
69 
70 	if (ne.LocalAddr().InitCheck() == B_OK) {
71 		int32 arcAddr;
72 		if (archive.FindInt32("_BNetEndpoint_addr_addr", &arcAddr) != B_OK) {
73 			fprintf(stderr, "'_BNetEndpoint_addr_addr' not found in archive\n");
74 			problemCount++;
75 			exit(1);
76 		}
77 		if ((uint32)localAddr.s_addr != (uint32)arcAddr) {
78 			fprintf(stderr,
79 				"expected '_BNetEndpoint_addr_addr' to be %x - is %x\n",
80 				localAddr.s_addr, (unsigned int)arcAddr);
81 			problemCount++;
82 			exit(1);
83 		}
84 		int16 arcPort;
85 		if (archive.FindInt16("_BNetEndpoint_addr_port", &arcPort) != B_OK) {
86 			fprintf(stderr, "'_BNetEndpoint_addr_port' not found in archive\n");
87 			problemCount++;
88 			exit(1);
89 		}
90 		if ((uint16)localPort != (uint16)arcPort) {
91 			fprintf(stderr,
92 				"expected '_BNetEndpoint_addr_port' to be %d - is %d\n",
93 				localPort, (int)arcPort);
94 			problemCount++;
95 			exit(1);
96 		}
97 	}
98 
99 	if (ne.RemoteAddr().InitCheck() == B_OK) {
100 		int32 arcAddr;
101 		if (archive.FindInt32("_BNetEndpoint_peer_addr", &arcAddr) != B_OK) {
102 			fprintf(stderr, "'_BNetEndpoint_peer_addr' not found in archive\n");
103 			problemCount++;
104 			exit(1);
105 		}
106 		if ((uint32)remoteAddr.s_addr != (uint32)arcAddr) {
107 			fprintf(stderr,
108 				"expected '_BNetEndpoint_peer_addr' to be %x - is %x\n",
109 				remoteAddr.s_addr, (unsigned int)arcAddr);
110 			problemCount++;
111 			exit(1);
112 		}
113 		int16 arcPort;
114 		if (archive.FindInt16("_BNetEndpoint_peer_port", &arcPort) != B_OK) {
115 			fprintf(stderr, "'_BNetEndpoint_peer_port' not found in archive\n");
116 			problemCount++;
117 			exit(1);
118 		}
119 		if ((uint16)remotePort != (uint16)arcPort) {
120 			fprintf(stderr,
121 				"expected '_BNetEndpoint_peer_port' to be %u - is %u\n",
122 				remotePort, (unsigned short)arcPort);
123 			problemCount++;
124 			exit(1);
125 		}
126 	}
127 
128 	int64 arcTimeout;
129 	if (archive.FindInt64("_BNetEndpoint_timeout", &arcTimeout) != B_OK) {
130 		fprintf(stderr, "'_BNetEndpoint_timeout' not found in archive\n");
131 		problemCount++;
132 		exit(1);
133 	}
134 	if (arcTimeout != B_INFINITE_TIMEOUT) {
135 		fprintf(stderr,
136 			"expected '_BNetEndpoint_timeout' to be %llu - is %llu\n",
137 			B_INFINITE_TIMEOUT, (uint64)arcTimeout);
138 		problemCount++;
139 		exit(1);
140 	}
141 
142 	int32 arcProtocol;
143 	if (archive.FindInt32("_BNetEndpoint_proto", &arcProtocol) != B_OK) {
144 		fprintf(stderr, "'_BNetEndpoint_proto' not found in archive\n");
145 		problemCount++;
146 		exit(1);
147 	}
148 	if (arcProtocol != protocol) {
149 		fprintf(stderr, "expected '_BNetEndpoint_proto' to be %d - is %d\n",
150 			(int)protocol, (int)arcProtocol);
151 		problemCount++;
152 		exit(1);
153 	}
154 
155 	BNetEndpoint* clone
156 		= dynamic_cast<BNetEndpoint *>(BNetEndpoint::Instantiate(&archive));
157 	if (!clone) {
158 		fprintf(stderr, "unable to instantiate endpoint from archive\n");
159 		problemCount++;
160 		exit(1);
161 	}
162 	delete clone;
163 }
164 
165 void testServer(thread_id clientThread)
166 {
167 	char buf[1];
168 
169 	// check simple UDP "connection"
170 	BNetEndpoint server(SOCK_DGRAM);
171 	for(int i=0; i < 2; ++i) {
172 		status_t status = server.Bind(serverAddr);
173 		if (status != B_OK) {
174 			fprintf(stderr, "Bind() failed in testServer - %s\n",
175 				strerror(status));
176 			problemCount++;
177 			exit(1);
178 		}
179 
180 		checkAddrsAreEqual(server.LocalAddr(), serverAddr,
181 			"LocalAddr() doesn't match serverAddr\n");
182 
183 		if (i == 0)
184 			resume_thread(clientThread);
185 
186 		BNetAddress remoteAddr;
187 		status = server.ReceiveFrom(buf, 1, remoteAddr, 0);
188 		if (status < B_OK) {
189 			fprintf(stderr, "ReceiveFrom() failed in testServer - %s\n",
190 				strerror(status));
191 			problemCount++;
192 			exit(1);
193 		}
194 
195 		if (buf[0] != 'U') {
196 			fprintf(stderr, "expected to receive %c but got %c\n", 'U', buf[0]);
197 			problemCount++;
198 			exit(1);
199 		}
200 
201 		checkAddrsAreEqual(remoteAddr, clientAddr,
202 			"remoteAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
203 
204 		checkArchive(server, SOCK_DGRAM, serverAddr, clientAddr);
205 
206 		server.Close();
207 	}
208 
209 	// now switch to TCP and try again
210 	server.SetProtocol(SOCK_STREAM);
211 	status_t status = server.Bind(serverAddr);
212 	if (status != B_OK) {
213 		fprintf(stderr, "Bind() failed in testServer - %s\n",
214 			strerror(status));
215 		problemCount++;
216 		exit(1);
217 	}
218 
219 	checkAddrsAreEqual(server.LocalAddr(), serverAddr,
220 		"LocalAddr() doesn't match serverAddr\n");
221 
222 	status = server.Listen();
223 	BNetEndpoint* acceptedConn = server.Accept();
224 	if (acceptedConn == NULL) {
225 		fprintf(stderr, "Accept() failed in testServer\n");
226 		problemCount++;
227 		exit(1);
228 	}
229 
230 	const BNetAddress& remoteAddr = acceptedConn->RemoteAddr();
231 	checkAddrsAreEqual(remoteAddr, clientAddr,
232 		"remoteAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
233 
234 	status = acceptedConn->Receive(buf, 1);
235 	if (status < B_OK) {
236 		fprintf(stderr, "Receive() failed in testServer - %s\n",
237 			strerror(status));
238 		problemCount++;
239 		exit(1);
240 	}
241 	delete acceptedConn;
242 
243 	if (buf[0] != 'T') {
244 		fprintf(stderr, "expected to receive %c but got %c\n", 'T', buf[0]);
245 		problemCount++;
246 		exit(1);
247 	}
248 
249 	checkArchive(server, SOCK_STREAM, serverAddr, clientAddr);
250 
251 	server.Close();
252 }
253 
254 
255 int32 testClient(void *)
256 {
257 	BNetEndpoint client(SOCK_DGRAM);
258 	printf("testing udp...\n");
259 	for(int i=0; i < 2; ++i) {
260 		status_t status = client.Bind(clientAddr);
261 		if (status != B_OK) {
262 			fprintf(stderr, "Bind() failed in testClient - %s\n",
263 				strerror(status));
264 			problemCount++;
265 			exit(1);
266 		}
267 
268 		checkAddrsAreEqual(client.LocalAddr(), clientAddr,
269 			"LocalAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
270 
271 		status = client.SendTo("U", 1, serverAddr, 0);
272 		if (status < B_OK) {
273 			fprintf(stderr, "SendTo() failed in testClient - %s\n",
274 				strerror(status));
275 			problemCount++;
276 			exit(1);
277 		}
278 
279 		checkArchive(client, SOCK_DGRAM, clientAddr, serverAddr);
280 
281 		sleep(1);
282 
283 		client.Close();
284 	}
285 
286 	sleep(1);
287 
288 	printf("testing tcp...\n");
289 	// now switch to TCP and try again
290 	client.SetProtocol(SOCK_STREAM);
291 	status_t status = client.Bind(clientAddr);
292 	if (status != B_OK) {
293 		fprintf(stderr, "Bind() failed in testClient - %s\n",
294 			strerror(status));
295 		problemCount++;
296 		exit(1);
297 	}
298 
299 	checkAddrsAreEqual(client.LocalAddr(), clientAddr,
300 		"LocalAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
301 
302 	status = client.Connect(serverAddr);
303 	if (status < B_OK) {
304 		fprintf(stderr, "Connect() failed in testClient - %s\n",
305 			strerror(status));
306 		problemCount++;
307 		exit(1);
308 	}
309 	status = client.Send("T", 1);
310 	if (status < B_OK) {
311 		fprintf(stderr, "Send() failed in testClient - %s\n",
312 			strerror(status));
313 		problemCount++;
314 		exit(1);
315 	}
316 
317 	checkArchive(client, SOCK_STREAM, clientAddr, serverAddr);
318 
319 	client.Close();
320 
321 	return B_OK;
322 }
323 
324 
325 int
326 main(int argc, const char* const* argv)
327 {
328 	BNetEndpoint dummy(SOCK_DGRAM);
329 	if (sizeof(dummy) != 208) {
330 		fprintf(stderr, "expected sizeof(netEndpoint) to be 208 - is %ld\n",
331 			sizeof(dummy));
332 		exit(1);
333 	}
334 	dummy.Close();
335 
336 	// start thread for client
337 	thread_id tid = spawn_thread(testClient, "client", B_NORMAL_PRIORITY, NULL);
338 	if (tid < 0) {
339 		fprintf(stderr, "spawn_thread() failed: %s\n", strerror(tid));
340 		exit(1);
341 	}
342 
343 	testServer(tid);
344 
345 	status_t clientStatus;
346 	wait_for_thread(tid, &clientStatus);
347 
348 	if (!problemCount)
349 		printf("Everything went fine.\n");
350 
351 	return 0;
352 }
353