1 /* 2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>. 3 * All rights reserved. Distributed under the terms of the MIT License. 4 */ 5 6 7 #include <boot/net/UDP.h> 8 9 #include <stdio.h> 10 11 #include <KernelExport.h> 12 13 #include <boot/net/ChainBuffer.h> 14 #include <boot/net/NetStack.h> 15 16 17 //#define TRACE_UDP 18 #ifdef TRACE_UDP 19 # define TRACE(x) dprintf x 20 #else 21 # define TRACE(x) ; 22 #endif 23 24 25 using std::nothrow; 26 27 28 // #pragma mark - UDPPacket 29 30 31 UDPPacket::UDPPacket() 32 : 33 fNext(NULL), 34 fData(NULL), 35 fSize(0) 36 { 37 } 38 39 40 UDPPacket::~UDPPacket() 41 { 42 free(fData); 43 } 44 45 46 status_t 47 UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress, 48 uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort) 49 { 50 if (data == NULL) 51 return B_BAD_VALUE; 52 53 // clone the data 54 fData = malloc(size); 55 if (fData == NULL) 56 return B_NO_MEMORY; 57 memcpy(fData, data, size); 58 59 fSize = size; 60 fSourceAddress = sourceAddress; 61 fDestinationAddress = destinationAddress; 62 fSourcePort = sourcePort; 63 fDestinationPort = destinationPort; 64 65 return B_OK; 66 } 67 68 69 UDPPacket * 70 UDPPacket::Next() const 71 { 72 return fNext; 73 } 74 75 76 void 77 UDPPacket::SetNext(UDPPacket *next) 78 { 79 fNext = next; 80 } 81 82 83 const void * 84 UDPPacket::Data() const 85 { 86 return fData; 87 } 88 89 90 size_t 91 UDPPacket::DataSize() const 92 { 93 return fSize; 94 } 95 96 97 ip_addr_t 98 UDPPacket::SourceAddress() const 99 { 100 return fSourceAddress; 101 } 102 103 104 uint16 105 UDPPacket::SourcePort() const 106 { 107 return fSourcePort; 108 } 109 110 111 ip_addr_t 112 UDPPacket::DestinationAddress() const 113 { 114 return fDestinationAddress; 115 } 116 117 118 uint16 119 UDPPacket::DestinationPort() const 120 { 121 return fDestinationPort; 122 } 123 124 125 // #pragma mark - UDPSocket 126 127 128 UDPSocket::UDPSocket() 129 : 130 fUDPService(NetStack::Default()->GetUDPService()), 131 fFirstPacket(NULL), 132 fLastPacket(NULL), 133 fAddress(INADDR_ANY), 134 fPort(0) 135 { 136 } 137 138 139 UDPSocket::~UDPSocket() 140 { 141 if (fPort != 0 && fUDPService != NULL) 142 fUDPService->UnbindSocket(this); 143 } 144 145 146 status_t 147 UDPSocket::Bind(ip_addr_t address, uint16 port) 148 { 149 if (fUDPService == NULL) { 150 printf("UDPSocket::Bind(): no UDP service\n"); 151 return B_NO_INIT; 152 } 153 154 if (address == INADDR_BROADCAST || port == 0) { 155 printf("UDPSocket::Bind(): broadcast IP or port 0\n"); 156 return B_BAD_VALUE; 157 } 158 159 if (fPort != 0) { 160 printf("UDPSocket::Bind(): already bound\n"); 161 return EALREADY; 162 // correct code? 163 } 164 165 status_t error = fUDPService->BindSocket(this, address, port); 166 if (error != B_OK) { 167 printf("UDPSocket::Bind(): service BindSocket() failed\n"); 168 return error; 169 } 170 171 fAddress = address; 172 fPort = port; 173 174 return B_OK; 175 } 176 177 178 status_t 179 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort, 180 ChainBuffer *buffer) 181 { 182 if (fUDPService == NULL) 183 return B_NO_INIT; 184 185 return fUDPService->Send(fPort, destinationAddress, destinationPort, 186 buffer); 187 } 188 189 190 status_t 191 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort, 192 const void *data, size_t size) 193 { 194 if (data == NULL) 195 return B_BAD_VALUE; 196 197 ChainBuffer buffer((void*)data, size); 198 return Send(destinationAddress, destinationPort, &buffer); 199 } 200 201 202 status_t 203 UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout) 204 { 205 if (fUDPService == NULL) 206 return B_NO_INIT; 207 208 if (_packet == NULL) 209 return B_BAD_VALUE; 210 211 bigtime_t startTime = system_time(); 212 for (;;) { 213 fUDPService->ProcessIncomingPackets(); 214 *_packet = PopPacket(); 215 if (*_packet != NULL) 216 return B_OK; 217 218 if (system_time() - startTime > timeout) 219 return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT); 220 } 221 } 222 223 224 void 225 UDPSocket::PushPacket(UDPPacket *packet) 226 { 227 if (fLastPacket != NULL) 228 fLastPacket->SetNext(packet); 229 else 230 fFirstPacket = packet; 231 232 fLastPacket = packet; 233 packet->SetNext(NULL); 234 } 235 236 237 UDPPacket * 238 UDPSocket::PopPacket() 239 { 240 if (fFirstPacket == NULL) 241 return NULL; 242 243 UDPPacket *packet = fFirstPacket; 244 fFirstPacket = packet->Next(); 245 246 if (fFirstPacket == NULL) 247 fLastPacket = NULL; 248 249 packet->SetNext(NULL); 250 return packet; 251 } 252 253 254 // #pragma mark - UDPService 255 256 257 UDPService::UDPService(IPService *ipService) 258 : 259 IPSubService(kUDPServiceName), 260 fIPService(ipService) 261 { 262 } 263 264 265 UDPService::~UDPService() 266 { 267 if (fIPService != NULL) 268 fIPService->UnregisterIPSubService(this); 269 } 270 271 272 status_t 273 UDPService::Init() 274 { 275 if (fIPService == NULL) 276 return B_BAD_VALUE; 277 if (!fIPService->RegisterIPSubService(this)) 278 return B_NO_MEMORY; 279 return B_OK; 280 } 281 282 283 uint8 284 UDPService::IPProtocol() const 285 { 286 return IPPROTO_UDP; 287 } 288 289 290 void 291 UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP, 292 ip_addr_t destinationIP, const void *data, size_t size) 293 { 294 TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, " 295 "%lu - %lu bytes\n", sourceIP, destinationIP, size, 296 sizeof(udp_header))); 297 298 if (data == NULL || size < sizeof(udp_header)) 299 return; 300 301 const udp_header *header = (const udp_header*)data; 302 uint16 source = ntohs(header->source); 303 uint16 destination = ntohs(header->destination); 304 uint16 length = ntohs(header->length); 305 306 // check the header 307 if (length < sizeof(udp_header) || length > size 308 || (header->checksum != 0 // 0 => checksum disabled 309 && _ChecksumData(data, length, sourceIP, destinationIP) != 0)) { 310 TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size " 311 "or checksum\n")); 312 return; 313 } 314 315 // find the target socket 316 UDPSocket *socket = _FindSocket(destinationIP, destination); 317 if (socket == NULL) 318 return; 319 320 // create a UDPPacket and queue it in the socket 321 UDPPacket *packet = new(nothrow) UDPPacket; 322 if (packet == NULL) 323 return; 324 status_t error = packet->SetTo((uint8*)data + sizeof(udp_header), 325 length - sizeof(udp_header), sourceIP, source, destinationIP, 326 destination); 327 if (error == B_OK) 328 socket->PushPacket(packet); 329 else 330 delete packet; 331 } 332 333 334 status_t 335 UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress, 336 uint16 destinationPort, ChainBuffer *buffer) 337 { 338 TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n", 339 sourcePort, destinationAddress, destinationPort, 340 (buffer != NULL ? buffer->TotalSize() : 0))); 341 342 if (fIPService == NULL) 343 return B_NO_INIT; 344 345 if (buffer == NULL) 346 return B_BAD_VALUE; 347 348 // prepend the UDP header 349 udp_header header; 350 ChainBuffer headerBuffer(&header, sizeof(header), buffer); 351 header.source = htons(sourcePort); 352 header.destination = htons(destinationPort); 353 header.length = htons(headerBuffer.TotalSize()); 354 355 // compute the checksum 356 header.checksum = 0; 357 header.checksum = htons(_ChecksumBuffer(&headerBuffer, 358 fIPService->IPAddress(), destinationAddress, 359 headerBuffer.TotalSize())); 360 // 0 means checksum disabled; 0xffff is equivalent in this case 361 if (header.checksum == 0) 362 header.checksum = 0xffff; 363 364 return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer); 365 } 366 367 368 void 369 UDPService::ProcessIncomingPackets() 370 { 371 if (fIPService != NULL) 372 fIPService->ProcessIncomingPackets(); 373 } 374 375 376 status_t 377 UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port) 378 { 379 if (socket == NULL) 380 return B_BAD_VALUE; 381 382 if (_FindSocket(address, port) != NULL) { 383 printf("UDPService::BindSocket(): address in use\n"); 384 return EADDRINUSE; 385 } 386 387 return fSockets.Add(socket); 388 } 389 390 391 void 392 UDPService::UnbindSocket(UDPSocket *socket) 393 { 394 fSockets.Remove(socket); 395 } 396 397 398 uint16 399 UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source, 400 ip_addr_t destination, uint16 length) 401 { 402 // The checksum is calculated over a pseudo-header plus the UDP packet. 403 // So we temporarily prepend the pseudo-header. 404 struct pseudo_header { 405 ip_addr_t source; 406 ip_addr_t destination; 407 uint8 pad; 408 uint8 protocol; 409 uint16 length; 410 } __attribute__ ((__packed__)); 411 pseudo_header header = { 412 htonl(source), 413 htonl(destination), 414 0, 415 IPPROTO_UDP, 416 htons(length) 417 }; 418 419 ChainBuffer headerBuffer(&header, sizeof(header), buffer); 420 uint16 checksum = ip_checksum(&headerBuffer); 421 headerBuffer.DetachNext(); 422 return checksum; 423 } 424 425 426 uint16 427 UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source, 428 ip_addr_t destination) 429 { 430 ChainBuffer buffer((void*)data, length); 431 return _ChecksumBuffer(&buffer, source, destination, length); 432 } 433 434 435 UDPSocket * 436 UDPService::_FindSocket(ip_addr_t address, uint16 port) 437 { 438 int count = fSockets.Count(); 439 for (int i = 0; i < count; i++) { 440 UDPSocket *socket = fSockets.ElementAt(i); 441 if ((address == INADDR_ANY || socket->Address() == INADDR_ANY 442 || socket->Address() == address) 443 && port == socket->Port()) { 444 return socket; 445 } 446 } 447 448 return NULL; 449 } 450