xref: /haiku/src/tools/remote_disk_server/remote_disk_server.cpp (revision 3cb015b1ee509d69c643506e8ff573808c86dcfc)
1 /*
2  * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3  * All rights reserved. Distributed under the terms of the MIT License.
4  */
5 
6 #include <endian.h>
7 #include <errno.h>
8 #include <fcntl.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <stdint.h>
12 #include <string.h>
13 #include <unistd.h>
14 #include <netinet/in.h>
15 #include <sys/socket.h>
16 #include <sys/stat.h>
17 
18 #include <boot/net/RemoteDiskDefs.h>
19 
20 
21 #if __BYTE_ORDER == __LITTLE_ENDIAN
22 
23 static inline
24 uint64_t swap_uint64(uint64_t data)
25 {
26 	return ((data & 0xff) << 56)
27 		| ((data & 0xff00) << 40)
28 		| ((data & 0xff0000) << 24)
29 		| ((data & 0xff000000) << 8)
30 		| ((data >> 8) & 0xff000000)
31 		| ((data >> 24) & 0xff0000)
32 		| ((data >> 40) & 0xff00)
33 		| ((data >> 56) & 0xff);
34 }
35 
36 #define host_to_net64(data)	swap_uint64(data)
37 #define net_to_host64(data)	swap_uint64(data)
38 
39 #endif
40 
41 #if __BYTE_ORDER == __BIG_ENDIAN
42 #define host_to_net64(data)	(data)
43 #define net_to_host64(data)	(data)
44 #endif
45 
46 #undef htonll
47 #undef ntohll
48 #define htonll(data)	host_to_net64(data)
49 #define ntohll(data)	net_to_host64(data)
50 
51 
52 class Server {
53 public:
54 	Server(const char *fileName)
55 		: fImagePath(fileName),
56 		  fImageFD(-1),
57 		  fImageSize(0),
58 		  fSocket(-1)
59 	{
60 	}
61 
62 	int Run()
63 	{
64 		_CreateSocket();
65 
66 		// main server loop
67 		for (;;) {
68 			// receive
69 			fClientAddress.sin_family = AF_INET;
70 			fClientAddress.sin_port = 0;
71 			fClientAddress.sin_addr.s_addr = htonl(INADDR_ANY);
72 			socklen_t addrSize = sizeof(fClientAddress);
73 			char buffer[2048];
74 			ssize_t bytesRead = recvfrom(fSocket, buffer, sizeof(buffer), 0,
75 								(sockaddr*)&fClientAddress, &addrSize);
76 			// handle error
77 			if (bytesRead < 0) {
78 				if (errno == EINTR)
79 					continue;
80 				fprintf(stderr, "Error: Failed to read from socket: %s.\n",
81 					strerror(errno));
82 				exit(1);
83 			}
84 
85 			// short package?
86 			if (bytesRead < (ssize_t)sizeof(remote_disk_header)) {
87 				fprintf(stderr, "Dropping short request package (%d bytes).\n",
88 					bytesRead);
89 				continue;
90 			}
91 
92 			fRequest = (remote_disk_header*)buffer;
93 			fRequestSize = bytesRead;
94 
95 			switch (fRequest->command) {
96 				case REMOTE_DISK_HELLO_REQUEST:
97 					_HandleHelloRequest();
98 					break;
99 
100 				case REMOTE_DISK_READ_REQUEST:
101 					_HandleReadRequest();
102 					break;
103 
104 				case REMOTE_DISK_WRITE_REQUEST:
105 					_HandleWriteRequest();
106 					break;
107 
108 				default:
109 					fprintf(stderr, "Ignoring invalid request %d.\n",
110 						(int)fRequest->command);
111 					break;
112 			}
113 		}
114 
115 		return 0;
116 	}
117 
118 private:
119 	void _OpenImage(bool reopen)
120 	{
121 		// already open?
122 		if (fImageFD >= 0) {
123 			if (!reopen)
124 				return;
125 
126 			close(fImageFD);
127 			fImageFD = -1;
128 			fImageSize = 0;
129 		}
130 
131 		// open the image
132 		fImageFD = open(fImagePath, O_RDWR);
133 		if (fImageFD < 0) {
134 			fprintf(stderr, "Error: Failed to open \"%s\": %s.\n", fImagePath,
135 				strerror(errno));
136 			exit(1);
137 		}
138 
139 		// get its size
140 		struct stat st;
141 		if (fstat(fImageFD, &st) < 0) {
142 			fprintf(stderr, "Error: Failed to stat \"%s\": %s.\n", fImagePath,
143 				strerror(errno));
144 			exit(1);
145 		}
146 		fImageSize = st.st_size;
147 	}
148 
149 	void _CreateSocket()
150 	{
151 		// create a socket
152 		fSocket = socket(AF_INET, SOCK_DGRAM, 0);
153 		if (fSocket < 0) {
154 			fprintf(stderr, "Error: Failed to create a socket: %s.",
155 				strerror(errno));
156 			exit(1);
157 		}
158 
159 		// bind it to the port
160 		sockaddr_in addr;
161 		addr.sin_family = AF_INET;
162 		addr.sin_port = htons(REMOTE_DISK_SERVER_PORT);
163 		addr.sin_addr.s_addr = INADDR_ANY;
164 		if (bind(fSocket, (sockaddr*)&addr, sizeof(addr)) < 0) {
165 			fprintf(stderr, "Error: Failed to bind socket to port %hu: %s\n",
166 				REMOTE_DISK_SERVER_PORT, strerror(errno));
167 			exit(1);
168 		}
169 	}
170 
171 	void _HandleHelloRequest()
172 	{
173 		printf("HELLO request\n");
174 
175 		_OpenImage(true);
176 
177 		remote_disk_header reply;
178 		reply.offset = htonll(fImageSize);
179 
180 		reply.command = REMOTE_DISK_HELLO_REPLY;
181 		_SendReply(&reply, sizeof(remote_disk_header));
182 	}
183 
184 	void _HandleReadRequest()
185 	{
186 		_OpenImage(false);
187 
188 		char buffer[2048];
189 		remote_disk_header *reply = (remote_disk_header*)buffer;
190 		uint64_t offset = ntohll(fRequest->offset);
191 		int16_t size = ntohs(fRequest->size);
192 		int16_t result = 0;
193 
194 		printf("READ request: offset: %llu, %hd bytes\n", offset, size);
195 
196 		if (offset < (uint64_t)fImageSize && size > 0) {
197 			// always read 1024 bytes
198 			size = REMOTE_DISK_BLOCK_SIZE;
199 			if (offset + size > (uint64_t)fImageSize)
200 				size = fImageSize - offset;
201 
202 			// seek to the offset
203 			off_t oldOffset = lseek(fImageFD, offset, SEEK_SET);
204 			if (oldOffset >= 0) {
205 				// read
206 				ssize_t bytesRead = read(fImageFD, reply->data, size);
207 				if (bytesRead >= 0) {
208 					result = bytesRead;
209 				} else {
210 					fprintf(stderr, "Error: Failed to read at position %llu: "
211 						"%s.", offset, strerror(errno));
212 					result = REMOTE_DISK_IO_ERROR;
213 				}
214 			} else {
215 				fprintf(stderr, "Error: Failed to seek to position %llu: %s.",
216 					offset, strerror(errno));
217 				result = REMOTE_DISK_IO_ERROR;
218 			}
219 		}
220 
221 		// send reply
222 		reply->command = REMOTE_DISK_READ_REPLY;
223 		reply->offset = htonll(offset);
224 		reply->size = htons(result);
225 		_SendReply(reply, sizeof(*reply) + (result >= 0 ? result : 0));
226 	}
227 
228 	void _HandleWriteRequest()
229 	{
230 		_OpenImage(false);
231 
232 		remote_disk_header reply;
233 		uint64_t offset = ntohll(fRequest->offset);
234 		int16_t size = ntohs(fRequest->size);
235 		int16_t result = 0;
236 
237 		printf("READ request: offset: %llu, %hd bytes\n", offset, size);
238 
239 		if (size < 0
240 			|| (uint32_t)size > fRequestSize - sizeof(remote_disk_header)
241 			|| offset > (uint64_t)fImageSize) {
242 			result = REMOTE_DISK_BAD_REQUEST;
243 		} else if (offset < (uint64_t)fImageSize && size > 0) {
244 			if (offset + size > (uint64_t)fImageSize)
245 				size = fImageSize - offset;
246 
247 			// seek to the offset
248 			off_t oldOffset = lseek(fImageFD, offset, SEEK_SET);
249 			if (oldOffset >= 0) {
250 				// write
251 				ssize_t bytesWritten = write(fImageFD, fRequest->data, size);
252 				if (bytesWritten >= 0) {
253 					result = bytesWritten;
254 				} else {
255 					fprintf(stderr, "Error: Failed to write at position %llu: "
256 						"%s.", offset, strerror(errno));
257 					result = REMOTE_DISK_IO_ERROR;
258 				}
259 			} else {
260 				fprintf(stderr, "Error: Failed to seek to position %llu: %s.",
261 					offset, strerror(errno));
262 				result = REMOTE_DISK_IO_ERROR;
263 			}
264 		}
265 
266 		// send reply
267 		reply.command = REMOTE_DISK_WRITE_REPLY;
268 		reply.offset = htonll(offset);
269 		reply.size = htons(result);
270 		_SendReply(&reply, sizeof(reply));
271 	}
272 
273 	void _SendReply(remote_disk_header *reply, int size)
274 	{
275 		reply->request_id = fRequest->request_id;
276 		reply->port = htons(REMOTE_DISK_SERVER_PORT);
277 
278 		for (;;) {
279 			ssize_t bytesSent = sendto(fSocket, reply, size, 0,
280 				(const sockaddr*)&fClientAddress, sizeof(fClientAddress));
281 
282 			if (bytesSent < 0) {
283 				if (errno == EINTR)
284 					continue;
285 				fprintf(stderr, "Error: Failed to send reply to client: %s.\n",
286 					strerror(errno));
287 			}
288 			break;
289 		}
290 	}
291 
292 private:
293 	const char			*fImagePath;
294 	int					fImageFD;
295 	off_t				fImageSize;
296 	int					fSocket;
297 	remote_disk_header	*fRequest;
298 	ssize_t				fRequestSize;
299 	sockaddr_in			fClientAddress;
300 };
301 
302 
303 // main
304 int
305 main(int argc, const char *const *argv)
306 {
307 	if (argc != 2) {
308 		fprintf(stderr, "Usage: %s <image path>\n", argv[0]);
309 		exit(1);
310 	}
311 	const char *fileName = argv[1];
312 
313 	Server server(fileName);
314 	return server.Run();
315 }
316