xref: /haiku/src/system/boot/loader/net/UDP.cpp (revision 93a78ecaa45114d68952d08c4778f073515102f2)
1 /*
2  * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3  * All rights reserved. Distributed under the terms of the MIT License.
4  */
5 
6 #include <boot/net/UDP.h>
7 
8 #include <stdio.h>
9 #include <KernelExport.h>
10 
11 #include <boot/net/ChainBuffer.h>
12 #include <boot/net/NetStack.h>
13 
14 
15 //#define TRACE_UDP
16 #ifdef TRACE_UDP
17 #	define TRACE(x) dprintf x
18 #else
19 #	define TRACE(x) ;
20 #endif
21 
22 
23 // #pragma mark - UDPPacket
24 
25 // constructor
26 UDPPacket::UDPPacket()
27 	: fNext(NULL),
28 		fData(NULL),
29 		fSize(0)
30 {
31 }
32 
33 // destructor
34 UDPPacket::~UDPPacket()
35 {
36 	free(fData);
37 }
38 
39 // SetTo
40 status_t
41 UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress,
42 	uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort)
43 {
44 	if (!data)
45 		return B_BAD_VALUE;
46 
47 	// clone the data
48 	fData = malloc(size);
49 	if (!fData)
50 		return B_NO_MEMORY;
51 	memcpy(fData, data, size);
52 
53 	fSize = size;
54 	fSourceAddress = sourceAddress;
55 	fDestinationAddress = destinationAddress;
56 	fSourcePort = sourcePort;
57 	fDestinationPort = destinationPort;
58 
59 	return B_OK;
60 }
61 
62 // Next
63 UDPPacket *
64 UDPPacket::Next() const
65 {
66 	return fNext;
67 }
68 
69 // SetNext
70 void
71 UDPPacket::SetNext(UDPPacket *next)
72 {
73 	fNext = next;
74 }
75 
76 // Data
77 const void *
78 UDPPacket::Data() const
79 {
80 	return fData;
81 }
82 
83 // DataSize
84 size_t
85 UDPPacket::DataSize() const
86 {
87 	return fSize;
88 }
89 
90 // SourceAddress
91 ip_addr_t
92 UDPPacket::SourceAddress() const
93 {
94 	return fSourceAddress;
95 }
96 
97 // SourcePort
98 uint16
99 UDPPacket::SourcePort() const
100 {
101 	return fSourcePort;
102 }
103 
104 // DestinationAddress
105 ip_addr_t
106 UDPPacket::DestinationAddress() const
107 {
108 	return fDestinationAddress;
109 }
110 
111 // DestinationPort
112 uint16
113 UDPPacket::DestinationPort() const
114 {
115 	return fDestinationPort;
116 }
117 
118 
119 // #pragma mark - UDPSocket
120 
121 // constructor
122 UDPSocket::UDPSocket()
123 	: fUDPService(NetStack::Default()->GetUDPService()),
124 	  fFirstPacket(NULL),
125 	  fLastPacket(NULL),
126 	  fAddress(INADDR_ANY),
127 	  fPort(0)
128 {
129 }
130 
131 // destructor
132 UDPSocket::~UDPSocket()
133 {
134 	if (fPort != 0 && fUDPService)
135 		fUDPService->UnbindSocket(this);
136 }
137 
138 // Bind
139 status_t
140 UDPSocket::Bind(ip_addr_t address, uint16 port)
141 {
142 	if (!fUDPService) {
143 		printf("UDPSocket::Bind(): no UDP service\n");
144 		return B_NO_INIT;
145 	}
146 
147 	if (address == INADDR_BROADCAST || port == 0) {
148 		printf("UDPSocket::Bind(): broadcast IP or port 0\n");
149 		return B_BAD_VALUE;
150 	}
151 
152 	if (fPort != 0) {
153 		printf("UDPSocket::Bind(): already bound\n");
154 		return EALREADY; // correct code?
155 	}
156 
157 	status_t error = fUDPService->BindSocket(this, address, port);
158 	if (error != B_OK) {
159 		printf("UDPSocket::Bind(): service BindSocket() failed\n");
160 		return error;
161 	}
162 
163 	fAddress = address;
164 	fPort = port;
165 
166 	return B_OK;
167 }
168 
169 // Send
170 status_t
171 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
172 	ChainBuffer *buffer)
173 {
174 	if (!fUDPService)
175 		return B_NO_INIT;
176 
177 	return fUDPService->Send(fPort, destinationAddress, destinationPort,
178 		buffer);
179 }
180 
181 // Send
182 status_t
183 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
184 	const void *data, size_t size)
185 {
186 	if (!data)
187 		return B_BAD_VALUE;
188 
189 	ChainBuffer buffer((void*)data, size);
190 	return Send(destinationAddress, destinationPort, &buffer);
191 }
192 
193 // Receive
194 status_t
195 UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout)
196 {
197 	if (!fUDPService)
198 		return B_NO_INIT;
199 
200 	if (!_packet)
201 		return B_BAD_VALUE;
202 
203 	bigtime_t startTime = system_time();
204 	for (;;) {
205 		fUDPService->ProcessIncomingPackets();
206 		if ((*_packet = PopPacket()))
207 			return B_OK;
208 
209 		if (system_time() - startTime > timeout)
210 			return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT);
211 	}
212 }
213 
214 // PushPacket
215 void
216 UDPSocket::PushPacket(UDPPacket *packet)
217 {
218 	if (fLastPacket)
219 		fLastPacket->SetNext(packet);
220 	else
221 		fFirstPacket = packet;
222 
223 	fLastPacket = packet;
224 	packet->SetNext(NULL);
225 }
226 
227 // PopPacket
228 UDPPacket *
229 UDPSocket::PopPacket()
230 {
231 	if (!fFirstPacket)
232 		return NULL;
233 
234 	UDPPacket *packet = fFirstPacket;
235 	fFirstPacket = packet->Next();
236 
237 	if (!fFirstPacket)
238 		fLastPacket = NULL;
239 
240 	packet->SetNext(NULL);
241 	return packet;
242 }
243 
244 
245 // #pragma mark - UDPService
246 
247 // constructor
248 UDPService::UDPService(IPService *ipService)
249 	: IPSubService(kUDPServiceName),
250 		fIPService(ipService)
251 {
252 }
253 
254 // destructor
255 UDPService::~UDPService()
256 {
257 	if (fIPService)
258 		fIPService->UnregisterIPSubService(this);
259 }
260 
261 // Init
262 status_t
263 UDPService::Init()
264 {
265 	if (!fIPService)
266 		return B_BAD_VALUE;
267 	if (!fIPService->RegisterIPSubService(this))
268 		return B_NO_MEMORY;
269 	return B_OK;
270 }
271 
272 // IPProtocol
273 uint8
274 UDPService::IPProtocol() const
275 {
276 	return IPPROTO_UDP;
277 }
278 
279 // HandleIPPacket
280 void
281 UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP,
282 	ip_addr_t destinationIP, const void *data, size_t size)
283 {
284 	TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, "
285 		"%lu - %lu bytes\n", sourceIP, destinationIP, size,
286 		sizeof(udp_header)));
287 
288 	if (!data || size < sizeof(udp_header))
289 		return;
290 
291 	const udp_header *header = (const udp_header*)data;
292 	uint16 source = ntohs(header->source);
293 	uint16 destination = ntohs(header->destination);
294 	uint16 length = ntohs(header->length);
295 
296 	// check the header
297 	if (length < sizeof(udp_header) || length > size
298 		|| (header->checksum != 0	// 0 => checksum disabled
299 			&& _ChecksumData(data, length, sourceIP, destinationIP) != 0)) {
300 		TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size "
301 			"or checksum\n"));
302 		return;
303 	}
304 
305 	// find the target socket
306 	UDPSocket *socket = _FindSocket(destinationIP, destination);
307 	if (!socket)
308 		return;
309 
310 	// create a UDPPacket and queue it in the socket
311 	UDPPacket *packet = new(nothrow) UDPPacket;
312 	if (!packet)
313 		return;
314 	status_t error = packet->SetTo((uint8*)data + sizeof(udp_header),
315 		length - sizeof(udp_header), sourceIP, source, destinationIP,
316 		destination);
317 	if (error == B_OK)
318 		socket->PushPacket(packet);
319 	else
320 		delete packet;
321 }
322 
323 // Send
324 status_t
325 UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress,
326 	uint16 destinationPort, ChainBuffer *buffer)
327 {
328 	TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n",
329 		sourcePort, destinationAddress, destinationPort,
330 		(buffer ? buffer->TotalSize() : 0)));
331 
332 	if (!fIPService)
333 		return B_NO_INIT;
334 
335 	if (!buffer)
336 		return B_BAD_VALUE;
337 
338 	// prepend the UDP header
339 	udp_header header;
340 	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
341 	header.source = htons(sourcePort);
342 	header.destination = htons(destinationPort);
343 	header.length = htons(headerBuffer.TotalSize());
344 
345 	// compute the checksum
346 	header.checksum = 0;
347 	header.checksum = htons(_ChecksumBuffer(&headerBuffer,
348 		fIPService->IPAddress(), destinationAddress,
349 		headerBuffer.TotalSize()));
350 	// 0 means checksum disabled; 0xffff is equivalent in this case
351 	if (header.checksum == 0)
352 		header.checksum = 0xffff;
353 
354 	return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer);
355 }
356 
357 // ProcessIncomingPackets
358 void
359 UDPService::ProcessIncomingPackets()
360 {
361 	if (fIPService)
362 		fIPService->ProcessIncomingPackets();
363 }
364 
365 // BindSocket
366 status_t
367 UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port)
368 {
369 	if (!socket)
370 		return B_BAD_VALUE;
371 
372 	if (_FindSocket(address, port)) {
373 		printf("UDPService::BindSocket(): address in use\n");
374 		return EADDRINUSE;
375 	}
376 
377 	return fSockets.Add(socket);
378 }
379 
380 // UnbindSocket
381 void
382 UDPService::UnbindSocket(UDPSocket *socket)
383 {
384 	fSockets.Remove(socket);
385 }
386 
387 // _ChecksumBuffer
388 uint16
389 UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source,
390 	ip_addr_t destination, uint16 length)
391 {
392 	// The checksum is calculated over a pseudo-header plus the UDP packet.
393 	// So we temporarily prepend the pseudo-header.
394 	struct pseudo_header {
395 		ip_addr_t	source;
396 		ip_addr_t	destination;
397 		uint8		pad;
398 		uint8		protocol;
399 		uint16		length;
400 	} __attribute__ ((__packed__));
401 	pseudo_header header = {
402 		htonl(source),
403 		htonl(destination),
404 		0,
405 		IPPROTO_UDP,
406 		htons(length)
407 	};
408 
409 	ChainBuffer headerBuffer(&header, sizeof(header), buffer);
410 	uint16 checksum = ip_checksum(&headerBuffer);
411 	headerBuffer.DetachNext();
412 	return checksum;
413 }
414 
415 // _ChecksumData
416 uint16
417 UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source,
418 	ip_addr_t destination)
419 {
420 	ChainBuffer buffer((void*)data, length);
421 	return _ChecksumBuffer(&buffer, source, destination, length);
422 }
423 
424 // _FindSocket
425 UDPSocket *
426 UDPService::_FindSocket(ip_addr_t address, uint16 port)
427 {
428 	int count = fSockets.Count();
429 	for (int i = 0; i < count; i++) {
430 		UDPSocket *socket = fSockets.ElementAt(i);
431 		if ((address == INADDR_ANY || socket->Address() == INADDR_ANY
432 				|| socket->Address() == address)
433 			&& port == socket->Port()) {
434 			return socket;
435 		}
436 	}
437 
438 	return NULL;
439 }
440