// AuthenticationServer.cpp #include "AuthenticationServer.h" #include #include #include #include #include "AuthenticationPanel.h" #include "AuthenticationServerDefs.h" #include "DebugSupport.h" #include "TaskManager.h" // Authentication class AuthenticationServer::Authentication { public: Authentication() : fUser(), fPassword() { } Authentication(const char* user, const char* password) : fUser(user), fPassword(password) { } status_t SetTo(const char* user, const char* password) { if (fUser.SetTo(user) && fPassword.SetTo(password)) return B_OK; return B_NO_MEMORY; } bool IsValid() const { return (fUser.GetLength() > 0); } const char* GetUser() const { return fUser.GetString(); } const char* GetPassword() const { return fPassword.GetString(); } private: HashString fUser; HashString fPassword; }; // ServerKey class AuthenticationServer::ServerKey { public: ServerKey() : fContext(), fServer() { } ServerKey(const char* context, const char* server) : fContext(context), fServer(server) { } ServerKey(const ServerKey& other) : fContext(other.fContext), fServer(other.fServer) { } uint32 GetHashCode() const { return fContext.GetHashCode() * 17 + fServer.GetHashCode(); } ServerKey& operator=(const ServerKey& other) { fContext = other.fContext; fServer = other.fServer; return *this; } bool operator==(const ServerKey& other) const { return (fContext == other.fContext && fServer == other.fServer); } bool operator!=(const ServerKey& other) const { return !(*this == other); } private: HashString fContext; HashString fServer; }; // ServerEntry class AuthenticationServer::ServerEntry { public: ServerEntry() : fDefaultAuthentication(), fUseDefaultAuthentication(false) { } ~ServerEntry() { // delete the authentications for (AuthenticationMap::Iterator it = fAuthentications.GetIterator(); it.HasNext();) { delete it.Next().value; } } void SetUseDefaultAuthentication(bool useDefaultAuthentication) { fUseDefaultAuthentication = useDefaultAuthentication; } bool UseDefaultAuthentication() const { return fUseDefaultAuthentication; } status_t SetDefaultAuthentication(const char* user, const char* password) { return fDefaultAuthentication.SetTo(user, password); } const Authentication& GetDefaultAuthentication() const { return fDefaultAuthentication; } status_t SetAuthentication(const char* share, const char* user, const char* password) { // check, if an entry already exists for the share -- if it does, // just set it Authentication* authentication = fAuthentications.Get(share); if (authentication) return authentication->SetTo(user, password); // the entry does not exist yet: create and add a new one authentication = new(std::nothrow) Authentication; if (!authentication) return B_NO_MEMORY; status_t error = authentication->SetTo(user, password); if (error == B_OK) error = fAuthentications.Put(share, authentication); if (error != B_OK) delete authentication; return error; } Authentication* GetAuthentication(const char* share) const { return fAuthentications.Get(share); } private: typedef HashMap AuthenticationMap; Authentication fDefaultAuthentication; bool fUseDefaultAuthentication; AuthenticationMap fAuthentications; }; // ServerEntryMap struct AuthenticationServer::ServerEntryMap : HashMap { }; // UserDialogTask class AuthenticationServer::UserDialogTask : public Task { public: UserDialogTask(AuthenticationServer* authenticationServer, const char* context, const char* server, const char* share, bool badPassword, port_id replyPort, int32 replyToken) : Task("user dialog task"), fAuthenticationServer(authenticationServer), fContext(context), fServer(server), fShare(share), fBadPassword(badPassword), fReplyPort(replyPort), fReplyToken(replyToken), fPanel(NULL) { } virtual status_t Execute() { // open the panel char user[B_OS_NAME_LENGTH]; char password[B_OS_NAME_LENGTH]; bool keep = true; fPanel = new(std::nothrow) AuthenticationPanel(); status_t error = (fPanel ? B_OK : B_NO_MEMORY); bool cancelled = false; HashString defaultUser; HashString defaultPassword; fAuthenticationServer->_GetAuthentication(fContext.GetString(), fServer.GetString(), NULL, &defaultUser, &defaultPassword); if (error == B_OK) { cancelled = fPanel->GetAuthentication(fServer.GetString(), fShare.GetString(), defaultUser.GetString(), defaultPassword.GetString(), false, fBadPassword, user, password, &keep); } fPanel = NULL; // send the reply if (error != B_OK) { fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken, error, true, NULL, NULL); } else if (cancelled) { fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken, B_OK, true, NULL, NULL); } else { fAuthenticationServer->_AddAuthentication(fContext.GetString(), fServer.GetString(), fShare.GetString(), user, password, keep); fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken, B_OK, false, user, password); } return error; } virtual void Stop() { if (fPanel) fPanel->Cancel(); } private: AuthenticationServer* fAuthenticationServer; HashString fContext; HashString fServer; HashString fShare; bool fBadPassword; port_id fReplyPort; int32 fReplyToken; AuthenticationPanel* fPanel; }; // constructor AuthenticationServer::AuthenticationServer() : BApplication("application/x-vnd.haiku-authentication_server"), fLock(), fRequestPort(-1), fRequestThread(-1), fServerEntries(NULL), fTerminating(false) { } // destructor AuthenticationServer::~AuthenticationServer() { fTerminating = true; // terminate the request thread if (fRequestPort >= 0) delete_port(fRequestPort); if (fRequestThread >= 0) { int32 result; wait_for_thread(fRequestPort, &result); } // delete the server entries for (ServerEntryMap::Iterator it = fServerEntries->GetIterator(); it.HasNext();) { delete it.Next().value; } } // Init status_t AuthenticationServer::Init() { // create the server entry map fServerEntries = new(std::nothrow) ServerEntryMap; if (!fServerEntries) return B_NO_MEMORY; status_t error = fServerEntries->InitCheck(); if (error != B_OK) return error; // create the request port fRequestPort = create_port(10, kAuthenticationServerPortName); if (fRequestPort < 0) return fRequestPort; // spawn the request thread fRequestThread = spawn_thread(&_RequestThreadEntry, "request thread", B_NORMAL_PRIORITY, this); if (fRequestThread < 0) return fRequestThread; resume_thread(fRequestThread); return B_OK; } // _RequestThreadEntry int32 AuthenticationServer::_RequestThreadEntry(void* data) { return ((AuthenticationServer*)data)->_RequestThread(); } // _RequestThread int32 AuthenticationServer::_RequestThread() { TaskManager taskManager; while (!fTerminating) { taskManager.RemoveDoneTasks(); // read the request KMessage request; status_t error = request.ReceiveFrom(fRequestPort); if (error != B_OK) continue; // get the parameters const char* context = NULL; const char* server = NULL; const char* share = NULL; bool badPassword = true; request.FindString("context", &context); request.FindString("server", &server); request.FindString("share", &share); request.FindBool("badPassword", &badPassword); if (!context || !server || !share) continue; HashString foundUser; HashString foundPassword; if (!badPassword && _GetAuthentication(context, server, share, &foundUser, &foundPassword)) { _SendRequestReply(request.ReplyPort(), request.ReplyToken(), error, false, foundUser.GetString(), foundPassword.GetString()); } else { // we need to ask the user: create a task that does it UserDialogTask* task = new(std::nothrow) UserDialogTask(this, context, server, share, badPassword, request.ReplyPort(), request.ReplyToken()); if (!task) { ERROR("AuthenticationServer::_RequestThread(): ERROR: " "failed to allocate "); continue; } status_t error = taskManager.RunTask(task); if (error != B_OK) { ERROR("AuthenticationServer::_RequestThread(): Failed to " "start server info task: %s\n", strerror(error)); continue; } } } return 0; } // _GetAuthentication /*! If share is NULL, the default authentication for the server is returned. */ bool AuthenticationServer::_GetAuthentication(const char* context, const char* server, const char* share, HashString* user, HashString* password) { if (!context || !server || !user || !password) return B_BAD_VALUE; // get the server entry AutoLocker _(fLock); ServerKey key(context, server); ServerEntry* serverEntry = fServerEntries->Get(key); if (!serverEntry) return false; // get the authentication const Authentication* authentication = NULL; if (share) { serverEntry->GetAuthentication(share); if (!authentication && serverEntry->UseDefaultAuthentication()) authentication = &serverEntry->GetDefaultAuthentication(); } else authentication = &serverEntry->GetDefaultAuthentication(); if (!authentication || !authentication->IsValid()) return false; return (user->SetTo(authentication->GetUser()) && password->SetTo(authentication->GetPassword())); } // _AddAuthentication status_t AuthenticationServer::_AddAuthentication(const char* context, const char* server, const char* share, const char* user, const char* password, bool makeDefault) { AutoLocker _(fLock); ServerKey key(context, server); // get the server entry ServerEntry* serverEntry = fServerEntries->Get(key); if (!serverEntry) { // server entry does not exist yet: create a new one serverEntry = new(std::nothrow) ServerEntry; if (!serverEntry) return B_NO_MEMORY; status_t error = fServerEntries->Put(key, serverEntry); if (error != B_OK) { delete serverEntry; return error; } } // put the authentication status_t error = serverEntry->SetAuthentication(share, user, password); if (error == B_OK) { if (makeDefault || !serverEntry->UseDefaultAuthentication()) serverEntry->SetDefaultAuthentication(user, password); if (makeDefault) serverEntry->SetUseDefaultAuthentication(true); } return error; } // _SendRequestReply status_t AuthenticationServer::_SendRequestReply(port_id port, int32 token, status_t error, bool cancelled, const char* user, const char* password) { // prepare the reply KMessage reply; reply.AddInt32("error", error); if (error == B_OK) { reply.AddBool("cancelled", cancelled); if (!cancelled) { reply.AddString("user", user); reply.AddString("password", password); } } // send the reply return reply.SendTo(port, token); } // main int main() { AuthenticationServer app; status_t error = app.Init(); if (error != B_OK) return 1; app.Run(); return 0; }