xref: /haiku/src/add-ons/kernel/file_systems/netfs/authentication_server/AuthenticationServer.cpp (revision 02354704729d38c3b078c696adc1bbbd33cbcf72)
1 // AuthenticationServer.cpp
2 
3 #include "AuthenticationServer.h"
4 
5 #include <new>
6 
7 #include <HashMap.h>
8 #include <HashString.h>
9 #include <util/KMessage.h>
10 
11 #include "AuthenticationPanel.h"
12 #include "AuthenticationServerDefs.h"
13 #include "DebugSupport.h"
14 #include "TaskManager.h"
15 
16 
17 // Authentication
18 class AuthenticationServer::Authentication {
19 public:
20 	Authentication()
21 		: fUser(),
22 		  fPassword()
23 	{
24 	}
25 
26 	Authentication(const char* user, const char* password)
27 		: fUser(user),
28 		  fPassword(password)
29 	{
30 	}
31 
32 	status_t SetTo(const char* user, const char* password)
33 	{
34 		if (fUser.SetTo(user) && fPassword.SetTo(password))
35 			return B_OK;
36 		return B_NO_MEMORY;
37 	}
38 
39 	bool IsValid() const
40 	{
41 		return (fUser.GetLength() > 0);
42 	}
43 
44 	const char* GetUser() const
45 	{
46 		return fUser.GetString();
47 	}
48 
49 	const char* GetPassword() const
50 	{
51 		return fPassword.GetString();
52 	}
53 
54 private:
55 	HashString	fUser;
56 	HashString	fPassword;
57 };
58 
59 // ServerKey
60 class AuthenticationServer::ServerKey {
61 public:
62 	ServerKey()
63 		: fContext(),
64 		  fServer()
65 	{
66 	}
67 
68 	ServerKey(const char* context, const char* server)
69 		: fContext(context),
70 		  fServer(server)
71 	{
72 	}
73 
74 	ServerKey(const ServerKey& other)
75 		: fContext(other.fContext),
76 		  fServer(other.fServer)
77 	{
78 	}
79 
80 	uint32 GetHashCode() const
81 	{
82 		return fContext.GetHashCode() * 17 + fServer.GetHashCode();
83 	}
84 
85 	ServerKey& operator=(const ServerKey& other)
86 	{
87 		fContext = other.fContext;
88 		fServer = other.fServer;
89 		return *this;
90 	}
91 
92 	bool operator==(const ServerKey& other) const
93 	{
94 		return (fContext == other.fContext && fServer == other.fServer);
95 	}
96 
97 	bool operator!=(const ServerKey& other) const
98 	{
99 		return !(*this == other);
100 	}
101 
102 private:
103 	HashString	fContext;
104 	HashString	fServer;
105 };
106 
107 // ServerEntry
108 class AuthenticationServer::ServerEntry {
109 public:
110 	ServerEntry()
111 		: fDefaultAuthentication(),
112 		  fUseDefaultAuthentication(false)
113 	{
114 	}
115 
116 	~ServerEntry()
117 	{
118 		// delete the authentications
119 		for (AuthenticationMap::Iterator it = fAuthentications.GetIterator();
120 			 it.HasNext();) {
121 			delete it.Next().value;
122 		}
123 	}
124 
125 	void SetUseDefaultAuthentication(bool useDefaultAuthentication)
126 	{
127 		fUseDefaultAuthentication = useDefaultAuthentication;
128 	}
129 
130 	bool UseDefaultAuthentication() const
131 	{
132 		return fUseDefaultAuthentication;
133 	}
134 
135 	status_t SetDefaultAuthentication(const char* user, const char* password)
136 	{
137 		return fDefaultAuthentication.SetTo(user, password);
138 	}
139 
140 	const Authentication& GetDefaultAuthentication() const
141 	{
142 		return fDefaultAuthentication;
143 	}
144 
145 	status_t SetAuthentication(const char* share, const char* user,
146 		const char* password)
147 	{
148 		// check, if an entry already exists for the share -- if it does,
149 		// just set it
150 		Authentication* authentication = fAuthentications.Get(share);
151 		if (authentication)
152 			return authentication->SetTo(user, password);
153 		// the entry does not exist yet: create and add a new one
154 		authentication = new(std::nothrow) Authentication;
155 		if (!authentication)
156 			return B_NO_MEMORY;
157 		status_t error = authentication->SetTo(user, password);
158 		if (error == B_OK)
159 			error = fAuthentications.Put(share, authentication);
160 		if (error != B_OK)
161 			delete authentication;
162 		return error;
163 	}
164 
165 	Authentication* GetAuthentication(const char* share) const
166 	{
167 		return fAuthentications.Get(share);
168 	}
169 
170 private:
171 	typedef HashMap<HashString, Authentication*> AuthenticationMap;
172 
173 	Authentication		fDefaultAuthentication;
174 	bool				fUseDefaultAuthentication;
175 	AuthenticationMap	fAuthentications;
176 };
177 
178 // ServerEntryMap
179 struct AuthenticationServer::ServerEntryMap
180 	: HashMap<ServerKey, ServerEntry*> {
181 };
182 
183 // UserDialogTask
184 class AuthenticationServer::UserDialogTask : public Task {
185 public:
186 	UserDialogTask(AuthenticationServer* authenticationServer,
187 		const char* context, const char* server, const char* share,
188 		bool badPassword, port_id replyPort,
189 		int32 replyToken)
190 		: Task("user dialog task"),
191 		  fAuthenticationServer(authenticationServer),
192 		  fContext(context),
193 		  fServer(server),
194 		  fShare(share),
195 		  fBadPassword(badPassword),
196 		  fReplyPort(replyPort),
197 		  fReplyToken(replyToken),
198 		  fPanel(NULL)
199 	{
200 	}
201 
202 	virtual status_t Execute()
203 	{
204 		// open the panel
205 		char user[B_OS_NAME_LENGTH];
206 		char password[B_OS_NAME_LENGTH];
207 		bool keep = true;
208 		fPanel = new(std::nothrow) AuthenticationPanel();
209 		status_t error = (fPanel ? B_OK : B_NO_MEMORY);
210 		bool cancelled = false;
211 		HashString defaultUser;
212 		HashString defaultPassword;
213 		fAuthenticationServer->_GetAuthentication(fContext.GetString(),
214 			fServer.GetString(), NULL, &defaultUser, &defaultPassword);
215 		if (error == B_OK) {
216 			cancelled = fPanel->GetAuthentication(fServer.GetString(),
217 				fShare.GetString(), defaultUser.GetString(),
218 				defaultPassword.GetString(), false, fBadPassword, user,
219 				password, &keep);
220 		}
221 		fPanel = NULL;
222 		// send the reply
223 		if (error != B_OK) {
224 			fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
225 				error, true, NULL, NULL);
226 		} else if (cancelled) {
227 			fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
228 				B_OK, true, NULL, NULL);
229 		} else {
230 			fAuthenticationServer->_AddAuthentication(fContext.GetString(),
231 				fServer.GetString(), fShare.GetString(), user, password,
232 				keep);
233 			fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
234 				B_OK, false, user, password);
235 		}
236 		return error;
237 	}
238 
239 	virtual void Stop()
240 	{
241 		if (fPanel)
242 			fPanel->Cancel();
243 	}
244 
245 private:
246 	AuthenticationServer*	fAuthenticationServer;
247 	HashString				fContext;
248 	HashString				fServer;
249 	HashString				fShare;
250 	bool					fBadPassword;
251 	port_id					fReplyPort;
252 	int32					fReplyToken;
253 	AuthenticationPanel*	fPanel;
254 };
255 
256 
257 // constructor
258 AuthenticationServer::AuthenticationServer()
259 	:
260 	BApplication("application/x-vnd.haiku-authentication_server"),
261 	fLock(),
262 	fRequestPort(-1),
263 	fRequestThread(-1),
264 	fServerEntries(NULL),
265 	fTerminating(false)
266 {
267 }
268 
269 // destructor
270 AuthenticationServer::~AuthenticationServer()
271 {
272 	fTerminating = true;
273 	// terminate the request thread
274 	if (fRequestPort >= 0)
275 		delete_port(fRequestPort);
276 	if (fRequestThread >= 0) {
277 		int32 result;
278 		wait_for_thread(fRequestPort, &result);
279 	}
280 	// delete the server entries
281 	for (ServerEntryMap::Iterator it = fServerEntries->GetIterator();
282 		 it.HasNext();) {
283 		delete it.Next().value;
284 	}
285 }
286 
287 // Init
288 status_t
289 AuthenticationServer::Init()
290 {
291 	// create the server entry map
292 	fServerEntries = new(std::nothrow) ServerEntryMap;
293 	if (!fServerEntries)
294 		return B_NO_MEMORY;
295 	status_t error = fServerEntries->InitCheck();
296 	if (error != B_OK)
297 		return error;
298 	// create the request port
299 	fRequestPort = create_port(10, kAuthenticationServerPortName);
300 	if (fRequestPort < 0)
301 		return fRequestPort;
302 	// spawn the request thread
303 	fRequestThread = spawn_thread(&_RequestThreadEntry, "request thread",
304 		B_NORMAL_PRIORITY, this);
305 	if (fRequestThread < 0)
306 		return fRequestThread;
307 	resume_thread(fRequestThread);
308 	return B_OK;
309 }
310 
311 // _RequestThreadEntry
312 int32
313 AuthenticationServer::_RequestThreadEntry(void* data)
314 {
315 	return ((AuthenticationServer*)data)->_RequestThread();
316 }
317 
318 // _RequestThread
319 int32
320 AuthenticationServer::_RequestThread()
321 {
322 	TaskManager taskManager;
323 	while (!fTerminating) {
324 		taskManager.RemoveDoneTasks();
325 		// read the request
326 		KMessage request;
327 		status_t error = request.ReceiveFrom(fRequestPort);
328 		if (error != B_OK)
329 			continue;
330 		// get the parameters
331 		const char* context = NULL;
332 		const char* server = NULL;
333 		const char* share = NULL;
334 		bool badPassword = true;
335 		request.FindString("context", &context);
336 		request.FindString("server", &server);
337 		request.FindString("share", &share);
338 		request.FindBool("badPassword", &badPassword);
339 		if (!context || !server || !share)
340 			continue;
341 		HashString foundUser;
342 		HashString foundPassword;
343 		if (!badPassword && _GetAuthentication(context, server, share,
344 			&foundUser, &foundPassword)) {
345 			_SendRequestReply(request.ReplyPort(), request.ReplyToken(),
346 				error, false, foundUser.GetString(), foundPassword.GetString());
347 		} else {
348 			// we need to ask the user: create a task that does it
349 			UserDialogTask* task = new(std::nothrow) UserDialogTask(this,
350 				context, server, share, badPassword, request.ReplyPort(),
351 				request.ReplyToken());
352 			if (!task) {
353 				ERROR("AuthenticationServer::_RequestThread(): ERROR: "
354 					"failed to allocate ");
355 				continue;
356 			}
357 			status_t error = taskManager.RunTask(task);
358 			if (error != B_OK) {
359 				ERROR("AuthenticationServer::_RequestThread(): Failed to "
360 					"start server info task: %s\n", strerror(error));
361 				continue;
362 			}
363 		}
364 	}
365 	return 0;
366 }
367 
368 // _GetAuthentication
369 /*!
370 	If share is NULL, the default authentication for the server is returned.
371 */
372 bool
373 AuthenticationServer::_GetAuthentication(const char* context,
374 	const char* server, const char* share, HashString* user,
375 	HashString* password)
376 {
377 	if (!context || !server || !user || !password)
378 		return B_BAD_VALUE;
379 	// get the server entry
380 	AutoLocker<BLocker> _(fLock);
381 	ServerKey key(context, server);
382 	ServerEntry* serverEntry = fServerEntries->Get(key);
383 	if (!serverEntry)
384 		return false;
385 	// get the authentication
386 	const Authentication* authentication = NULL;
387 	if (share) {
388 		serverEntry->GetAuthentication(share);
389 		if (!authentication && serverEntry->UseDefaultAuthentication())
390 			authentication = &serverEntry->GetDefaultAuthentication();
391 	} else
392 		authentication = &serverEntry->GetDefaultAuthentication();
393 	if (!authentication || !authentication->IsValid())
394 		return false;
395 	return (user->SetTo(authentication->GetUser())
396 		&& password->SetTo(authentication->GetPassword()));
397 }
398 
399 // _AddAuthentication
400 status_t
401 AuthenticationServer::_AddAuthentication(const char* context,
402 	const char* server, const char* share, const char* user,
403 	const char* password, bool makeDefault)
404 {
405 	AutoLocker<BLocker> _(fLock);
406 	ServerKey key(context, server);
407 	// get the server entry
408 	ServerEntry* serverEntry = fServerEntries->Get(key);
409 	if (!serverEntry) {
410 		// server entry does not exist yet: create a new one
411 		serverEntry = new(std::nothrow) ServerEntry;
412 		if (!serverEntry)
413 			return B_NO_MEMORY;
414 		status_t error = fServerEntries->Put(key, serverEntry);
415 		if (error != B_OK) {
416 			delete serverEntry;
417 			return error;
418 		}
419 	}
420 	// put the authentication
421 	status_t error = serverEntry->SetAuthentication(share, user, password);
422 	if (error == B_OK) {
423 		if (makeDefault || !serverEntry->UseDefaultAuthentication())
424 			serverEntry->SetDefaultAuthentication(user, password);
425 		if (makeDefault)
426 			serverEntry->SetUseDefaultAuthentication(true);
427 	}
428 	return error;
429 }
430 
431 // _SendRequestReply
432 status_t
433 AuthenticationServer::_SendRequestReply(port_id port, int32 token,
434 	status_t error, bool cancelled, const char* user, const char* password)
435 {
436 	// prepare the reply
437 	KMessage reply;
438 	reply.AddInt32("error", error);
439 	if (error == B_OK) {
440 		reply.AddBool("cancelled", cancelled);
441 		if (!cancelled) {
442 			reply.AddString("user", user);
443 			reply.AddString("password", password);
444 		}
445 	}
446 	// send the reply
447 	return reply.SendTo(port, token);
448 }
449 
450 
451 // main
452 int
453 main()
454 {
455 	AuthenticationServer app;
456 	status_t error = app.Init();
457 	if (error != B_OK)
458 		return 1;
459 	app.Run();
460 	return 0;
461 }
462 
463