xref: /haiku/src/add-ons/kernel/file_systems/nfs4/RPCServer.cpp (revision 25a7b01d15612846f332751841da3579db313082)
1 /*
2  * Copyright 2012 Haiku, Inc. All rights reserved.
3  * Distributed under the terms of the MIT License.
4  *
5  * Authors:
6  *		Paweł Dziepak, pdziepak@quarnos.org
7  */
8 
9 
10 #include "RPCServer.h"
11 
12 #include <stdlib.h>
13 
14 #include <util/AutoLock.h>
15 #include <util/Random.h>
16 
17 #include "RPCCallbackServer.h"
18 #include "RPCReply.h"
19 
20 
21 using namespace RPC;
22 
23 
RequestManager()24 RequestManager::RequestManager()
25 	:
26 	fQueueHead(NULL),
27 	fQueueTail(NULL)
28 {
29 	mutex_init(&fLock, NULL);
30 }
31 
32 
~RequestManager()33 RequestManager::~RequestManager()
34 {
35 	mutex_destroy(&fLock);
36 }
37 
38 
39 void
AddRequest(Request * request)40 RequestManager::AddRequest(Request* request)
41 {
42 	ASSERT(request != NULL);
43 
44 	MutexLocker _(fLock);
45 	if (fQueueTail != NULL)
46 		fQueueTail->fNext = request;
47 	else
48 		fQueueHead = request;
49 	fQueueTail = request;
50 	request->fNext = NULL;
51 }
52 
53 
54 Request*
FindRequest(uint32 xid)55 RequestManager::FindRequest(uint32 xid)
56 {
57 	MutexLocker _(fLock);
58 	Request* req = fQueueHead;
59 	Request* prev = NULL;
60 	while (req != NULL) {
61 		if (req->fXID == xid) {
62 			if (prev != NULL)
63 				prev->fNext = req->fNext;
64 			if (fQueueTail == req)
65 				fQueueTail = prev;
66 			if (fQueueHead == req)
67 				fQueueHead = req->fNext;
68 
69 			return req;
70 		}
71 
72 		prev = req;
73 		req = req->fNext;
74 	}
75 
76 	return NULL;
77 }
78 
79 
Server(Connection * connection,PeerAddress * address)80 Server::Server(Connection* connection, PeerAddress* address)
81 	:
82 	fConnection(connection),
83 	fAddress(address),
84 	fPrivateData(NULL),
85 	fCallback(NULL),
86 	fRepairCount(0),
87 	fXID(get_random<uint32>())
88 {
89 	ASSERT(connection != NULL);
90 	ASSERT(address != NULL);
91 
92 	mutex_init(&fCallbackLock, NULL);
93 	mutex_init(&fRepairLock, NULL);
94 
95 	_StartListening();
96 }
97 
98 
~Server()99 Server::~Server()
100 {
101 	if (fCallback != NULL)
102 		fCallback->CBServer()->UnregisterCallback(fCallback);
103 	delete fCallback;
104 	mutex_destroy(&fCallbackLock);
105 	mutex_destroy(&fRepairLock);
106 
107 	delete fPrivateData;
108 
109 	fThreadCancel = true;
110 	fConnection->Disconnect();
111 
112 	status_t result;
113 	wait_for_thread(fThread, &result);
114 
115 	delete fConnection;
116 }
117 
118 
119 status_t
_StartListening()120 Server::_StartListening()
121 {
122 	fThreadCancel = false;
123 	fThreadError = B_OK;
124 	fThread = spawn_kernel_thread(&Server::_ListenerThreadStart,
125 		"NFSv4 Listener", B_NORMAL_PRIORITY, this);
126 	if (fThread < B_OK)
127 		return fThread;
128 
129 	status_t result = resume_thread(fThread);
130 	if (result != B_OK) {
131 		kill_thread(fThread);
132 		return result;
133 	}
134 
135 	return B_OK;
136 }
137 
138 
139 status_t
SendCallAsync(Call * call,Reply ** reply,Request ** request)140 Server::SendCallAsync(Call* call, Reply** reply, Request** request)
141 {
142 	ASSERT(call != NULL);
143 	ASSERT(reply != NULL);
144 	ASSERT(request != NULL);
145 
146 	if (fThreadError != B_OK && Repair() != B_OK)
147 		return fThreadError;
148 
149 	Request* req = new(std::nothrow) Request;
150 	if (req == NULL)
151 		return B_NO_MEMORY;
152 
153 	uint32 xid = _GetXID();
154 	call->SetXID(xid);
155 	req->fXID = xid;
156 	req->fReply = reply;
157 	req->fEvent.Init(&req->fEvent, NULL);
158 	req->fDone = false;
159 	req->fError = B_OK;
160 	req->fNext = NULL;
161 
162 	fRequests.AddRequest(req);
163 
164 	*request = req;
165 	status_t error = ResendCallAsync(call, req);
166 	if (error != B_OK)
167 		delete req;
168 	return error;
169 }
170 
171 
172 status_t
ResendCallAsync(Call * call,Request * request)173 Server::ResendCallAsync(Call* call, Request* request)
174 {
175 	ASSERT(call != NULL);
176 	ASSERT(request != NULL);
177 
178 	if (fThreadError != B_OK && Repair() != B_OK) {
179 		fRequests.FindRequest(request->fXID);
180 		return fThreadError;
181 	}
182 
183 	XDR::WriteStream& stream = call->Stream();
184 	status_t result = fConnection->Send(stream.Buffer(), stream.Size());
185 	if (result != B_OK) {
186 		fRequests.FindRequest(request->fXID);
187 		return result;
188 	}
189 
190 	return B_OK;
191 }
192 
193 
194 status_t
WakeCall(Request * request)195 Server::WakeCall(Request* request)
196 {
197 	ASSERT(request != NULL);
198 
199 	Request* req = fRequests.FindRequest(request->fXID);
200 	if (req == NULL)
201 		return B_OK;
202 
203 	request->fError = B_FILE_ERROR;
204 	*request->fReply = NULL;
205 	request->fDone = true;
206 	request->fEvent.NotifyAll();
207 
208 	return B_OK;
209 }
210 
211 
212 status_t
Repair()213 Server::Repair()
214 {
215 	uint32 thisRepair = fRepairCount;
216 
217 	MutexLocker _(fRepairLock);
218 	if (fRepairCount != thisRepair)
219 		return B_OK;
220 
221 	fThreadCancel = true;
222 
223 	status_t result = fConnection->Reconnect();
224 	if (result != B_OK)
225 		return result;
226 
227 	wait_for_thread(fThread, &result);
228 	result = _StartListening();
229 
230 	if (result == B_OK)
231 		fRepairCount++;
232 
233 	return result;
234 }
235 
236 
237 Callback*
GetCallback()238 Server::GetCallback()
239 {
240 	MutexLocker _(fCallbackLock);
241 
242 	if (fCallback == NULL) {
243 		fCallback = new(std::nothrow) Callback(this);
244 		if (fCallback == NULL)
245 			return NULL;
246 
247 		CallbackServer* server = CallbackServer::Get(this);
248 		if (server == NULL) {
249 			delete fCallback;
250 			return NULL;
251 		}
252 
253 		if (server->RegisterCallback(fCallback) != B_OK) {
254 			delete fCallback;
255 			return NULL;
256 		}
257 	}
258 
259 	return fCallback;
260 }
261 
262 
263 uint32
_GetXID()264 Server::_GetXID()
265 {
266 	return static_cast<uint32>(atomic_add(&fXID, 1));
267 }
268 
269 
270 status_t
_Listener()271 Server::_Listener()
272 {
273 	status_t result;
274 	uint32 size;
275 	void* buffer = NULL;
276 
277 	while (!fThreadCancel) {
278 		result = fConnection->Receive(&buffer, &size);
279 		if (result == B_NO_MEMORY)
280 			continue;
281 		else if (result != B_OK) {
282 			fThreadError = result;
283 			return result;
284 		}
285 
286 		ASSERT(buffer != NULL && size > 0);
287 		Reply* reply = new(std::nothrow) Reply(buffer, size);
288 		if (reply == NULL) {
289 			free(buffer);
290 			continue;
291 		}
292 
293 		Request* req = fRequests.FindRequest(reply->GetXID());
294 		if (req != NULL) {
295 			*req->fReply = reply;
296 			req->fDone = true;
297 			req->fEvent.NotifyAll();
298 		} else
299 			delete reply;
300 	}
301 
302 	return B_OK;
303 }
304 
305 
306 status_t
_ListenerThreadStart(void * object)307 Server::_ListenerThreadStart(void* object)
308 {
309 	ASSERT(object != NULL);
310 
311 	Server* server = reinterpret_cast<Server*>(object);
312 	return server->_Listener();
313 }
314 
315 
ServerManager()316 ServerManager::ServerManager()
317 	:
318 	fRoot(NULL)
319 {
320 	mutex_init(&fLock, NULL);
321 }
322 
323 
~ServerManager()324 ServerManager::~ServerManager()
325 {
326 	mutex_destroy(&fLock);
327 }
328 
329 
330 status_t
Acquire(Server ** _server,AddressResolver * resolver,ProgramData * (* createPrivateData)(Server *))331 ServerManager::Acquire(Server** _server, AddressResolver* resolver,
332 	ProgramData* (*createPrivateData)(Server*))
333 {
334 	PeerAddress address;
335 	status_t result;
336 
337 	while ((result = resolver->GetNextAddress(&address)) == B_OK) {
338 		result = _Acquire(_server, address, createPrivateData);
339 		if (result == B_OK)
340 			break;
341 	}
342 
343 	return result;
344 }
345 
346 
347 status_t
_Acquire(Server ** _server,const PeerAddress & address,ProgramData * (* createPrivateData)(Server *))348 ServerManager::_Acquire(Server** _server, const PeerAddress& address,
349 	ProgramData* (*createPrivateData)(Server*))
350 {
351 	ASSERT(_server != NULL);
352 	ASSERT(createPrivateData != NULL);
353 
354 	status_t result;
355 
356 	MutexLocker locker(fLock);
357 	ServerNode* node = _Find(address);
358 	if (node != NULL) {
359 		node->fRefCount++;
360 		*_server = node->fServer;
361 
362 		return B_OK;
363 	}
364 
365 	node = new(std::nothrow) ServerNode;
366 	if (node == NULL)
367 		return B_NO_MEMORY;
368 
369 	node->fID = address;
370 
371 	Connection* conn;
372 	result = Connection::Connect(&conn, address);
373 	if (result != B_OK) {
374 		delete node;
375 		return result;
376 	}
377 
378 	node->fServer = new Server(conn, &node->fID);
379 	if (node->fServer == NULL) {
380 		delete node;
381 		delete conn;
382 		return B_NO_MEMORY;
383 	}
384 	node->fServer->SetPrivateData(createPrivateData(node->fServer));
385 
386 	node->fRefCount = 1;
387 	node->fLeft = node->fRight = NULL;
388 
389 	ServerNode* nd = _Insert(node);
390 	if (nd != node) {
391 		nd->fRefCount++;
392 
393 		delete node->fServer;
394 		delete node;
395 		*_server = nd->fServer;
396 		return B_OK;
397 	}
398 
399 	*_server = node->fServer;
400 	return B_OK;
401 }
402 
403 
404 void
Release(Server * server)405 ServerManager::Release(Server* server)
406 {
407 	ASSERT(server != NULL);
408 
409 	MutexLocker _(fLock);
410 	ServerNode* node = _Find(server->ID());
411 	if (node != NULL) {
412 		node->fRefCount--;
413 
414 		if (node->fRefCount == 0) {
415 			_Delete(node);
416 			delete node->fServer;
417 			delete node;
418 		}
419 	}
420 }
421 
422 
423 ServerNode*
_Find(const PeerAddress & address)424 ServerManager::_Find(const PeerAddress& address)
425 {
426 	ServerNode* node = fRoot;
427 	while (node != NULL) {
428 		if (node->fID == address)
429 			return node;
430 		if (node->fID < address)
431 			node = node->fRight;
432 		else
433 			node = node->fLeft;
434 	}
435 
436 	return node;
437 }
438 
439 
440 void
_Delete(ServerNode * node)441 ServerManager::_Delete(ServerNode* node)
442 {
443 	ASSERT(node != NULL);
444 
445 	bool found = false;
446 	ServerNode* previous = NULL;
447 	ServerNode* current = fRoot;
448 	while (current != NULL) {
449 		if (current->fID == node->fID) {
450 			found = true;
451 			break;
452 		}
453 
454 		if (current->fID < node->fID) {
455 			previous = current;
456 			current = current->fRight;
457 		} else {
458 			previous = current;
459 			current = current->fLeft;
460 		}
461 	}
462 
463 	if (!found)
464 		return;
465 
466 	if (previous == NULL)
467 		fRoot = NULL;
468 	else if (current->fLeft == NULL && current->fRight == NULL) {
469 		if (previous->fID < node->fID)
470 			previous->fRight = NULL;
471 		else
472 			previous->fLeft = NULL;
473 	} else if (current->fLeft != NULL && current->fRight == NULL) {
474 		if (previous->fID < node->fID)
475 			previous->fRight = current->fLeft;
476 		else
477 			previous->fLeft = current->fLeft;
478 	} else if (current->fLeft == NULL && current->fRight != NULL) {
479 		if (previous->fID < node->fID)
480 			previous->fRight = current->fRight;
481 		else
482 			previous->fLeft = current->fRight;
483 	} else {
484 		ServerNode* left_prev = current;
485 		ServerNode*	left = current->fLeft;
486 
487 		while (left->fLeft != NULL) {
488 			left_prev = left;
489 			left = left->fLeft;
490 		}
491 
492 		if (previous->fID < node->fID)
493 			previous->fRight = left;
494 		else
495 			previous->fLeft = left;
496 
497 
498 		left_prev->fLeft = NULL;
499 	}
500 }
501 
502 
503 ServerNode*
_Insert(ServerNode * node)504 ServerManager::_Insert(ServerNode* node)
505 {
506 	ASSERT(node != NULL);
507 
508 	ServerNode* previous = NULL;
509 	ServerNode* current = fRoot;
510 	while (current != NULL) {
511 		if (current->fID == node->fID)
512 			return current;
513 		if (current->fID < node->fID) {
514 			previous = current;
515 			current = current->fRight;
516 		} else {
517 			previous = current;
518 			current = current->fLeft;
519 		}
520 	}
521 
522 	if (previous == NULL)
523 		fRoot = node;
524 	else if (previous->fID < node->fID)
525 		previous->fRight = node;
526 	else
527 		previous->fLeft = node;
528 
529 	return node;
530 }
531 
532