diff --git a/AnkiServer/apps/sync_app.py b/AnkiServer/apps/sync_app.py index 658ecf3..6de6b90 100644 --- a/AnkiServer/apps/sync_app.py +++ b/AnkiServer/apps/sync_app.py @@ -35,6 +35,8 @@ from anki.utils import intTime, checksum, isMac from anki.consts import SYNC_ZIP_SIZE, SYNC_ZIP_COUNT +from AnkiServer.user_managers import SimpleUserManager, SqliteUserManager + try: import simplejson as json except ImportError: @@ -338,25 +340,6 @@ def save(self, hkey, session): def delete(self, hkey): del self.sessions[hkey] -class SimpleUserManager(object): - """A simple user manager that always allows any user.""" - - def authenticate(self, username, password): - """ - Returns True if this username is allowed to connect with this password. False otherwise. - Override this to change how users are authenticated. - """ - - return True - - def username2dirname(self, username): - """ - Returns the directory name for the given user. By default, this is just the username. - Override this to adjust the mapping between users and their directory. - """ - - return username - class SyncApp(object): valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download', 'getDecks'] @@ -730,34 +713,6 @@ def delete(self, hkey): cursor.execute("DELETE FROM session WHERE hkey=?", (hkey,)) conn.commit() -class SqliteUserManager(SimpleUserManager): - """Authenticates users against a SQLite database.""" - - def __init__(self, auth_db_path): - self.auth_db_path = os.path.abspath(auth_db_path) - - def authenticate(self, username, password): - """Returns True if this username is allowed to connect with this password. False otherwise.""" - - conn = sqlite.connect(self.auth_db_path) - cursor = conn.cursor() - param = (username,) - - cursor.execute("SELECT hash FROM auth WHERE user=?", param) - - db_ret = cursor.fetchone() - - if db_ret != None: - db_hash = str(db_ret[0]) - salt = db_hash[-16:] - hashobj = hashlib.sha256() - - hashobj.update(username+password+salt) - - conn.close() - - return (db_ret != None and hashobj.hexdigest()+salt == db_hash) - # Our entry point def make_app(global_conf, **local_conf): if local_conf.has_key('session_db_path'): diff --git a/AnkiServer/user_managers.py b/AnkiServer/user_managers.py new file mode 100644 index 0000000..49c4d80 --- /dev/null +++ b/AnkiServer/user_managers.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- + + +import binascii +import hashlib +import logging +import os +import sqlite3 as sqlite + + +class SimpleUserManager(object): + """A simple user manager that always allows any user.""" + + def __init__(self, collection_path=''): + self.collection_path = collection_path + + def authenticate(self, username, password): + """ + Returns True if this username is allowed to connect with this password. + False otherwise. Override this to change how users are authenticated. + """ + + return True + + def username2dirname(self, username): + """ + Returns the directory name for the given user. By default, this is just + the username. Override this to adjust the mapping between users and + their directory. + """ + + return username + + def _create_user_dir(self, username): + user_dir_path = os.path.join(self.collection_path, username) + if not os.path.isdir(user_dir_path): + logging.info("Creating collection directory for user '{}' at {}" + .format(username, user_dir_path)) + os.makedirs(user_dir_path) + + +class SqliteUserManager(SimpleUserManager): + """Authenticates users against a SQLite database.""" + + def __init__(self, auth_db_path, collection_path=None): + SimpleUserManager.__init__(self, collection_path) + self.auth_db_path = auth_db_path + + def auth_db_exists(self): + return os.path.isfile(self.auth_db_path) + + def user_list(self): + if not self.auth_db_exists(): + raise ValueError("Cannot list users for nonexistent auth db {}." + .format(self.auth_db_path)) + else: + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + cursor.execute("SELECT user FROM auth") + rows = cursor.fetchall() + conn.commit() + conn.close() + + return [row[0] for row in rows] + + def user_exists(self, username): + users = self.user_list() + return username in users + + def del_user(self, username): + if not self.auth_db_exists(): + raise ValueError("Cannot remove user from nonexistent auth db {}." + .format(self.auth_db_path)) + else: + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + logging.info("Removing user '{}' from auth db." + .format(username)) + cursor.execute("DELETE FROM auth WHERE user=?", (username,)) + conn.commit() + conn.close() + + def add_user(self, username, password): + self._add_user_to_auth_db(username, password) + self._create_user_dir(username) + + def add_users(self, users_data): + for username, password in users_data: + self.add_user(username, password) + + def _add_user_to_auth_db(self, username, password): + if not self.auth_db_exists(): + self.create_auth_db() + + pass_hash = self._create_pass_hash(username, password) + + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + logging.info("Adding user '{}' to auth db.".format(username)) + cursor.execute("INSERT INTO auth VALUES (?, ?)", + (username, pass_hash)) + conn.commit() + conn.close() + + def set_password_for_user(self, username, new_password): + if not self.auth_db_exists(): + raise ValueError("Cannot remove user from nonexistent auth db {}." + .format(self.auth_db_path)) + elif not self.user_exists(username): + raise ValueError("Cannot remove nonexistent user {}." + .format(username)) + else: + hash = self._create_pass_hash(username, new_password) + + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + cursor.execute("UPDATE auth SET hash=? WHERE user=?", (hash, username)) + conn.commit() + conn.close() + + logging.info("Changed password for user {}.".format(username)) + + def authenticate_user(self, username, password): + """Returns True if this username is allowed to connect with this password. False otherwise.""" + + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + param = (username,) + cursor.execute("SELECT hash FROM auth WHERE user=?", param) + db_hash = cursor.fetchone() + conn.close() + + if db_hash is None: + logging.info("Authentication failed for nonexistent user {}." + .format(username)) + return False + else: + expected_value = str(db_hash[0]) + salt = self._extract_salt(expected_value) + + hashobj = hashlib.sha256() + hashobj.update(username + password + salt) + actual_value = hashobj.hexdigest() + salt + + if actual_value == expected_value: + logging.info("Authentication succeeded for user {}." + .format(username)) + return True + else: + logging.info("Authentication failed for user {}." + .format(username)) + return False + + @staticmethod + def _extract_salt(hash): + return hash[-16:] + + @staticmethod + def _create_pass_hash(username, password): + salt = binascii.b2a_hex(os.urandom(8)) + pass_hash = (hashlib.sha256(username + password + salt).hexdigest() + + salt) + return pass_hash + + def create_auth_db(self): + conn = sqlite.connect(self.auth_db_path) + cursor = conn.cursor() + logging.info("Creating auth db at {}." + .format(self.auth_db_path)) + cursor.execute("""CREATE TABLE IF NOT EXISTS auth + (user VARCHAR PRIMARY KEY, hash VARCHAR)""") + conn.commit() + conn.close() diff --git a/ankiserverctl.py b/ankiserverctl.py index ac8d47c..7f2b4e8 100755 --- a/ankiserverctl.py +++ b/ankiserverctl.py @@ -1,13 +1,14 @@ #!/usr/bin/env python +from __future__ import print_function + import os import sys import signal import subprocess -import binascii import getpass -import hashlib -import sqlite3 + +from AnkiServer.user_managers import SqliteUserManager SERVERCONFIG = "production.ini" AUTHDBPATH = "auth.db" @@ -15,16 +16,16 @@ COLLECTIONPATH = "collections/" def usage(): - print "usage: "+sys.argv[0]+" []" - print - print "Commands:" - print " start [configfile] - start the server" - print " debug [configfile] - start the server in debug mode" - print " stop - stop the server" - print " adduser - add a new user" - print " deluser - delete a user" - print " lsuser - list users" - print " passwd - change password of a user" + print("usage: "+sys.argv[0]+" []") + print() + print("Commands:") + print(" start [configfile] - start the server") + print(" debug [configfile] - start the server in debug mode") + print(" stop - stop the server") + print(" adduser - add a new user") + print(" deluser - delete a user") + print(" lsuser - list users") + print(" passwd - change password of a user") def startsrv(configpath, debug): if not configpath: @@ -58,81 +59,62 @@ def stopsrv(): os.kill(pid, signal.SIGKILL) os.remove(PIDPATH) - except Exception, error: - print >>sys.stderr, sys.argv[0]+": Failed to stop server: "+error.message + except Exception as error: + print("{}: Failed to stop server: {}" + .format(sys.argv[0], error.message), file=sys.stderr) else: - print >>sys.stderr, sys.argv[0]+": The server is not running" + print("{}: The server is not running".format(sys.argv[0]), + file=sys.stderr) def adduser(username): if username: - print "Enter password for "+username+": " - + print("Enter password for {}".format(username)) password = getpass.getpass() - salt = binascii.b2a_hex(os.urandom(8)) - hash = hashlib.sha256(username+password+salt).hexdigest()+salt - - conn = sqlite3.connect(AUTHDBPATH) - cursor = conn.cursor() - - cursor.execute( "CREATE TABLE IF NOT EXISTS auth " - "(user VARCHAR PRIMARY KEY, hash VARCHAR)") - cursor.execute("INSERT INTO auth VALUES (?, ?)", (username, hash)) - - if not os.path.isdir(COLLECTIONPATH+username): - os.makedirs(COLLECTIONPATH+username) - - conn.commit() - conn.close() + user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + user_manager.add_user(username, password) else: usage() def deluser(username): if username and os.path.isfile(AUTHDBPATH): - conn = sqlite3.connect(AUTHDBPATH) - cursor = conn.cursor() + user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) - cursor.execute("DELETE FROM auth WHERE user=?", (username,)) - - conn.commit() - conn.close() + try: + user_manager.del_user(username) + except ValueError as error: + print("Could not delete user {}: {}" + .format(username, error.message), file=sys.stderr) elif not username: usage() else: - print >>sys.stderr, sys.argv[0]+": Database file does not exist" + print("{}: Database file does not exist".format(sys.argv[0]), + file=sys.stderr) def lsuser(): - conn = sqlite3.connect(AUTHDBPATH) - cursor = conn.cursor() - - cursor.execute("SELECT user FROM auth") - - row = cursor.fetchone() - - while row is not None: - print row[0] - - row = cursor.fetchone() - - conn.close() + user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + try: + users = user_manager.user_list() + for username in users: + print(username) + except ValueError as error: + print("Could not list users: {}".format(AUTHDBPATH, error.message), + file=sys.stderr) def passwd(username): if os.path.isfile(AUTHDBPATH): - print "Enter password for "+username+": " - + print("Enter password for {}:".format(username)) password = getpass.getpass() - salt = binascii.b2a_hex(os.urandom(8)) - hash = hashlib.sha256(username+password+salt).hexdigest()+salt - - conn = sqlite3.connect(AUTHDBPATH) - cursor = conn.cursor() - cursor.execute("UPDATE auth SET hash=? WHERE user=?", (hash, username)) - - conn.commit() - conn.close() + user_manager = SqliteUserManager(AUTHDBPATH, COLLECTIONPATH) + try: + user_manager.set_password_for_user(username, password) + except ValueError as error: + print("Could not set password for user {}: {}" + .format(username, error.message), file=sys.stderr) else: - print >>sys.stderr, sys.argv[0]+": Database file does not exist" + print("{}: Database file does not exist".format(sys.argv[0]), + file=sys.stderr) def main(): argc = len(sys.argv) diff --git a/setup.py b/setup.py index 328a6e1..9e8295b 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,8 @@ def get_anki_bundled_files(): ], tests_require=[ 'nose>=1.3.0', - 'mock>=1.0.0,<2.0.0a', + 'mock>=1.0.0', + 'webtest>=2.0.20' ], data_files=get_anki_bundled_files()+[ ('examples', [ diff --git a/test.ini b/test.ini new file mode 100644 index 0000000..6663431 --- /dev/null +++ b/test.ini @@ -0,0 +1,30 @@ + +[server:main] +use = egg:AnkiServer#server +host = 127.0.0.1 +port = 27701 + +[filter-app:main] +use = egg:Paste#translogger +next = real + +[app:real] +use = egg:Paste#urlmap +/ = rest_app +/msync = sync_app +/sync = sync_app + +[app:rest_app] +use = egg:AnkiServer#rest_app +data_root = ./collections +allowed_hosts = 127.0.0.1 +;logging.config_file = logging.conf + +[app:sync_app] +use = egg:AnkiServer#sync_app +data_root = ./collections +base_url = /sync/ +base_media_url = /msync/ +session_db_path = ./session.db +auth_db_path = ./auth.db + diff --git a/tests/assets/blue.jpg b/tests/assets/blue.jpg new file mode 100644 index 0000000..c958d13 Binary files /dev/null and b/tests/assets/blue.jpg differ diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/collection_utils.py b/tests/helpers/collection_utils.py new file mode 100644 index 0000000..3eb3483 --- /dev/null +++ b/tests/helpers/collection_utils.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + + +import os +import shutil +import tempfile + + +from anki import Collection +from helpers.file_utils import FileUtils + + +class CollectionUtils(object): + """ + Provides utility methods for creating, inspecting and manipulating anki + collections. + """ + + def __init__(self): + self.collections_to_close = [] + self.fileutils = FileUtils() + self.master_db_path = None + + def __create_master_col(self): + """ + Creates an empty master anki db that will be copied on each request + for a new db. This is more efficient than initializing a new db each + time. + """ + + file_descriptor, file_path = tempfile.mkstemp(suffix=".anki2") + os.close(file_descriptor) + os.unlink(file_path) # We only need the file path. + master_col = Collection(file_path) + self.__mark_col_paths_for_deletion(master_col) + master_col.db.close() + self.master_db_path = file_path + + self.fileutils.mark_for_deletion(self.master_db_path) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.clean_up() + + def __mark_collection_for_closing(self, collection): + self.collections_to_close.append(collection) + + def __mark_col_paths_for_deletion(self, collection): + """ + Marks the paths of all the database files and directories managed by + the collection for later deletion. + """ + self.fileutils.mark_for_deletion(collection.path) + self.fileutils.mark_for_deletion(collection.media.dir()) + self.fileutils.mark_for_deletion(collection.media.col.path) + + def clean_up(self): + """ + Removes all files created by the Collection objects we issued and the + master db file. + """ + + # Close collections. + for col in self.collections_to_close: + col.close() # This also closes the media col. + self.collections_to_close = [] + + # Remove the files created by the collections. + self.fileutils.clean_up() + + self.master_db_path = None + + def create_empty_col(self): + """ + Returns a Collection object using a copy of our master db file. + """ + + if self.master_db_path is None: + self.__create_master_col() + + file_descriptor, file_path = tempfile.mkstemp(suffix=".anki2") + + # Overwrite temp file with a copy of our master db. + shutil.copy(self.master_db_path, file_path) + collection = Collection(file_path) + + self.__mark_collection_for_closing(collection) + self.__mark_col_paths_for_deletion(collection) + return collection + + @staticmethod + def create_col_from_existing_db(db_file_path): + """ + Returns a Collection object created from an existing anki db file. + """ + + return Collection(db_file_path) diff --git a/tests/helpers/db_utils.py b/tests/helpers/db_utils.py new file mode 100644 index 0000000..24fd650 --- /dev/null +++ b/tests/helpers/db_utils.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + + +import os +import sqlite3 +import subprocess + + +from helpers.file_utils import FileUtils + + +class DBUtils(object): + """Provides methods for creating and comparing sqlite databases.""" + + def __init__(self): + self.fileutils = FileUtils() + + def clean_up(self): + self.fileutils.clean_up() + + def create_sqlite_db_with_sql(self, sql_string): + """ + Creates an SQLite db and executes the passed sql statements on it. + + :param sql_string: the sql statements to execute on the newly created + db + :return: the path to the created db file + """ + + db_path = self.fileutils.create_file_path(suffix=".anki2") + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + cursor.executescript(sql_string) + connection.commit() + connection.close() + + return db_path + + @staticmethod + def sqlite_db_to_sql_string(database): + """ + Returns a string containing the sql export of the database. Used for + debugging. + + :param database: either the path to the SQLite db file or an open + connection to it + :return: a string representing the sql export of the database + """ + + if type(database) == str: + connection = sqlite3.connect(database) + else: + connection = database + + res = '\n'.join(connection.iterdump()) + + if type(database) == str: + connection.close() + + return res + + def media_dbs_differ(self, left_db_path, right_db_path, compare_timestamps=False): + """ + Compares two media sqlite database files for equality. mtime and dirMod + timestamps are not considered when comparing. + + :param left_db_path: path to the left db file + :param right_db_path: path to the right db file + :param compare_timestamps: flag determining if timestamp values + (media.mtime and meta.dirMod) are included + in the comparison + :return: True if the specified databases differ, False else + """ + + if not os.path.isfile(left_db_path): + raise IOError("file '" + left_db_path + "' does not exist") + elif not os.path.isfile(right_db_path): + raise IOError("file '" + right_db_path + "' does not exist") + + # Create temporary copies of the files to act on. + left_db_path = self.fileutils.create_file_copy(left_db_path) + right_db_path = self.fileutils.create_file_copy(right_db_path) + + if not compare_timestamps: + # Set all timestamps that are not NULL to 0. + for dbPath in [left_db_path, right_db_path]: + connection = sqlite3.connect(dbPath) + + connection.execute("""UPDATE media SET mtime=0 + WHERE mtime IS NOT NULL""") + + connection.execute("""UPDATE meta SET dirMod=0 + WHERE rowid=1""") + connection.commit() + connection.close() + + return self.__sqlite_dbs_differ(left_db_path, right_db_path) + + def __sqlite_dbs_differ(self, left_db_path, right_db_path): + """ + Uses the sqldiff cli tool to compare two sqlite files for equality. + Returns True if the databases differ, False if they don't. + + :param left_db_path: path to the left db file + :param right_db_path: path to the right db file + :return: True if the specified databases differ, False else + """ + + command = ["/bin/sqldiff", left_db_path, right_db_path] + + try: + child_process = subprocess.Popen(command, + shell=False, + stdout=subprocess.PIPE) + stdout, stderr = child_process.communicate() + exit_code = child_process.returncode + + if exit_code != 0 or stderr is not None: + raise RuntimeError("Command {} encountered an error, exit " + "code: {}, stderr: {}" + .format(" ".join(command), + exit_code, + stderr)) + + # Any output from sqldiff means the databases differ. + return stdout != "" + except OSError as err: + raise err diff --git a/tests/helpers/file_utils.py b/tests/helpers/file_utils.py new file mode 100644 index 0000000..2238b5c --- /dev/null +++ b/tests/helpers/file_utils.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- + + +from cStringIO import StringIO +import json +import logging +import logging.config +import os +import random +import shutil +import tempfile +import unicodedata +import zipfile + + +from anki.consts import SYNC_ZIP_SIZE +from anki.utils import checksum + + +class FileUtils(object): + """ + Provides utility methods for creating temporary files and directories. All + created files and dirs are recursively removed when clean_up() is called. + Supports the with statement. + """ + + def __init__(self): + self.paths_to_delete = [] + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.clean_up() + + def clean_up(self): + """ + Recursively removes all files and directories created by this instance. + """ + + # Change cwd to a dir we're not about to delete so later calls to + # os.getcwd() and similar functions don't raise Exceptions. + os.chdir("/tmp") + + # Error callback for shutil.rmtree(). + def on_error(func, path, excinfo): + logging.error("Error removing file: func={}, path={}, excinfo={}" + .format(func, path, excinfo)) + + for path in self.paths_to_delete: + if os.path.isfile(path): + logging.debug("Removing temporary file '{}'.".format(path)) + os.remove(path) + elif os.path.isdir(path): + logging.debug(("Removing temporary dir tree '{}' with " + + "files {}").format(path, os.listdir(path))) + shutil.rmtree(path, onerror=on_error) + + self.paths_to_delete = [] + + def mark_for_deletion(self, path): + self.paths_to_delete.append(path) + + def create_file(self, suffix='', prefix='tmp'): + file_descriptor, file_path = tempfile.mkstemp(suffix=suffix, + prefix=prefix) + self.mark_for_deletion(file_path) + return file_path + + def create_dir(self, suffix='', prefix='tmp'): + dir_path = tempfile.mkdtemp(suffix=suffix, + prefix=prefix) + self.mark_for_deletion(dir_path) + return dir_path + + def create_file_path(self, suffix='', prefix='tmp'): + """Generates a file path.""" + + file_path = self.create_file(suffix, prefix) + os.unlink(file_path) + return file_path + + def create_dir_path(self, suffix='', prefix='tmp'): + dir_path = self.create_dir(suffix, prefix) + os.rmdir(dir_path) + return dir_path + + def create_named_file(self, filename, file_contents=None): + """ + Creates a temporary file with a custom name within a new temporary + directory and marks that parent dir for recursive deletion method. + """ + + # We need to create a parent directory for the file so we can freely + # choose the file name . + temp_file_parent_dir = tempfile.mkdtemp(prefix="anki") + self.mark_for_deletion(temp_file_parent_dir) + + file_path = os.path.join(temp_file_parent_dir, filename) + + if file_contents is not None: + open(file_path, 'w').write(file_contents) + + return file_path + + def create_named_file_path(self, filename): + file_path = self.create_named_file(filename) + return file_path + + def create_file_copy(self, path): + basename = os.path.basename(path) + temp_file_path = self.create_named_file_path(basename) + shutil.copyfile(path, temp_file_path) + return temp_file_path + + def create_named_files(self, filenames_and_data): + """ + Creates temporary files within the same new temporary parent directory + and marks that parent for recursive deletion. + + :param filenames_and_data: list of tuples (filename, file contents) + :return: list of paths to the created files + """ + + temp_files_parent_dir = tempfile.mkdtemp(prefix="anki") + self.mark_for_deletion(temp_files_parent_dir) + + file_paths = [] + for filename, file_contents in filenames_and_data: + path = os.path.join(temp_files_parent_dir, filename) + file_paths.append(path) + if file_contents is not None: + open(path, 'w').write(file_contents) + + return file_paths + + @staticmethod + def create_zip_with_existing_files(file_paths): + """ + The method zips existing files and returns the zip data. Logic is + adapted from Anki Desktop's MediaManager.mediaChangesZip(). + + :param file_paths: the paths of the files to include in the zip + :type file_paths: list + :return: the data of the created zip file + """ + + file_buffer = StringIO() + zip_file = zipfile.ZipFile(file_buffer, + 'w', + compression=zipfile.ZIP_DEFLATED) + + meta = [] + sz = 0 + + for count, filePath in enumerate(file_paths): + zip_file.write(filePath, str(count)) + normname = unicodedata.normalize( + "NFC", + os.path.basename(filePath) + ) + meta.append((normname, str(count))) + + sz += os.path.getsize(filePath) + if sz >= SYNC_ZIP_SIZE: + break + + zip_file.writestr("_meta", json.dumps(meta)) + zip_file.close() + + return file_buffer.getvalue() + + def get_asset_path(self, relative_file_path): + """ + Retrieves the path of a file for testing from the "assets" directory. + + :param relative_file_path: the name of the file to retrieve, relative + to the "assets" directory + :return: the absolute path to the file in the "assets" directory. + """ + + join = os.path.join + + script_dir = os.path.dirname(os.path.realpath(__file__)) + support_dir = join(script_dir, os.pardir, "assets") + res = join(support_dir, relative_file_path) + return res diff --git a/tests/helpers/mock_servers.py b/tests/helpers/mock_servers.py new file mode 100644 index 0000000..6ac0332 --- /dev/null +++ b/tests/helpers/mock_servers.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + + +import logging + + +from anki.sync import HttpSyncer, RemoteServer, RemoteMediaServer, FullSyncer + + +class MockServerConnection(object): + """ + Mock for HttpSyncer's con attribute, a httplib2 connection. All requests + that would normally got to the remote server will be redirected to our + server_app_to_test object. + """ + + def __init__(self, server_app_to_test): + self.test_app = server_app_to_test + + def request(self, uri, method='GET', headers=None, body=None): + if method == 'POST': + logging.debug("Posting to URI '{}'.".format(uri)) + logging.info("Posting to URI '{}'.".format(uri)) + test_response = self.test_app.post(uri, + params=body, + headers=headers, + status="*") + + resp = test_response.headers + resp.update({ + "status": str(test_response.status_int) + }) + cont = test_response.body + return resp, cont + else: + raise Exception('Unexpected HttpSyncer.req() behavior.') + + +class MockRemoteServer(RemoteServer): + """ + Mock for RemoteServer. All communication to our remote counterpart is + routed to our TestApp object. + """ + + def __init__(self, hkey, server_test_app): + # Create a custom connection object we will use to communicate with our + # 'remote' server counterpart. + connection = MockServerConnection(server_test_app) + HttpSyncer.__init__(self, hkey, connection) + + def syncURL(self): # Overrides RemoteServer.syncURL(). + return "/sync/" + + +class MockRemoteMediaServer(RemoteMediaServer): + """ + Mock for RemoteMediaServer. All communication to our remote counterpart is + routed to our TestApp object. + """ + + def __init__(self, col, hkey, server_test_app): + # Create a custom connection object we will use to communicate with our + # 'remote' server counterpart. + connection = MockServerConnection(server_test_app) + HttpSyncer.__init__(self, hkey, connection) + + def syncURL(self): # Overrides RemoteServer.syncURL(). + return "/msync/" diff --git a/tests/helpers/monkey_patches.py b/tests/helpers/monkey_patches.py new file mode 100644 index 0000000..f9ea83b --- /dev/null +++ b/tests/helpers/monkey_patches.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + + +import os +import sqlite3 as sqlite +from anki.media import MediaManager +from anki.storage import DB + + +mediamanager_orig_funcs = { + "findChanges": None, + "mediaChangesZip": None, + "addFilesFromZip": None, + "syncDelete": None +} + + +db_orig_funcs = { + "__init__": None +} + + +def monkeypatch_mediamanager(): + """ + Monkey patches anki.media.MediaManager's methods so they chdir to + self.dir() before acting on its media directory and chdir back to the + original cwd after finishing. + """ + + def make_cwd_safe(original_func): + mediamanager_orig_funcs["findChanges"] = MediaManager.findChanges + mediamanager_orig_funcs["mediaChangesZip"] = MediaManager.mediaChangesZip + mediamanager_orig_funcs["addFilesFromZip"] = MediaManager.addFilesFromZip + mediamanager_orig_funcs["syncDelete"] = MediaManager.syncDelete + + def wrapper(instance, *args): + old_cwd = os.getcwd() + os.chdir(instance.dir()) + + res = original_func(instance, *args) + + os.chdir(old_cwd) + return res + return wrapper + + MediaManager.findChanges = make_cwd_safe(MediaManager.findChanges) + MediaManager.mediaChangesZip = make_cwd_safe(MediaManager.mediaChangesZip) + MediaManager.addFilesFromZip = make_cwd_safe(MediaManager.addFilesFromZip) + MediaManager.syncDelete = make_cwd_safe(MediaManager.syncDelete) + + +def unpatch_mediamanager(): + """Undoes monkey patches to Anki's MediaManager.""" + + MediaManager.findChanges = mediamanager_orig_funcs["findChanges"] + MediaManager.mediaChangesZip = mediamanager_orig_funcs["mediaChangesZip"] + MediaManager.addFilesFromZip = mediamanager_orig_funcs["addFilesFromZip"] + MediaManager.syncDelete = mediamanager_orig_funcs["syncDelete"] + + mediamanager_orig_funcs["findChanges"] = None + mediamanager_orig_funcs["mediaChangesZip"] = None + mediamanager_orig_funcs["mediaChangesZip"] = None + mediamanager_orig_funcs["mediaChangesZip"] = None + + +def monkeypatch_db(): + """ + Monkey patches Anki's DB.__init__ to connect to allow access to the db + connection from more than one thread, so that we can inspect and modify + the db created in the app in our test code. + """ + db_orig_funcs["__init__"] = DB.__init__ + + def patched___init__(self, path, text=None, timeout=0): + # Code taken from Anki's DB.__init__() + encpath = path + if isinstance(encpath, unicode): + encpath = path.encode("utf-8") + # Allow more than one thread to use this connection. + self._db = sqlite.connect(encpath, + timeout=timeout, + check_same_thread=False) + if text: + self._db.text_factory = text + self._path = path + self.echo = os.environ.get("DBECHO") # echo db modifications + self.mod = False # flag that db has been modified? + + DB.__init__ = patched___init__ + + +def unpatch_db(): + """Undoes monkey patches to Anki's DB.""" + + DB.__init__ = db_orig_funcs["__init__"] + db_orig_funcs["__init__"] = None diff --git a/tests/helpers/server_utils.py b/tests/helpers/server_utils.py new file mode 100644 index 0000000..6fb58cf --- /dev/null +++ b/tests/helpers/server_utils.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + + +import filecmp +import logging +import os +from paste.deploy.loadwsgi import appconfig +import shutil + + +from AnkiServer.apps.sync_app import make_app as make_sync_app +from AnkiServer.apps.sync_app import SyncCollectionHandler, SyncMediaHandler +from AnkiServer.apps.rest_app import make_app as make_rest_app +from helpers.file_utils import FileUtils + + +class ServerUtils(object): + def __init__(self): + self.fileutils = FileUtils() + + def clean_up(self): + self.fileutils.clean_up() + + def create_server_paths(self): + """ + Creates temporary files and dirs for our app to use during tests. + """ + + auth = self.fileutils.create_file_path(suffix='.db', + prefix='ankiserver_auth_db_') + session = self.fileutils.create_file_path(suffix='.db', + prefix='ankiserver_session_db_') + data = self.fileutils.create_dir(suffix='', + prefix='ankiserver_data_root_') + return { + "auth_db": auth, + "session_db": session, + "data_root": data + } + + @staticmethod + def _create_server_app(server_paths, config_path, make_app_func=None): + settings = appconfig("config:{}".format(config_path), "sync_app") + + # Use custom files and dirs in settings. + settings.local_conf["auth_db_path"] = server_paths["auth_db"] + settings.local_conf["session_db_path"] = server_paths["session_db"] + settings.local_conf["data_root"] = server_paths["data_root"] + + server_app = make_app_func(settings.global_conf, **settings.local_conf) + return server_app + + @staticmethod + def create_server_sync_app(server_paths, config_path): + return ServerUtils._create_server_app(server_paths, + config_path, + make_sync_app) + + @staticmethod + def create_server_rest_app(server_paths, config_path): + return ServerUtils._create_server_app(server_paths, + config_path, + make_rest_app) + + def get_session_for_hkey(self, server, hkey): + return server.session_manager.load(hkey) + + def get_thread_for_hkey(self, server, hkey): + session = self.get_session_for_hkey(server, hkey) + thread = session.get_thread() + return thread + + def get_col_wrapper_for_hkey(self, server, hkey): + print("getting col wrapper for hkey " + hkey) + print("all session keys: " + str(server.session_manager.sessions.keys())) + thread = self.get_thread_for_hkey(server, hkey) + col_wrapper = thread.wrapper + return col_wrapper + + def get_col_for_hkey(self, server, hkey): + col_wrapper = self.get_col_wrapper_for_hkey(server, hkey) + col_wrapper.open() # Make sure the col is opened. + return col_wrapper._CollectionWrapper__col + + def get_col_db_path_for_hkey(self, server, hkey): + col = self.get_col_for_hkey(server, hkey) + return col.db._path + + def get_syncer_for_hkey(self, server, hkey, syncer_type='collection'): + col = self.get_col_for_hkey(server, hkey) + + session = self.get_session_for_hkey(server, hkey) + + syncer_type = syncer_type.lower() + if syncer_type == 'collection': + handler_method = SyncCollectionHandler.operations[0] + elif syncer_type == 'media': + handler_method = SyncMediaHandler.operations[0] + + return session.get_handler_for_operation(handler_method, col) + + def add_files_to_mediasyncer(self, + media_syncer, + filepaths, + update_db=False, + bump_last_usn=False): + """ + If bumpLastUsn is True, the media syncer's lastUsn will be incremented + once for each added file. Use this when adding files to the server. + """ + + for filepath in filepaths: + logging.debug("Adding file '{}' to mediaSyncer".format(filepath)) + # Import file into media dir. + media_syncer.col.media.addFile(filepath) + if bump_last_usn: + # Need to bump lastUsn once for each file. + media_manager = media_syncer.col.media + media_manager.setLastUsn(media_syncer.col.media.lastUsn() + 1) + + if update_db: + media_syncer.col.media.findChanges() # Write changes to db. diff --git a/tests/sync_app_functional_media_test.py b/tests/sync_app_functional_media_test.py new file mode 100644 index 0000000..7cce65c --- /dev/null +++ b/tests/sync_app_functional_media_test.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- + + +import filecmp +import os + + +from anki.sync import Syncer, MediaSyncer +from helpers.mock_servers import MockRemoteMediaServer +from helpers.monkey_patches import monkeypatch_mediamanager, unpatch_mediamanager +from sync_app_functional_test_base import SyncAppFunctionalTestBase + + +class SyncAppFunctionalMediaTest(SyncAppFunctionalTestBase): + + def setUp(self): + SyncAppFunctionalTestBase.setUp(self) + + monkeypatch_mediamanager() + self.hkey = self.mock_remote_server.hostKey("testuser", "testpassword") + client_collection = self.colutils.create_empty_col() + self.client_syncer = self.create_client_syncer(client_collection, + self.hkey, + self.server_test_app) + + def tearDown(self): + self.hkey = None + self.client_syncer = None + unpatch_mediamanager() + SyncAppFunctionalTestBase.tearDown(self) + + @staticmethod + def create_client_syncer(collection, hkey, server_test_app): + mock_remote_server = MockRemoteMediaServer(col=collection, + hkey=hkey, + server_test_app=server_test_app) + media_syncer = MediaSyncer(col=collection, + server=mock_remote_server) + return media_syncer + + def test_sync_empty_media_dbs(self): + # With both the client and the server having no media to sync, + # syncing should change nothing. + self.assertEqual('noChanges', self.client_syncer.sync()) + self.assertEqual('noChanges', self.client_syncer.sync()) + + def test_sync_file_from_server(self): + """ + Adds a file on the server. After syncing, client and server should have + the identical file in their media directories and media databases. + """ + client = self.client_syncer + server = self.serverutils.get_syncer_for_hkey(self.server_app, + self.hkey, + 'media') + + # Create a test file. + temp_file_path = self.fileutils.create_named_file(u"foo.jpg", "hello") + + # Add the test file to the server's collection. + self.serverutils.add_files_to_mediasyncer(server, + [temp_file_path], + update_db=True, + bump_last_usn=True) + + # Syncing should work. + self.assertEqual(client.sync(), 'OK') + + # The test file should be present in the server's and in the client's + # media directory. + self.assertTrue( + filecmp.cmp(os.path.join(client.col.media.dir(), u"foo.jpg"), + os.path.join(server.col.media.dir(), u"foo.jpg"))) + + # Further syncing should do nothing. + self.assertEqual(client.sync(), 'noChanges') + + def test_sync_file_from_client(self): + """ + Adds a file on the client. After syncing, client and server should have + the identical file in their media directories and media databases. + """ + join = os.path.join + client = self.client_syncer + server = self.serverutils.get_syncer_for_hkey(self.server_app, + self.hkey, + 'media') + + # Create a test file. + temp_file_path = self.fileutils.create_named_file(u"foo.jpg", "hello") + + # Add the test file to the client's media collection. + self.serverutils.add_files_to_mediasyncer(client, + [temp_file_path], + update_db=True, + bump_last_usn=False) + + # Syncing should work. + self.assertEqual(client.sync(), 'OK') + + # The same file should be present in both the client's and the server's + # media directory. + self.assertTrue(filecmp.cmp(join(client.col.media.dir(), u"foo.jpg"), + join(server.col.media.dir(), u"foo.jpg"))) + + # Further syncing should do nothing. + self.assertEqual(client.sync(), 'noChanges') + + # Except for timestamps, the media databases of client and server + # should be identical. + self.assertFalse( + self.dbutils.media_dbs_differ(client.col.media.db._path, + server.col.media.db._path) + ) + + def test_sync_different_files(self): + """ + Adds a file on the client and a file with different name and content on + the server. After syncing, both client and server should have both + files in their media directories and databases. + """ + join = os.path.join + isfile = os.path.isfile + client = self.client_syncer + server = self.serverutils.get_syncer_for_hkey(self.server_app, + self.hkey, + 'media') + + # Create two files and add one to the server and one to the client. + file_for_client, file_for_server = self.fileutils.create_named_files([ + (u"foo.jpg", "hello"), + (u"bar.jpg", "goodbye") + ]) + + self.serverutils.add_files_to_mediasyncer(client, + [file_for_client], + update_db=True) + self.serverutils.add_files_to_mediasyncer(server, + [file_for_server], + update_db=True, + bump_last_usn=True) + + # Syncing should work. + self.assertEqual(client.sync(), 'OK') + + # Both files should be present in the client's and in the server's + # media directories. + self.assertTrue(isfile(join(client.col.media.dir(), u"foo.jpg"))) + self.assertTrue(isfile(join(server.col.media.dir(), u"foo.jpg"))) + self.assertTrue(filecmp.cmp( + join(client.col.media.dir(), u"foo.jpg"), + join(server.col.media.dir(), u"foo.jpg")) + ) + self.assertTrue(isfile(join(client.col.media.dir(), u"bar.jpg"))) + self.assertTrue(isfile(join(server.col.media.dir(), u"bar.jpg"))) + self.assertTrue(filecmp.cmp( + join(client.col.media.dir(), u"bar.jpg"), + join(server.col.media.dir(), u"bar.jpg")) + ) + + # Further syncing should change nothing. + self.assertEqual(client.sync(), 'noChanges') + + def test_sync_different_contents(self): + """ + Adds a file to the client and a file with identical name but different + contents to the server. After syncing, both client and server should + have the server's version of the file in their media directories and + databases. + """ + join = os.path.join + isfile = os.path.isfile + client = self.client_syncer + server = self.serverutils.get_syncer_for_hkey(self.server_app, + self.hkey, + 'media') + + # Create two files with identical names but different contents and + # checksums. Add one to the server and one to the client. + file_for_client, file_for_server = self.fileutils.create_named_files([ + (u"foo.jpg", "hello"), + (u"foo.jpg", "goodbye") + ]) + + self.serverutils.add_files_to_mediasyncer(client, + [file_for_client], + update_db=True) + self.serverutils.add_files_to_mediasyncer(server, + [file_for_server], + update_db=True, + bump_last_usn=True) + + # Syncing should work. + self.assertEqual(client.sync(), 'OK') + + # A version of the file should be present in both the client's and the + # server's media directory. + self.assertTrue(isfile(join(client.col.media.dir(), u"foo.jpg"))) + self.assertEqual(os.listdir(client.col.media.dir()), ['foo.jpg']) + self.assertTrue(isfile(join(server.col.media.dir(), u"foo.jpg"))) + self.assertEqual(os.listdir(server.col.media.dir()), ['foo.jpg']) + self.assertEqual(client.sync(), 'noChanges') + + # Both files should have the contents of the server's version. + _checksum = client.col.media._checksum + self.assertEqual(_checksum(join(client.col.media.dir(), u"foo.jpg")), + _checksum(file_for_server)) + self.assertEqual(_checksum(join(server.col.media.dir(), u"foo.jpg")), + _checksum(file_for_server)) + + def test_sync_add_and_delete_on_client(self): + """ + Adds a file on the client. After syncing, the client and server should + both have the file. Then removes the file from the client's directory + and marks it as deleted in its database. After syncing again, the + server should have removed its version of the file from its media dir + and marked it as deleted in its db. + """ + join = os.path.join + isfile = os.path.isfile + client = self.client_syncer + server = self.serverutils.get_syncer_for_hkey(self.server_app, + self.hkey, + 'media') + + # Create a test file. + temp_file_path = self.fileutils.create_named_file(u"foo.jpg", "hello") + + # Add the test file to client's media collection. + self.serverutils.add_files_to_mediasyncer(client, + [temp_file_path], + update_db=True, + bump_last_usn=False) + + # Syncing client should work. + self.assertEqual(client.sync(), 'OK') + + # The same file should be present in both client's and the server's + # media directory. + self.assertTrue(filecmp.cmp(join(client.col.media.dir(), u"foo.jpg"), + join(server.col.media.dir(), u"foo.jpg"))) + + # Syncing client again should do nothing. + self.assertEqual(client.sync(), 'noChanges') + + # Remove files from client's media dir and write changes to its db. + os.remove(join(client.col.media.dir(), u"foo.jpg")) + + # TODO: client.col.media.findChanges() doesn't work here - why? + client.col.media._logChanges() + self.assertEqual(client.col.media.syncInfo(u"foo.jpg"), (None, 1)) + self.assertFalse(isfile(join(client.col.media.dir(), u"foo.jpg"))) + + # Syncing client again should work. + self.assertEqual(client.sync(), 'OK') + + # server should have picked up the removal from client. + self.assertEqual(server.col.media.syncInfo(u"foo.jpg"), (None, 0)) + self.assertFalse(isfile(join(server.col.media.dir(), u"foo.jpg"))) + + # Syncing client again should do nothing. + self.assertEqual(client.sync(), 'noChanges') + + def test_sync_compare_database_to_expected(self): + """ + Adds a test image file to the client's media directory. After syncing, + the server's database should, except for timestamps, be identical to a + database containing the expected data. + """ + client = self.client_syncer + + # Add a test image file to the client's media collection but don't + # update its media db since the desktop client updates that, using + # findChanges(), only during syncs. + support_file = self.fileutils.get_asset_path(u'blue.jpg') + self.assertTrue(os.path.isfile(support_file)) + self.serverutils.add_files_to_mediasyncer(client, + [support_file], + update_db=False) + + # Syncing should work. + self.assertEqual(client.sync(), "OK") + + # Create temporary db file with expected results. + chksum = client.col.media._checksum(support_file) + sql = (""" + CREATE TABLE meta (dirMod int, lastUsn int); + + INSERT INTO `meta` (dirMod, lastUsn) VALUES (123456789,1); + + CREATE TABLE media ( + fname text not null primary key, + csum text, + mtime int not null, + dirty int not null + ); + + INSERT INTO `media` (fname, csum, mtime, dirty) VALUES ( + 'blue.jpg', + '%s', + 1441483037, + 0 + ); + + CREATE INDEX idx_media_dirty on media (dirty); + """ % chksum) + + temp_db_path = self.dbutils.create_sqlite_db_with_sql(sql) + + # Except for timestamps, the client's db after sync should be identical + # to the expected data. + self.assertFalse(self.dbutils.media_dbs_differ( + client.col.media.db._path, + temp_db_path + )) diff --git a/tests/sync_app_functional_test_base.py b/tests/sync_app_functional_test_base.py new file mode 100644 index 0000000..54abdcd --- /dev/null +++ b/tests/sync_app_functional_test_base.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + + +import os +import unittest +from webtest import TestApp + + +from AnkiServer.user_managers import SqliteUserManager +from helpers.collection_utils import CollectionUtils +from helpers.db_utils import DBUtils +from helpers.file_utils import FileUtils +from helpers.mock_servers import MockRemoteServer +from helpers.monkey_patches import monkeypatch_db, unpatch_db +from helpers.server_utils import ServerUtils + + +class SyncAppFunctionalTestBase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.fileutils = FileUtils() + cls.colutils = CollectionUtils() + cls.serverutils = ServerUtils() + cls.dbutils = DBUtils() + + @classmethod + def tearDownClass(cls): + cls.fileutils.clean_up() + cls.fileutils = None + + cls.colutils.clean_up() + cls.colutils = None + + cls.serverutils.clean_up() + cls.serverutils = None + + cls.dbutils.clean_up() + cls.dbutils = None + + def setUp(self): + monkeypatch_db() + + # Create temporary files and dirs the server will use. + self.server_paths = self.serverutils.create_server_paths() + + # Add a test user to the temp auth db the server will use. + self.user_manager = SqliteUserManager(self.server_paths['auth_db'], + self.server_paths['data_root']) + self.user_manager.add_user('testuser', 'testpassword') + + # Get absolute path to test ini file. + script_dir = os.path.dirname(os.path.realpath(__file__)) + ini_file_path = os.path.join(script_dir, + os.pardir, + "test.ini") + + # Create SyncApp instance using the dev ini file and the temporary + # paths. + self.server_app = self.serverutils.create_server_sync_app(self.server_paths, + ini_file_path) + + # Wrap the SyncApp object in TestApp instance for testing. + self.server_test_app = TestApp(self.server_app) + + # MockRemoteServer instance needed for testing normal collection + # syncing and for retrieving hkey for other tests. + self.mock_remote_server = MockRemoteServer(hkey=None, + server_test_app=self.server_test_app) + + def tearDown(self): + self.server_paths = {} + self.user_manager = None + + # Shut down server. + self.server_app.collection_manager.shutdown() + self.server_app = None + + self.client_server_connection = None + + unpatch_db() diff --git a/tests/test_sync_app.py b/tests/test_sync_app.py index 6fef6fb..0582869 100644 --- a/tests/test_sync_app.py +++ b/tests/test_sync_app.py @@ -57,84 +57,6 @@ def test_meta(self): self.assertEqual(meta['cont'], True) -class SimpleUserManagerTest(unittest.TestCase): - _good_test_un = 'username' - _good_test_pw = 'password' - - _bad_test_un = 'notAUsername' - _bad_test_pw = 'notAPassword' - - def setUp(self): - self._user_manager = SimpleUserManager() - - def tearDown(self): - self._user_manager = None - - def test_authenticate(self): - self.assertTrue(self._user_manager.authenticate(self._good_test_un, - self._good_test_pw)) - - self.assertTrue(self._user_manager.authenticate(self._bad_test_un, - self._bad_test_pw)) - - self.assertTrue(self._user_manager.authenticate(self._good_test_un, - self._bad_test_pw)) - - self.assertTrue(self._user_manager.authenticate(self._bad_test_un, - self._good_test_pw)) - - def test_username2dirname(self): - dirname = self._user_manager.username2dirname(self._good_test_un) - self.assertEqual(dirname, self._good_test_un) - - -class SqliteUserManagerTest(SimpleUserManagerTest): - file_descriptor, _test_auth_db_path = tempfile.mkstemp(suffix=".db") - os.close(file_descriptor) - os.unlink(_test_auth_db_path) - - def _create_test_auth_db(self, db_path, username, password): - if os.path.exists(db_path): - os.remove(db_path) - - salt = binascii.b2a_hex(os.urandom(8)) - crypto_hash = hashlib.sha256(username+password+salt).hexdigest()+salt - - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - cursor.execute("""CREATE TABLE IF NOT EXISTS auth - (user VARCHAR PRIMARY KEY, hash VARCHAR)""") - - cursor.execute("INSERT INTO auth VALUES (?, ?)", (username, crypto_hash)) - - conn.commit() - conn.close() - - def setUp(self): - self._create_test_auth_db(self._test_auth_db_path, - self._good_test_un, - self._good_test_pw) - self._user_manager = SqliteUserManager(self._test_auth_db_path) - - def tearDown(self): - if os.path.exists(self._test_auth_db_path): - os.remove(self._test_auth_db_path) - - def test_authenticate(self): - self.assertTrue(self._user_manager.authenticate(self._good_test_un, - self._good_test_pw)) - - self.assertFalse(self._user_manager.authenticate(self._bad_test_un, - self._bad_test_pw)) - - self.assertFalse(self._user_manager.authenticate(self._good_test_un, - self._bad_test_pw)) - - self.assertFalse(self._user_manager.authenticate(self._bad_test_un, - self._good_test_pw)) - - class SimpleSessionManagerTest(unittest.TestCase): test_hkey = '1234567890' test_session = SyncUserSession('testName', 'testPath', None, None) diff --git a/tests/test_user_managers.py b/tests/test_user_managers.py new file mode 100644 index 0000000..a44a55f --- /dev/null +++ b/tests/test_user_managers.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- + + +import os +import unittest + + +from AnkiServer.user_managers import SimpleUserManager, SqliteUserManager +from helpers.file_utils import FileUtils + + +class SimpleUserManagerTest(unittest.TestCase): + def setUp(self): + self.user_manager = SimpleUserManager() + + def tearDown(self): + self._user_manager = None + + def test_authenticate(self): + good_test_un = 'username' + good_test_pw = 'password' + bad_test_un = 'notAUsername' + bad_test_pw = 'notAPassword' + + self.assertTrue(self.user_manager.authenticate(good_test_un, + good_test_pw)) + self.assertTrue(self.user_manager.authenticate(bad_test_un, + bad_test_pw)) + self.assertTrue(self.user_manager.authenticate(good_test_un, + bad_test_pw)) + self.assertTrue(self.user_manager.authenticate(bad_test_un, + good_test_pw)) + + def test_username2dirname(self): + username = 'my_username' + dirname = self.user_manager.username2dirname(username) + self.assertEqual(dirname, username) + + +class SqliteUserManagerTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.fileutils = FileUtils() + + @classmethod + def tearDownClass(cls): + cls.fileutils.clean_up() + cls.fileutils = None + + def setUp(self): + self.auth_db_path = self.fileutils.create_file_path(suffix='auth.db') + self.collection_path = self.fileutils.create_dir_path() + self.user_manager = SqliteUserManager(self.auth_db_path, + self.collection_path) + + def tearDown(self): + self.user_manager = None + + def test_auth_db_exists(self): + self.assertFalse(self.user_manager.auth_db_exists()) + + self.user_manager.create_auth_db() + self.assertTrue(self.user_manager.auth_db_exists()) + + os.unlink(self.auth_db_path) + self.assertFalse(self.user_manager.auth_db_exists()) + + def test_user_list(self): + username = "my_username" + password = "my_password" + self.user_manager.create_auth_db() + + self.assertEqual(self.user_manager.user_list(), []) + + self.user_manager.add_user(username, password) + self.assertEqual(self.user_manager.user_list(), [username]) + + def test_user_exists(self): + username = "my_username" + password = "my_password" + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + self.assertTrue(self.user_manager.user_exists(username)) + + self.user_manager.del_user(username) + self.assertFalse(self.user_manager.user_exists(username)) + + def test_del_user(self): + username = "my_username" + password = "my_password" + collection_dir_path = os.path.join(self.collection_path, username) + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + self.user_manager.del_user(username) + + # User should be gone. + self.assertFalse(self.user_manager.user_exists(username)) + # User's collection dir should still be there. + self.assertTrue(os.path.isdir(collection_dir_path)) + + def test_add_user(self): + username = "my_username" + password = "my_password" + expected_dir_path = os.path.join(self.collection_path, username) + self.user_manager.create_auth_db() + + self.assertFalse(os.path.exists(expected_dir_path)) + + self.user_manager.add_user(username, password) + + # User db entry and collection dir should be present. + self.assertTrue(self.user_manager.user_exists(username)) + self.assertTrue(os.path.isdir(expected_dir_path)) + + def test_add_users(self): + users_data = [("my_first_username", "my_first_password"), + ("my_second_username", "my_second_password")] + self.user_manager.create_auth_db() + self.user_manager.add_users(users_data) + + user_list = self.user_manager.user_list() + self.assertIn("my_first_username", user_list) + self.assertIn("my_second_username", user_list) + self.assertTrue(os.path.isdir(os.path.join(self.collection_path, + "my_first_username"))) + self.assertTrue(os.path.isdir(os.path.join(self.collection_path, + "my_second_username"))) + + def test__add_user_to_auth_db(self): + username = "my_username" + password = "my_password" + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + + self.assertTrue(self.user_manager.user_exists(username)) + + def test_create_auth_db(self): + self.assertFalse(os.path.exists(self.auth_db_path)) + self.user_manager.create_auth_db() + self.assertTrue(os.path.isfile(self.auth_db_path)) + + def test__create_user_dir(self): + username = "my_username" + expected_dir_path = os.path.join(self.collection_path, username) + self.assertFalse(os.path.exists(expected_dir_path)) + self.user_manager._create_user_dir(username) + self.assertTrue(os.path.isdir(expected_dir_path)) + + def test_authenticate_user(self): + username = "my_username" + password = "my_password" + + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + + self.assertTrue(self.user_manager.authenticate_user(username, + password)) + + def test_set_password_for_user(self): + username = "my_username" + password = "my_password" + new_password = "my_new_password" + + self.user_manager.create_auth_db() + self.user_manager.add_user(username, password) + + self.user_manager.set_password_for_user(username, new_password) + self.assertFalse(self.user_manager.authenticate_user(username, + password)) + self.assertTrue(self.user_manager.authenticate_user(username, + new_password)) +