1 // RequestChannel.cpp
2
3 #include "RequestChannel.h"
4
5 #include <new>
6 #include <typeinfo>
7
8 #include <stdlib.h>
9 #include <string.h>
10
11 #include <AutoDeleter.h>
12 #include <ByteOrder.h>
13
14 #include "Channel.h"
15 #include "Compatibility.h"
16 #include "DebugSupport.h"
17 #include "Request.h"
18 #include "RequestDumper.h"
19 #include "RequestFactory.h"
20 #include "RequestFlattener.h"
21 #include "RequestUnflattener.h"
22
23 static const int32 kMaxSaneRequestSize = 128 * 1024; // 128 KB
24 static const int32 kDefaultBufferSize = 4096; // 4 KB
25
26 // ChannelWriter
27 class RequestChannel::ChannelWriter : public Writer {
28 public:
ChannelWriter(Channel * channel,void * buffer,int32 bufferSize)29 ChannelWriter(Channel* channel, void* buffer, int32 bufferSize)
30 : fChannel(channel),
31 fBuffer(buffer),
32 fBufferSize(bufferSize),
33 fBytesWritten(0)
34 {
35 }
36
Write(const void * buffer,int32 size)37 virtual status_t Write(const void* buffer, int32 size)
38 {
39 status_t error = B_OK;
40 // if the data don't fit into the buffer anymore, flush the buffer
41 if (fBytesWritten + size > fBufferSize) {
42 error = Flush();
43 if (error != B_OK)
44 return error;
45 }
46 // if the data don't even fit into an empty buffer, just send it,
47 // otherwise append it to the buffer
48 if (size > fBufferSize) {
49 error = fChannel->Send(buffer, size);
50 if (error != B_OK)
51 return error;
52 } else {
53 memcpy((uint8*)fBuffer + fBytesWritten, buffer, size);
54 fBytesWritten += size;
55 }
56 return error;
57 }
58
Flush()59 status_t Flush()
60 {
61 if (fBytesWritten == 0)
62 return B_OK;
63 status_t error = fChannel->Send(fBuffer, fBytesWritten);
64 if (error != B_OK)
65 return error;
66 fBytesWritten = 0;
67 return B_OK;
68 }
69
70 private:
71 Channel* fChannel;
72 void* fBuffer;
73 int32 fBufferSize;
74 int32 fBytesWritten;
75 };
76
77
78 // MemoryReader
79 class RequestChannel::MemoryReader : public Reader {
80 public:
MemoryReader(void * buffer,int32 bufferSize)81 MemoryReader(void* buffer, int32 bufferSize)
82 : Reader(),
83 fBuffer(buffer),
84 fBufferSize(bufferSize),
85 fBytesRead(0)
86 {
87 }
88
Read(void * buffer,int32 size)89 virtual status_t Read(void* buffer, int32 size)
90 {
91 // check parameters
92 if (!buffer || size < 0)
93 return B_BAD_VALUE;
94 if (size == 0)
95 return B_OK;
96 // get pointer into data buffer
97 void* localBuffer;
98 bool mustFree;
99 status_t error = Read(size, &localBuffer, &mustFree);
100 if (error != B_OK)
101 return error;
102 // copy data into supplied buffer
103 memcpy(buffer, localBuffer, size);
104 return B_OK;
105 }
106
Read(int32 size,void ** buffer,bool * mustFree)107 virtual status_t Read(int32 size, void** buffer, bool* mustFree)
108 {
109 // check parameters
110 if (size < 0 || !buffer || !mustFree)
111 return B_BAD_VALUE;
112 if (fBytesRead + size > fBufferSize)
113 return B_BAD_VALUE;
114 // get the data pointer
115 *buffer = (uint8*)fBuffer + fBytesRead;
116 *mustFree = false;
117 fBytesRead += size;
118 return B_OK;
119 }
120
AllBytesRead() const121 bool AllBytesRead() const
122 {
123 return (fBytesRead == fBufferSize);
124 }
125
126 private:
127 void* fBuffer;
128 int32 fBufferSize;
129 int32 fBytesRead;
130 };
131
132
133 // RequestHeader
134 struct RequestChannel::RequestHeader {
135 uint32 type;
136 int32 size;
137 };
138
139
140 // constructor
RequestChannel(Channel * channel)141 RequestChannel::RequestChannel(Channel* channel)
142 : fChannel(channel),
143 fBuffer(NULL),
144 fBufferSize(0)
145 {
146 // allocate the send buffer
147 fBuffer = malloc(kDefaultBufferSize);
148 if (fBuffer)
149 fBufferSize = kDefaultBufferSize;
150 }
151
152 // destructor
~RequestChannel()153 RequestChannel::~RequestChannel()
154 {
155 free(fBuffer);
156 }
157
158 // SendRequest
159 status_t
SendRequest(Request * request)160 RequestChannel::SendRequest(Request* request)
161 {
162 if (!request)
163 RETURN_ERROR(B_BAD_VALUE);
164 PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
165
166 // get request size
167 int32 size;
168 status_t error = _GetRequestSize(request, &size);
169 if (error != B_OK)
170 RETURN_ERROR(error);
171 if (size < 0 || size > kMaxSaneRequestSize) {
172 ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: "
173 "%" B_PRId32 "\n", size);
174 RETURN_ERROR(B_BAD_DATA);
175 }
176
177 // write the request header
178 RequestHeader header;
179 header.type = B_HOST_TO_BENDIAN_INT32(request->GetType());
180 header.size = B_HOST_TO_BENDIAN_INT32(size);
181 ChannelWriter writer(fChannel, fBuffer, fBufferSize);
182 error = writer.Write(&header, sizeof(RequestHeader));
183 if (error != B_OK)
184 RETURN_ERROR(error);
185
186 // now write the request in earnest
187 RequestFlattener flattener(&writer);
188 request->Flatten(&flattener);
189 error = flattener.GetStatus();
190 if (error != B_OK)
191 RETURN_ERROR(error);
192 error = writer.Flush();
193 RETURN_ERROR(error);
194 }
195
196 // ReceiveRequest
197 status_t
ReceiveRequest(Request ** _request)198 RequestChannel::ReceiveRequest(Request** _request)
199 {
200 if (!_request)
201 RETURN_ERROR(B_BAD_VALUE);
202
203 // get the request header
204 RequestHeader header;
205 status_t error = fChannel->Receive(&header, sizeof(RequestHeader));
206 if (error != B_OK)
207 RETURN_ERROR(error);
208 header.type = B_HOST_TO_BENDIAN_INT32(header.type);
209 header.size = B_HOST_TO_BENDIAN_INT32(header.size);
210 if (header.size < 0 || header.size > kMaxSaneRequestSize) {
211 ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: "
212 "%" B_PRId32 "\n", header.size);
213 RETURN_ERROR(B_BAD_DATA);
214 }
215
216 // create the request
217 Request* request;
218 error = RequestFactory::CreateRequest(header.type, &request);
219 if (error != B_OK)
220 RETURN_ERROR(error);
221 ObjectDeleter<Request> requestDeleter(request);
222
223 // allocate a buffer for the data and read them
224 if (header.size > 0) {
225 RequestBuffer* requestBuffer = RequestBuffer::Create(header.size);
226 if (!requestBuffer)
227 RETURN_ERROR(B_NO_MEMORY);
228 request->AttachBuffer(requestBuffer);
229
230 // receive the data
231 error = fChannel->Receive(requestBuffer->GetData(), header.size);
232 if (error != B_OK)
233 RETURN_ERROR(error);
234
235 // unflatten the request
236 MemoryReader reader(requestBuffer->GetData(), header.size);
237 RequestUnflattener unflattener(&reader);
238 request->Unflatten(&unflattener);
239 error = unflattener.GetStatus();
240 if (error != B_OK)
241 RETURN_ERROR(error);
242 if (!reader.AllBytesRead())
243 RETURN_ERROR(B_BAD_DATA);
244 }
245
246 requestDeleter.Detach();
247 *_request = request;
248 PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
249 return B_OK;
250 }
251
252 // _GetRequestSize
253 status_t
_GetRequestSize(Request * request,int32 * size)254 RequestChannel::_GetRequestSize(Request* request, int32* size)
255 {
256 DummyWriter dummyWriter;
257 RequestFlattener flattener(&dummyWriter);
258 request->ShowAround(&flattener);
259 status_t error = flattener.GetStatus();
260 if (error != B_OK)
261 return error;
262 *size = flattener.GetBytesWritten();
263 return error;
264 }
265
266