xref: /haiku/src/add-ons/kernel/file_systems/netfs/shared/RequestChannel.cpp (revision 20f046edb99c55b1af0a17340ff8a581d000bc5c)
1 // RequestChannel.cpp
2 
3 #include "RequestChannel.h"
4 
5 #include <new>
6 #include <typeinfo>
7 
8 #include <stdlib.h>
9 #include <string.h>
10 
11 #include <AutoDeleter.h>
12 #include <ByteOrder.h>
13 
14 #include "Channel.h"
15 #include "Compatibility.h"
16 #include "DebugSupport.h"
17 #include "Request.h"
18 #include "RequestDumper.h"
19 #include "RequestFactory.h"
20 #include "RequestFlattener.h"
21 #include "RequestUnflattener.h"
22 
23 static const int32 kMaxSaneRequestSize	= 128 * 1024;	// 128 KB
24 static const int32 kDefaultBufferSize	= 4096;			// 4 KB
25 
26 // ChannelWriter
27 class RequestChannel::ChannelWriter : public Writer {
28 public:
ChannelWriter(Channel * channel,void * buffer,int32 bufferSize)29 	ChannelWriter(Channel* channel, void* buffer, int32 bufferSize)
30 		: fChannel(channel),
31 		  fBuffer(buffer),
32 		  fBufferSize(bufferSize),
33 		  fBytesWritten(0)
34 	{
35 	}
36 
Write(const void * buffer,int32 size)37 	virtual status_t Write(const void* buffer, int32 size)
38 	{
39 		status_t error = B_OK;
40 		// if the data don't fit into the buffer anymore, flush the buffer
41 		if (fBytesWritten + size > fBufferSize) {
42 			error = Flush();
43 			if (error != B_OK)
44 				return error;
45 		}
46 		// if the data don't even fit into an empty buffer, just send it,
47 		// otherwise append it to the buffer
48 		if (size > fBufferSize) {
49 			error = fChannel->Send(buffer, size);
50 			if (error != B_OK)
51 				return error;
52 		} else {
53 			memcpy((uint8*)fBuffer + fBytesWritten, buffer, size);
54 			fBytesWritten += size;
55 		}
56 		return error;
57 	}
58 
Flush()59 	status_t Flush()
60 	{
61 		if (fBytesWritten == 0)
62 			return B_OK;
63 		status_t error = fChannel->Send(fBuffer, fBytesWritten);
64 		if (error != B_OK)
65 			return error;
66 		fBytesWritten = 0;
67 		return B_OK;
68 	}
69 
70 private:
71 	Channel*	fChannel;
72 	void*		fBuffer;
73 	int32		fBufferSize;
74 	int32		fBytesWritten;
75 };
76 
77 
78 // MemoryReader
79 class RequestChannel::MemoryReader : public Reader {
80 public:
MemoryReader(void * buffer,int32 bufferSize)81 	MemoryReader(void* buffer, int32 bufferSize)
82 		: Reader(),
83 		  fBuffer(buffer),
84 		  fBufferSize(bufferSize),
85 		  fBytesRead(0)
86 	{
87 	}
88 
Read(void * buffer,int32 size)89 	virtual status_t Read(void* buffer, int32 size)
90 	{
91 		// check parameters
92 		if (!buffer || size < 0)
93 			return B_BAD_VALUE;
94 		if (size == 0)
95 			return B_OK;
96 		// get pointer into data buffer
97 		void* localBuffer;
98 		bool mustFree;
99 		status_t error = Read(size, &localBuffer, &mustFree);
100 		if (error != B_OK)
101 			return error;
102 		// copy data into supplied buffer
103 		memcpy(buffer, localBuffer, size);
104 		return B_OK;
105 	}
106 
Read(int32 size,void ** buffer,bool * mustFree)107 	virtual status_t Read(int32 size, void** buffer, bool* mustFree)
108 	{
109 		// check parameters
110 		if (size < 0 || !buffer || !mustFree)
111 			return B_BAD_VALUE;
112 		if (fBytesRead + size > fBufferSize)
113 			return B_BAD_VALUE;
114 		// get the data pointer
115 		*buffer = (uint8*)fBuffer + fBytesRead;
116 		*mustFree = false;
117 		fBytesRead += size;
118 		return B_OK;
119 	}
120 
AllBytesRead() const121 	bool AllBytesRead() const
122 	{
123 		return (fBytesRead == fBufferSize);
124 	}
125 
126 private:
127 	void*	fBuffer;
128 	int32	fBufferSize;
129 	int32	fBytesRead;
130 };
131 
132 
133 // RequestHeader
134 struct RequestChannel::RequestHeader {
135 	uint32	type;
136 	int32	size;
137 };
138 
139 
140 // constructor
RequestChannel(Channel * channel)141 RequestChannel::RequestChannel(Channel* channel)
142 	: fChannel(channel),
143 	  fBuffer(NULL),
144 	  fBufferSize(0)
145 {
146 	// allocate the send buffer
147 	fBuffer = malloc(kDefaultBufferSize);
148 	if (fBuffer)
149 		fBufferSize = kDefaultBufferSize;
150 }
151 
152 // destructor
~RequestChannel()153 RequestChannel::~RequestChannel()
154 {
155 	free(fBuffer);
156 }
157 
158 // SendRequest
159 status_t
SendRequest(Request * request)160 RequestChannel::SendRequest(Request* request)
161 {
162 	if (!request)
163 		RETURN_ERROR(B_BAD_VALUE);
164 	PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
165 
166 	// get request size
167 	int32 size;
168 	status_t error = _GetRequestSize(request, &size);
169 	if (error != B_OK)
170 		RETURN_ERROR(error);
171 	if (size < 0 || size > kMaxSaneRequestSize) {
172 		ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: "
173 			"%" B_PRId32 "\n", size);
174 		RETURN_ERROR(B_BAD_DATA);
175 	}
176 
177 	// write the request header
178 	RequestHeader header;
179 	header.type = B_HOST_TO_BENDIAN_INT32(request->GetType());
180 	header.size = B_HOST_TO_BENDIAN_INT32(size);
181 	ChannelWriter writer(fChannel, fBuffer, fBufferSize);
182 	error = writer.Write(&header, sizeof(RequestHeader));
183 	if (error != B_OK)
184 		RETURN_ERROR(error);
185 
186 	// now write the request in earnest
187 	RequestFlattener flattener(&writer);
188 	request->Flatten(&flattener);
189 	error = flattener.GetStatus();
190 	if (error != B_OK)
191 		RETURN_ERROR(error);
192 	error = writer.Flush();
193 	RETURN_ERROR(error);
194 }
195 
196 // ReceiveRequest
197 status_t
ReceiveRequest(Request ** _request)198 RequestChannel::ReceiveRequest(Request** _request)
199 {
200 	if (!_request)
201 		RETURN_ERROR(B_BAD_VALUE);
202 
203 	// get the request header
204 	RequestHeader header;
205 	status_t error = fChannel->Receive(&header, sizeof(RequestHeader));
206 	if (error != B_OK)
207 		RETURN_ERROR(error);
208 	header.type = B_HOST_TO_BENDIAN_INT32(header.type);
209 	header.size = B_HOST_TO_BENDIAN_INT32(header.size);
210 	if (header.size < 0 || header.size > kMaxSaneRequestSize) {
211 		ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: "
212 			"%" B_PRId32 "\n", header.size);
213 		RETURN_ERROR(B_BAD_DATA);
214 	}
215 
216 	// create the request
217 	Request* request;
218 	error = RequestFactory::CreateRequest(header.type, &request);
219 	if (error != B_OK)
220 		RETURN_ERROR(error);
221 	ObjectDeleter<Request> requestDeleter(request);
222 
223 	// allocate a buffer for the data and read them
224 	if (header.size > 0) {
225 		RequestBuffer* requestBuffer = RequestBuffer::Create(header.size);
226 		if (!requestBuffer)
227 			RETURN_ERROR(B_NO_MEMORY);
228 		request->AttachBuffer(requestBuffer);
229 
230 		// receive the data
231 		error = fChannel->Receive(requestBuffer->GetData(), header.size);
232 		if (error != B_OK)
233 			RETURN_ERROR(error);
234 
235 		// unflatten the request
236 		MemoryReader reader(requestBuffer->GetData(), header.size);
237 		RequestUnflattener unflattener(&reader);
238 		request->Unflatten(&unflattener);
239 		error = unflattener.GetStatus();
240 		if (error != B_OK)
241 			RETURN_ERROR(error);
242 		if (!reader.AllBytesRead())
243 			RETURN_ERROR(B_BAD_DATA);
244 	}
245 
246 	requestDeleter.Detach();
247 	*_request = request;
248 	PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
249 	return B_OK;
250 }
251 
252 // _GetRequestSize
253 status_t
_GetRequestSize(Request * request,int32 * size)254 RequestChannel::_GetRequestSize(Request* request, int32* size)
255 {
256 	DummyWriter dummyWriter;
257 	RequestFlattener flattener(&dummyWriter);
258 	request->ShowAround(&flattener);
259 	status_t error = flattener.GetStatus();
260 	if (error != B_OK)
261 		return error;
262 	*size = flattener.GetBytesWritten();
263 	return error;
264 }
265 
266