1 /* 2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>. 3 * All rights reserved. Distributed under the terms of the MIT License. 4 */ 5 6 7 #include <boot/net/UDP.h> 8 9 #include <stdio.h> 10 11 #include <KernelExport.h> 12 13 #include <boot/net/ChainBuffer.h> 14 #include <boot/net/NetStack.h> 15 16 17 //#define TRACE_UDP 18 #ifdef TRACE_UDP 19 # define TRACE(x) dprintf x 20 #else 21 # define TRACE(x) ; 22 #endif 23 24 25 using std::nothrow; 26 27 28 // #pragma mark - UDPPacket 29 30 31 UDPPacket::UDPPacket() 32 : 33 fNext(NULL), 34 fData(NULL), 35 fSize(0) 36 { 37 } 38 39 40 UDPPacket::~UDPPacket() 41 { 42 free(fData); 43 } 44 45 46 status_t 47 UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress, 48 uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort) 49 { 50 if (data == NULL) 51 return B_BAD_VALUE; 52 53 // clone the data 54 fData = malloc(size); 55 if (fData == NULL) 56 return B_NO_MEMORY; 57 memcpy(fData, data, size); 58 59 fSize = size; 60 fSourceAddress = sourceAddress; 61 fDestinationAddress = destinationAddress; 62 fSourcePort = sourcePort; 63 fDestinationPort = destinationPort; 64 65 return B_OK; 66 } 67 68 69 UDPPacket * 70 UDPPacket::Next() const 71 { 72 return fNext; 73 } 74 75 76 void 77 UDPPacket::SetNext(UDPPacket *next) 78 { 79 fNext = next; 80 } 81 82 83 const void * 84 UDPPacket::Data() const 85 { 86 return fData; 87 } 88 89 90 size_t 91 UDPPacket::DataSize() const 92 { 93 return fSize; 94 } 95 96 97 ip_addr_t 98 UDPPacket::SourceAddress() const 99 { 100 return fSourceAddress; 101 } 102 103 104 uint16 105 UDPPacket::SourcePort() const 106 { 107 return fSourcePort; 108 } 109 110 111 ip_addr_t 112 UDPPacket::DestinationAddress() const 113 { 114 return fDestinationAddress; 115 } 116 117 118 uint16 119 UDPPacket::DestinationPort() const 120 { 121 return fDestinationPort; 122 } 123 124 125 // #pragma mark - UDPSocket 126 127 128 UDPSocket::UDPSocket() 129 : 130 fUDPService(NetStack::Default()->GetUDPService()), 131 fFirstPacket(NULL), 132 fLastPacket(NULL), 133 fAddress(INADDR_ANY), 134 fPort(0) 135 { 136 } 137 138 139 UDPSocket::~UDPSocket() 140 { 141 if (fPort != 0 && fUDPService != NULL) 142 fUDPService->UnbindSocket(this); 143 } 144 145 146 status_t 147 UDPSocket::Bind(ip_addr_t address, uint16 port) 148 { 149 if (fUDPService == NULL) { 150 printf("UDPSocket::Bind(): no UDP service\n"); 151 return B_NO_INIT; 152 } 153 154 if (address == INADDR_BROADCAST || port == 0) { 155 printf("UDPSocket::Bind(): broadcast IP or port 0\n"); 156 return B_BAD_VALUE; 157 } 158 159 if (fPort != 0) { 160 printf("UDPSocket::Bind(): already bound\n"); 161 return EALREADY; 162 // correct code? 163 } 164 165 status_t error = fUDPService->BindSocket(this, address, port); 166 if (error != B_OK) { 167 printf("UDPSocket::Bind(): service BindSocket() failed\n"); 168 return error; 169 } 170 171 fAddress = address; 172 fPort = port; 173 174 return B_OK; 175 } 176 177 178 void 179 UDPSocket::Detach() 180 { 181 fUDPService = NULL; 182 // This will lead to subsequent methods returning B_NO_INIT 183 } 184 185 186 187 status_t 188 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort, 189 ChainBuffer *buffer) 190 { 191 if (fUDPService == NULL) 192 return B_NO_INIT; 193 194 return fUDPService->Send(fPort, destinationAddress, destinationPort, 195 buffer); 196 } 197 198 199 status_t 200 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort, 201 const void *data, size_t size) 202 { 203 if (data == NULL) 204 return B_BAD_VALUE; 205 206 ChainBuffer buffer((void*)data, size); 207 return Send(destinationAddress, destinationPort, &buffer); 208 } 209 210 211 status_t 212 UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout) 213 { 214 if (fUDPService == NULL) 215 return B_NO_INIT; 216 217 if (_packet == NULL) 218 return B_BAD_VALUE; 219 220 bigtime_t startTime = system_time(); 221 for (;;) { 222 fUDPService->ProcessIncomingPackets(); 223 *_packet = PopPacket(); 224 if (*_packet != NULL) 225 return B_OK; 226 227 if (system_time() - startTime > timeout) 228 return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT); 229 } 230 } 231 232 233 void 234 UDPSocket::PushPacket(UDPPacket *packet) 235 { 236 if (fLastPacket != NULL) 237 fLastPacket->SetNext(packet); 238 else 239 fFirstPacket = packet; 240 241 fLastPacket = packet; 242 packet->SetNext(NULL); 243 } 244 245 246 UDPPacket * 247 UDPSocket::PopPacket() 248 { 249 if (fFirstPacket == NULL) 250 return NULL; 251 252 UDPPacket *packet = fFirstPacket; 253 fFirstPacket = packet->Next(); 254 255 if (fFirstPacket == NULL) 256 fLastPacket = NULL; 257 258 packet->SetNext(NULL); 259 return packet; 260 } 261 262 263 // #pragma mark - UDPService 264 265 266 UDPService::UDPService(IPService *ipService) 267 : 268 IPSubService(kUDPServiceName), 269 fIPService(ipService) 270 { 271 } 272 273 274 UDPService::~UDPService() 275 { 276 int count = fSockets.Count(); 277 for (int i = 0; i < count; i++) { 278 UDPSocket *socket = fSockets.ElementAt(i); 279 socket->Detach(); 280 } 281 282 if (fIPService != NULL) 283 fIPService->UnregisterIPSubService(this); 284 } 285 286 287 status_t 288 UDPService::Init() 289 { 290 if (fIPService == NULL) 291 return B_BAD_VALUE; 292 if (!fIPService->RegisterIPSubService(this)) 293 return B_NO_MEMORY; 294 return B_OK; 295 } 296 297 298 uint8 299 UDPService::IPProtocol() const 300 { 301 return IPPROTO_UDP; 302 } 303 304 305 void 306 UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP, 307 ip_addr_t destinationIP, const void *data, size_t size) 308 { 309 TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, " 310 "%lu - %lu bytes\n", sourceIP, destinationIP, size, 311 sizeof(udp_header))); 312 313 if (data == NULL || size < sizeof(udp_header)) 314 return; 315 316 const udp_header *header = (const udp_header*)data; 317 uint16 source = ntohs(header->source); 318 uint16 destination = ntohs(header->destination); 319 uint16 length = ntohs(header->length); 320 321 // check the header 322 if (length < sizeof(udp_header) || length > size 323 || (header->checksum != 0 // 0 => checksum disabled 324 && _ChecksumData(data, length, sourceIP, destinationIP) != 0)) { 325 TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size " 326 "or checksum\n")); 327 return; 328 } 329 330 // find the target socket 331 UDPSocket *socket = _FindSocket(destinationIP, destination); 332 if (socket == NULL) 333 return; 334 335 // create a UDPPacket and queue it in the socket 336 UDPPacket *packet = new(nothrow) UDPPacket; 337 if (packet == NULL) 338 return; 339 status_t error = packet->SetTo((uint8*)data + sizeof(udp_header), 340 length - sizeof(udp_header), sourceIP, source, destinationIP, 341 destination); 342 if (error == B_OK) 343 socket->PushPacket(packet); 344 else 345 delete packet; 346 } 347 348 349 status_t 350 UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress, 351 uint16 destinationPort, ChainBuffer *buffer) 352 { 353 TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n", 354 sourcePort, destinationAddress, destinationPort, 355 (buffer != NULL ? buffer->TotalSize() : 0))); 356 357 if (fIPService == NULL) 358 return B_NO_INIT; 359 360 if (buffer == NULL) 361 return B_BAD_VALUE; 362 363 // prepend the UDP header 364 udp_header header; 365 ChainBuffer headerBuffer(&header, sizeof(header), buffer); 366 header.source = htons(sourcePort); 367 header.destination = htons(destinationPort); 368 header.length = htons(headerBuffer.TotalSize()); 369 370 // compute the checksum 371 header.checksum = 0; 372 header.checksum = htons(_ChecksumBuffer(&headerBuffer, 373 fIPService->IPAddress(), destinationAddress, 374 headerBuffer.TotalSize())); 375 // 0 means checksum disabled; 0xffff is equivalent in this case 376 if (header.checksum == 0) 377 header.checksum = 0xffff; 378 379 return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer); 380 } 381 382 383 void 384 UDPService::ProcessIncomingPackets() 385 { 386 if (fIPService != NULL) 387 fIPService->ProcessIncomingPackets(); 388 } 389 390 391 status_t 392 UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port) 393 { 394 if (socket == NULL) 395 return B_BAD_VALUE; 396 397 if (_FindSocket(address, port) != NULL) { 398 printf("UDPService::BindSocket(): address in use\n"); 399 return EADDRINUSE; 400 } 401 402 return fSockets.Add(socket); 403 } 404 405 406 void 407 UDPService::UnbindSocket(UDPSocket *socket) 408 { 409 fSockets.Remove(socket); 410 } 411 412 413 uint16 414 UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source, 415 ip_addr_t destination, uint16 length) 416 { 417 // The checksum is calculated over a pseudo-header plus the UDP packet. 418 // So we temporarily prepend the pseudo-header. 419 struct pseudo_header { 420 ip_addr_t source; 421 ip_addr_t destination; 422 uint8 pad; 423 uint8 protocol; 424 uint16 length; 425 } __attribute__ ((__packed__)); 426 pseudo_header header = { 427 htonl(source), 428 htonl(destination), 429 0, 430 IPPROTO_UDP, 431 htons(length) 432 }; 433 434 ChainBuffer headerBuffer(&header, sizeof(header), buffer); 435 uint16 checksum = ip_checksum(&headerBuffer); 436 headerBuffer.DetachNext(); 437 return checksum; 438 } 439 440 441 uint16 442 UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source, 443 ip_addr_t destination) 444 { 445 ChainBuffer buffer((void*)data, length); 446 return _ChecksumBuffer(&buffer, source, destination, length); 447 } 448 449 450 UDPSocket * 451 UDPService::_FindSocket(ip_addr_t address, uint16 port) 452 { 453 int count = fSockets.Count(); 454 for (int i = 0; i < count; i++) { 455 UDPSocket *socket = fSockets.ElementAt(i); 456 if ((address == INADDR_ANY || socket->Address() == INADDR_ANY 457 || socket->Address() == address) 458 && port == socket->Port()) { 459 return socket; 460 } 461 } 462 463 return NULL; 464 } 465