1 // RequestChannel.cpp 2 3 #include "RequestChannel.h" 4 5 #include <new> 6 #include <typeinfo> 7 8 #include <stdlib.h> 9 #include <string.h> 10 11 #include <AutoDeleter.h> 12 #include <ByteOrder.h> 13 14 #include "Channel.h" 15 #include "Compatibility.h" 16 #include "DebugSupport.h" 17 #include "Request.h" 18 #include "RequestDumper.h" 19 #include "RequestFactory.h" 20 #include "RequestFlattener.h" 21 #include "RequestUnflattener.h" 22 23 static const int32 kMaxSaneRequestSize = 128 * 1024; // 128 KB 24 static const int32 kDefaultBufferSize = 4096; // 4 KB 25 26 // ChannelWriter 27 class RequestChannel::ChannelWriter : public Writer { 28 public: 29 ChannelWriter(Channel* channel, void* buffer, int32 bufferSize) 30 : fChannel(channel), 31 fBuffer(buffer), 32 fBufferSize(bufferSize), 33 fBytesWritten(0) 34 { 35 } 36 37 virtual status_t Write(const void* buffer, int32 size) 38 { 39 status_t error = B_OK; 40 // if the data don't fit into the buffer anymore, flush the buffer 41 if (fBytesWritten + size > fBufferSize) { 42 error = Flush(); 43 if (error != B_OK) 44 return error; 45 } 46 // if the data don't even fit into an empty buffer, just send it, 47 // otherwise append it to the buffer 48 if (size > fBufferSize) { 49 error = fChannel->Send(buffer, size); 50 if (error != B_OK) 51 return error; 52 } else { 53 memcpy((uint8*)fBuffer + fBytesWritten, buffer, size); 54 fBytesWritten += size; 55 } 56 return error; 57 } 58 59 status_t Flush() 60 { 61 if (fBytesWritten == 0) 62 return B_OK; 63 status_t error = fChannel->Send(fBuffer, fBytesWritten); 64 if (error != B_OK) 65 return error; 66 fBytesWritten = 0; 67 return B_OK; 68 } 69 70 private: 71 Channel* fChannel; 72 void* fBuffer; 73 int32 fBufferSize; 74 int32 fBytesWritten; 75 }; 76 77 78 // MemoryReader 79 class RequestChannel::MemoryReader : public Reader { 80 public: 81 MemoryReader(void* buffer, int32 bufferSize) 82 : Reader(), 83 fBuffer(buffer), 84 fBufferSize(bufferSize), 85 fBytesRead(0) 86 { 87 } 88 89 virtual status_t Read(void* buffer, int32 size) 90 { 91 // check parameters 92 if (!buffer || size < 0) 93 return B_BAD_VALUE; 94 if (size == 0) 95 return B_OK; 96 // get pointer into data buffer 97 void* localBuffer; 98 bool mustFree; 99 status_t error = Read(size, &localBuffer, &mustFree); 100 if (error != B_OK) 101 return error; 102 // copy data into supplied buffer 103 memcpy(buffer, localBuffer, size); 104 return B_OK; 105 } 106 107 virtual status_t Read(int32 size, void** buffer, bool* mustFree) 108 { 109 // check parameters 110 if (size < 0 || !buffer || !mustFree) 111 return B_BAD_VALUE; 112 if (fBytesRead + size > fBufferSize) 113 return B_BAD_VALUE; 114 // get the data pointer 115 *buffer = (uint8*)fBuffer + fBytesRead; 116 *mustFree = false; 117 fBytesRead += size; 118 return B_OK; 119 } 120 121 bool AllBytesRead() const 122 { 123 return (fBytesRead == fBufferSize); 124 } 125 126 private: 127 void* fBuffer; 128 int32 fBufferSize; 129 int32 fBytesRead; 130 }; 131 132 133 // RequestHeader 134 struct RequestChannel::RequestHeader { 135 uint32 type; 136 int32 size; 137 }; 138 139 140 // constructor 141 RequestChannel::RequestChannel(Channel* channel) 142 : fChannel(channel), 143 fBuffer(NULL), 144 fBufferSize(0) 145 { 146 // allocate the send buffer 147 fBuffer = malloc(kDefaultBufferSize); 148 if (fBuffer) 149 fBufferSize = kDefaultBufferSize; 150 } 151 152 // destructor 153 RequestChannel::~RequestChannel() 154 { 155 free(fBuffer); 156 } 157 158 // SendRequest 159 status_t 160 RequestChannel::SendRequest(Request* request) 161 { 162 if (!request) 163 RETURN_ERROR(B_BAD_VALUE); 164 PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name()); 165 166 // get request size 167 int32 size; 168 status_t error = _GetRequestSize(request, &size); 169 if (error != B_OK) 170 RETURN_ERROR(error); 171 if (size < 0 || size > kMaxSaneRequestSize) { 172 ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: " 173 "%" B_PRId32 "\n", size); 174 RETURN_ERROR(B_BAD_DATA); 175 } 176 177 // write the request header 178 RequestHeader header; 179 header.type = B_HOST_TO_BENDIAN_INT32(request->GetType()); 180 header.size = B_HOST_TO_BENDIAN_INT32(size); 181 ChannelWriter writer(fChannel, fBuffer, fBufferSize); 182 error = writer.Write(&header, sizeof(RequestHeader)); 183 if (error != B_OK) 184 RETURN_ERROR(error); 185 186 // now write the request in earnest 187 RequestFlattener flattener(&writer); 188 request->Flatten(&flattener); 189 error = flattener.GetStatus(); 190 if (error != B_OK) 191 RETURN_ERROR(error); 192 error = writer.Flush(); 193 RETURN_ERROR(error); 194 } 195 196 // ReceiveRequest 197 status_t 198 RequestChannel::ReceiveRequest(Request** _request) 199 { 200 if (!_request) 201 RETURN_ERROR(B_BAD_VALUE); 202 203 // get the request header 204 RequestHeader header; 205 status_t error = fChannel->Receive(&header, sizeof(RequestHeader)); 206 if (error != B_OK) 207 RETURN_ERROR(error); 208 header.type = B_HOST_TO_BENDIAN_INT32(header.type); 209 header.size = B_HOST_TO_BENDIAN_INT32(header.size); 210 if (header.size < 0 || header.size > kMaxSaneRequestSize) { 211 ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: " 212 "%" B_PRId32 "\n", header.size); 213 RETURN_ERROR(B_BAD_DATA); 214 } 215 216 // create the request 217 Request* request; 218 error = RequestFactory::CreateRequest(header.type, &request); 219 if (error != B_OK) 220 RETURN_ERROR(error); 221 ObjectDeleter<Request> requestDeleter(request); 222 223 // allocate a buffer for the data and read them 224 if (header.size > 0) { 225 RequestBuffer* requestBuffer = RequestBuffer::Create(header.size); 226 if (!requestBuffer) 227 RETURN_ERROR(B_NO_MEMORY); 228 request->AttachBuffer(requestBuffer); 229 230 // receive the data 231 error = fChannel->Receive(requestBuffer->GetData(), header.size); 232 if (error != B_OK) 233 RETURN_ERROR(error); 234 235 // unflatten the request 236 MemoryReader reader(requestBuffer->GetData(), header.size); 237 RequestUnflattener unflattener(&reader); 238 request->Unflatten(&unflattener); 239 error = unflattener.GetStatus(); 240 if (error != B_OK) 241 RETURN_ERROR(error); 242 if (!reader.AllBytesRead()) 243 RETURN_ERROR(B_BAD_DATA); 244 } 245 246 requestDeleter.Detach(); 247 *_request = request; 248 PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name()); 249 return B_OK; 250 } 251 252 // _GetRequestSize 253 status_t 254 RequestChannel::_GetRequestSize(Request* request, int32* size) 255 { 256 DummyWriter dummyWriter; 257 RequestFlattener flattener(&dummyWriter); 258 request->ShowAround(&flattener); 259 status_t error = flattener.GetStatus(); 260 if (error != B_OK) 261 return error; 262 *size = flattener.GetBytesWritten(); 263 return error; 264 } 265 266