// RequestChannel.cpp #include "RequestChannel.h" #include #include #include #include #include #include #include "Channel.h" #include "Compatibility.h" #include "DebugSupport.h" #include "Request.h" #include "RequestDumper.h" #include "RequestFactory.h" #include "RequestFlattener.h" #include "RequestUnflattener.h" static const int32 kMaxSaneRequestSize = 128 * 1024; // 128 KB static const int32 kDefaultBufferSize = 4096; // 4 KB // ChannelWriter class RequestChannel::ChannelWriter : public Writer { public: ChannelWriter(Channel* channel, void* buffer, int32 bufferSize) : fChannel(channel), fBuffer(buffer), fBufferSize(bufferSize), fBytesWritten(0) { } virtual status_t Write(const void* buffer, int32 size) { status_t error = B_OK; // if the data don't fit into the buffer anymore, flush the buffer if (fBytesWritten + size > fBufferSize) { error = Flush(); if (error != B_OK) return error; } // if the data don't even fit into an empty buffer, just send it, // otherwise append it to the buffer if (size > fBufferSize) { error = fChannel->Send(buffer, size); if (error != B_OK) return error; } else { memcpy((uint8*)fBuffer + fBytesWritten, buffer, size); fBytesWritten += size; } return error; } status_t Flush() { if (fBytesWritten == 0) return B_OK; status_t error = fChannel->Send(fBuffer, fBytesWritten); if (error != B_OK) return error; fBytesWritten = 0; return B_OK; } private: Channel* fChannel; void* fBuffer; int32 fBufferSize; int32 fBytesWritten; }; // MemoryReader class RequestChannel::MemoryReader : public Reader { public: MemoryReader(void* buffer, int32 bufferSize) : Reader(), fBuffer(buffer), fBufferSize(bufferSize), fBytesRead(0) { } virtual status_t Read(void* buffer, int32 size) { // check parameters if (!buffer || size < 0) return B_BAD_VALUE; if (size == 0) return B_OK; // get pointer into data buffer void* localBuffer; bool mustFree; status_t error = Read(size, &localBuffer, &mustFree); if (error != B_OK) return error; // copy data into supplied buffer memcpy(buffer, localBuffer, size); return B_OK; } virtual status_t Read(int32 size, void** buffer, bool* mustFree) { // check parameters if (size < 0 || !buffer || !mustFree) return B_BAD_VALUE; if (fBytesRead + size > fBufferSize) return B_BAD_VALUE; // get the data pointer *buffer = (uint8*)fBuffer + fBytesRead; *mustFree = false; fBytesRead += size; return B_OK; } bool AllBytesRead() const { return (fBytesRead == fBufferSize); } private: void* fBuffer; int32 fBufferSize; int32 fBytesRead; }; // RequestHeader struct RequestChannel::RequestHeader { uint32 type; int32 size; }; // constructor RequestChannel::RequestChannel(Channel* channel) : fChannel(channel), fBuffer(NULL), fBufferSize(0) { // allocate the send buffer fBuffer = malloc(kDefaultBufferSize); if (fBuffer) fBufferSize = kDefaultBufferSize; } // destructor RequestChannel::~RequestChannel() { free(fBuffer); } // SendRequest status_t RequestChannel::SendRequest(Request* request) { if (!request) RETURN_ERROR(B_BAD_VALUE); PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name()); // get request size int32 size; status_t error = _GetRequestSize(request, &size); if (error != B_OK) RETURN_ERROR(error); if (size < 0 || size > kMaxSaneRequestSize) { ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: " "%" B_PRId32 "\n", size); RETURN_ERROR(B_BAD_DATA); } // write the request header RequestHeader header; header.type = B_HOST_TO_BENDIAN_INT32(request->GetType()); header.size = B_HOST_TO_BENDIAN_INT32(size); ChannelWriter writer(fChannel, fBuffer, fBufferSize); error = writer.Write(&header, sizeof(RequestHeader)); if (error != B_OK) RETURN_ERROR(error); // now write the request in earnest RequestFlattener flattener(&writer); request->Flatten(&flattener); error = flattener.GetStatus(); if (error != B_OK) RETURN_ERROR(error); error = writer.Flush(); RETURN_ERROR(error); } // ReceiveRequest status_t RequestChannel::ReceiveRequest(Request** _request) { if (!_request) RETURN_ERROR(B_BAD_VALUE); // get the request header RequestHeader header; status_t error = fChannel->Receive(&header, sizeof(RequestHeader)); if (error != B_OK) RETURN_ERROR(error); header.type = B_HOST_TO_BENDIAN_INT32(header.type); header.size = B_HOST_TO_BENDIAN_INT32(header.size); if (header.size < 0 || header.size > kMaxSaneRequestSize) { ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: " "%" B_PRId32 "\n", header.size); RETURN_ERROR(B_BAD_DATA); } // create the request Request* request; error = RequestFactory::CreateRequest(header.type, &request); if (error != B_OK) RETURN_ERROR(error); ObjectDeleter requestDeleter(request); // allocate a buffer for the data and read them if (header.size > 0) { RequestBuffer* requestBuffer = RequestBuffer::Create(header.size); if (!requestBuffer) RETURN_ERROR(B_NO_MEMORY); request->AttachBuffer(requestBuffer); // receive the data error = fChannel->Receive(requestBuffer->GetData(), header.size); if (error != B_OK) RETURN_ERROR(error); // unflatten the request MemoryReader reader(requestBuffer->GetData(), header.size); RequestUnflattener unflattener(&reader); request->Unflatten(&unflattener); error = unflattener.GetStatus(); if (error != B_OK) RETURN_ERROR(error); if (!reader.AllBytesRead()) RETURN_ERROR(B_BAD_DATA); } requestDeleter.Detach(); *_request = request; PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name()); return B_OK; } // _GetRequestSize status_t RequestChannel::_GetRequestSize(Request* request, int32* size) { DummyWriter dummyWriter; RequestFlattener flattener(&dummyWriter); request->ShowAround(&flattener); status_t error = flattener.GetStatus(); if (error != B_OK) return error; *size = flattener.GetBytesWritten(); return error; }