1 // InsecureConnection.cpp 2 3 // TODO: Asynchronous connecting on client side? 4 5 #include <new> 6 7 #include <errno.h> 8 #include <netdb.h> 9 #include <stdio.h> 10 #include <string.h> 11 #include <unistd.h> 12 13 #include <ByteOrder.h> 14 15 #include "Compatibility.h" 16 #include "DebugSupport.h" 17 #include "InsecureChannel.h" 18 #include "InsecureConnection.h" 19 #include "NetAddress.h" 20 #include "NetFSDefs.h" 21 22 namespace InsecureConnectionDefs { 23 24 const int32 kProtocolVersion = 1; 25 26 const bigtime_t kAcceptingTimeout = 10000000; // 10 s 27 28 // number of client up/down stream channels 29 const int32 kMinUpStreamChannels = 1; 30 const int32 kMaxUpStreamChannels = 10; 31 const int32 kDefaultUpStreamChannels = 5; 32 const int32 kMinDownStreamChannels = 1; 33 const int32 kMaxDownStreamChannels = 5; 34 const int32 kDefaultDownStreamChannels = 1; 35 36 } // namespace InsecureConnectionDefs 37 38 using namespace InsecureConnectionDefs; 39 40 // SocketCloser 41 struct SocketCloser { 42 SocketCloser(int fd) : fFD(fd) {} 43 ~SocketCloser() 44 { 45 if (fFD >= 0) 46 closesocket(fFD); 47 } 48 49 int Detach() 50 { 51 int fd = fFD; 52 fFD = -1; 53 return fd; 54 } 55 56 private: 57 int fFD; 58 }; 59 60 61 // #pragma mark - 62 // #pragma mark ----- InsecureConnection ----- 63 64 // constructor 65 InsecureConnection::InsecureConnection() 66 { 67 } 68 69 // destructor 70 InsecureConnection::~InsecureConnection() 71 { 72 } 73 74 // Init (server side) 75 status_t 76 InsecureConnection::Init(int fd) 77 { 78 status_t error = AbstractConnection::Init(); 79 if (error != B_OK) { 80 closesocket(fd); 81 return error; 82 } 83 // create the initial channel 84 Channel* channel = new(std::nothrow) InsecureChannel(fd); 85 if (!channel) { 86 closesocket(fd); 87 return B_NO_MEMORY; 88 } 89 // add it 90 error = AddDownStreamChannel(channel); 91 if (error != B_OK) { 92 delete channel; 93 return error; 94 } 95 return B_OK; 96 } 97 98 // Init (client side) 99 status_t 100 InsecureConnection::Init(const char* parameters) 101 { 102 PRINT(("InsecureConnection::Init\n")); 103 if (!parameters) 104 return B_BAD_VALUE; 105 status_t error = AbstractConnection::Init(); 106 if (error != B_OK) 107 return error; 108 // parse the parameters to get a server name and a port we shall connect to 109 // parameter format is "<server>[:port] [ <up> [ <down> ] ]" 110 char server[256]; 111 uint16 port = kDefaultInsecureConnectionPort; 112 int upStreamChannels = kDefaultUpStreamChannels; 113 int downStreamChannels = kDefaultDownStreamChannels; 114 if (strchr(parameters, ':')) { 115 int result = sscanf(parameters, "%255[^:]:%hu %d %d", server, &port, 116 &upStreamChannels, &downStreamChannels); 117 if (result < 2) 118 return B_BAD_VALUE; 119 } else { 120 int result = sscanf(parameters, "%255[^:] %d %d", server, 121 &upStreamChannels, &downStreamChannels); 122 if (result < 1) 123 return B_BAD_VALUE; 124 } 125 // resolve server address 126 NetAddress netAddress; 127 error = NetAddressResolver().GetHostAddress(server, &netAddress); 128 if (error != B_OK) 129 return error; 130 in_addr serverAddr = netAddress.GetAddress().sin_addr; 131 // open the initial channel 132 Channel* channel; 133 error = _OpenClientChannel(serverAddr, port, &channel); 134 if (error != B_OK) 135 return error; 136 error = AddUpStreamChannel(channel); 137 if (error != B_OK) { 138 delete channel; 139 return error; 140 } 141 // send the server a connect request 142 ConnectRequest request; 143 request.protocolVersion = B_HOST_TO_BENDIAN_INT32(kProtocolVersion); 144 request.serverAddress = serverAddr.s_addr; 145 request.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels); 146 request.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels); 147 error = channel->Send(&request, sizeof(ConnectRequest)); 148 if (error != B_OK) 149 return error; 150 // get the server reply 151 ConnectReply reply; 152 error = channel->Receive(&reply, sizeof(ConnectReply)); 153 if (error != B_OK) 154 return error; 155 error = B_BENDIAN_TO_HOST_INT32(reply.error); 156 if (error != B_OK) 157 return error; 158 upStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.upStreamChannels); 159 downStreamChannels = B_BENDIAN_TO_HOST_INT32(reply.downStreamChannels); 160 port = B_BENDIAN_TO_HOST_INT16(reply.port); 161 // open the remaining channels 162 int32 allChannels = upStreamChannels + downStreamChannels; 163 for (int32 i = 1; i < allChannels; i++) { 164 PRINT((" creating channel %ld\n", i)); 165 // open the channel 166 error = _OpenClientChannel(serverAddr, port, &channel); 167 if (error != B_OK) 168 RETURN_ERROR(error); 169 // add it 170 if (i < upStreamChannels) 171 error = AddUpStreamChannel(channel); 172 else 173 error = AddDownStreamChannel(channel); 174 if (error != B_OK) { 175 delete channel; 176 return error; 177 } 178 } 179 return B_OK; 180 } 181 182 // FinishInitialization 183 status_t 184 InsecureConnection::FinishInitialization() 185 { 186 PRINT(("InsecureConnection::FinishInitialization()\n")); 187 // get the down stream channel 188 InsecureChannel* channel 189 = dynamic_cast<InsecureChannel*>(DownStreamChannelAt(0)); 190 if (!channel) 191 return B_BAD_VALUE; 192 // receive the connect request 193 ConnectRequest request; 194 status_t error = channel->Receive(&request, sizeof(ConnectRequest)); 195 if (error != B_OK) 196 return error; 197 // check the protocol version 198 int32 protocolVersion = B_BENDIAN_TO_HOST_INT32(request.protocolVersion); 199 if (protocolVersion != kProtocolVersion) { 200 _SendErrorReply(channel, B_ERROR); 201 return B_ERROR; 202 } 203 // get our address (we need it for binding) 204 in_addr serverAddr; 205 serverAddr.s_addr = request.serverAddress; 206 // check number of up and down stream channels 207 int32 upStreamChannels = B_BENDIAN_TO_HOST_INT32(request.upStreamChannels); 208 int32 downStreamChannels = B_BENDIAN_TO_HOST_INT32( 209 request.downStreamChannels); 210 if (upStreamChannels < kMinUpStreamChannels) 211 upStreamChannels = kMinUpStreamChannels; 212 else if (upStreamChannels > kMaxUpStreamChannels) 213 upStreamChannels = kMaxUpStreamChannels; 214 if (downStreamChannels < kMinDownStreamChannels) 215 downStreamChannels = kMinDownStreamChannels; 216 else if (downStreamChannels > kMaxDownStreamChannels) 217 downStreamChannels = kMaxDownStreamChannels; 218 // due to a bug on BONE we have a maximum of 2 working connections 219 // accepted on one listener socket. 220 NetAddress peerAddress; 221 if (channel->GetPeerAddress(&peerAddress) == B_OK 222 && peerAddress.IsLocal()) { 223 upStreamChannels = 1; 224 if (downStreamChannels > 2) 225 downStreamChannels = 2; 226 } 227 int32 allChannels = upStreamChannels + downStreamChannels; 228 // create a listener socket 229 int fd = socket(AF_INET, SOCK_STREAM, 0); 230 if (fd < 0) { 231 error = errno; 232 _SendErrorReply(channel, error); 233 return error; 234 } 235 SocketCloser _(fd); 236 // bind it to some port 237 sockaddr_in addr; 238 addr.sin_family = AF_INET; 239 addr.sin_port = 0; 240 addr.sin_addr = serverAddr; 241 if (bind(fd, (sockaddr*)&addr, sizeof(addr)) < 0) { 242 error = errno; 243 _SendErrorReply(channel, error); 244 return error; 245 } 246 // get the port 247 socklen_t addrSize = sizeof(addr); 248 if (getsockname(fd, (sockaddr*)&addr, &addrSize) < 0) { 249 error = errno; 250 _SendErrorReply(channel, error); 251 return error; 252 } 253 // set socket to non-blocking 254 int dontBlock = 1; 255 if (setsockopt(fd, SOL_SOCKET, SO_NONBLOCK, &dontBlock, sizeof(int)) < 0) { 256 error = errno; 257 _SendErrorReply(channel, error); 258 return error; 259 } 260 // start listening 261 if (listen(fd, allChannels - 1) < 0) { 262 error = errno; 263 _SendErrorReply(channel, error); 264 return error; 265 } 266 // send the reply 267 ConnectReply reply; 268 reply.error = B_HOST_TO_BENDIAN_INT32(B_OK); 269 reply.upStreamChannels = B_HOST_TO_BENDIAN_INT32(upStreamChannels); 270 reply.downStreamChannels = B_HOST_TO_BENDIAN_INT32(downStreamChannels); 271 reply.port = addr.sin_port; 272 error = channel->Send(&reply, sizeof(ConnectReply)); 273 if (error != B_OK) 274 return error; 275 // start accepting 276 bigtime_t startAccepting = system_time(); 277 for (int32 i = 1; i < allChannels; ) { 278 // accept a connection 279 int channelFD = accept(fd, NULL, 0); 280 if (channelFD < 0) { 281 error = errno; 282 if (error == B_INTERRUPTED) { 283 error = B_OK; 284 continue; 285 } 286 if (error == B_WOULD_BLOCK) { 287 bigtime_t now = system_time(); 288 if (now - startAccepting > kAcceptingTimeout) 289 RETURN_ERROR(B_TIMED_OUT); 290 snooze(10000); 291 continue; 292 } 293 RETURN_ERROR(error); 294 } 295 PRINT((" accepting channel %ld\n", i)); 296 // create a channel 297 channel = new(std::nothrow) InsecureChannel(channelFD); 298 if (!channel) { 299 closesocket(channelFD); 300 return B_NO_MEMORY; 301 } 302 // add it 303 if (i < upStreamChannels) // inverse, since we are on server side 304 error = AddDownStreamChannel(channel); 305 else 306 error = AddUpStreamChannel(channel); 307 if (error != B_OK) { 308 delete channel; 309 return error; 310 } 311 i++; 312 startAccepting = system_time(); 313 } 314 return B_OK; 315 } 316 317 // _OpenClientChannel 318 status_t 319 InsecureConnection::_OpenClientChannel(in_addr serverAddr, uint16 port, 320 Channel** _channel) 321 { 322 // create a socket 323 int fd = socket(AF_INET, SOCK_STREAM, 0); 324 if (fd < 0) 325 return errno; 326 // connect 327 sockaddr_in addr; 328 addr.sin_family = AF_INET; 329 addr.sin_port = htons(port); 330 addr.sin_addr = serverAddr; 331 if (connect(fd, (sockaddr*)&addr, sizeof(addr)) < 0) { 332 status_t error = errno; 333 closesocket(fd); 334 RETURN_ERROR(error); 335 } 336 // create the channel 337 Channel* channel = new(std::nothrow) InsecureChannel(fd); 338 if (!channel) { 339 closesocket(fd); 340 return B_NO_MEMORY; 341 } 342 *_channel = channel; 343 return B_OK; 344 } 345 346 // _SendErrorReply 347 status_t 348 InsecureConnection::_SendErrorReply(Channel* channel, status_t error) 349 { 350 ConnectReply reply; 351 reply.error = B_HOST_TO_BENDIAN_INT32(error); 352 return channel->Send(&reply, sizeof(ConnectReply)); 353 } 354 355