Skip to content

Commit

Permalink
refactor(tests): update tests suite
Browse files Browse the repository at this point in the history
  • Loading branch information
yshalenyk committed Jun 14, 2024
1 parent d1a931f commit f97ee30
Show file tree
Hide file tree
Showing 8 changed files with 512 additions and 55 deletions.
6 changes: 0 additions & 6 deletions docs/index.rst

This file was deleted.

1 change: 1 addition & 0 deletions nightingale/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Datasource:
class Publishing:
version: str
publisher: str
base_uri: str


@dataclass(frozen=True)
Expand Down
16 changes: 6 additions & 10 deletions nightingale/loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
import sqlite3

import pandas as pd

logger = logging.getLogger(__name__)

# TODO: postgresql support
Expand All @@ -11,8 +9,9 @@
class DataLoader:
"""Load data from a database using a SQL query"""

def __init__(self, config):
def __init__(self, config, connection=None):
self.config = config
self._connection = connection

def load(self, selector):
cursor = self.get_cursor()
Expand All @@ -26,8 +25,12 @@ def get_cursor(self):
return cursor

def get_connection(self):
if self._connection:
return self._connection
conn = sqlite3.connect(self.config.connection)
conn.row_factory = sqlite3.Row
# for tests
self._connection = conn
logger.info(f"Connected to {self.config.connection}")
return conn

Expand Down Expand Up @@ -60,10 +63,3 @@ def validate_selector(self, data_elements):
for column in columns:
if column not in labels_for_mapping:
logger.warning(f"Column {column} is not mapped")


class PDLoader(DataLoader):
"""Load data from sql to a pandas dataframe"""

def load(self, selector):
self.data = pd.read_sql(selector, self.get_connection())
9 changes: 2 additions & 7 deletions nightingale/mapping/v1/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_schema_sheet(self):
def normmalize_mapping_column(self, mappings):
"""Normalize the mapping column by setting all space separators to one space"""
for mapping in mappings:
mapping["mapping"] = mapping["mapping"].replace(" ", " ")
if " " in mapping["mapping"]:
mapping["mapping"] = " ".join((p.strip() for p in mapping["mapping"].split(" ")))
return mappings

def read_mapping_sheet(self, sheet):
Expand Down Expand Up @@ -174,12 +175,6 @@ def get_paths_for_mapping(self, key, force_publish=False):
def is_array_path(self, path):
return self.schema.get(path, {}).get("type") == "array"

def is_in_array(self, keys):
for path in ("/" + "/".join(keys[: i + 1]) for i in range(len(keys))):
if self.is_array_path(path):
return path
return False

def get_arrays(self):
result = []
for path, schema in self.schema.items():
Expand Down
54 changes: 29 additions & 25 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
import sqlite3
import unittest
from unittest.mock import MagicMock, patch

from nightingale.config import Datasources
from nightingale.config import Datasource
from nightingale.loader import DataLoader


class TestDataLoader(unittest.TestCase):
def setUp(self):
self.config = Datasources(connection="test.db", selector="SELECT * FROM test_table")
self.loader = DataLoader(self.config)

@patch("sqlite3.connect")
def test_get_connection(self, mock_connect):
self.loader.get_connection()
mock_connect.assert_called_once_with(self.config.connection)

@patch("sqlite3.connect")
def test_get_cursor(self, mock_connect):
mock_cursor = MagicMock()
mock_connect.return_value.cursor.return_value = mock_cursor
cursor = self.loader.get_cursor()
self.assertEqual(cursor, mock_cursor)

@patch.object(DataLoader, "get_cursor")
def test_load(self, mock_get_cursor):
mock_cursor = MagicMock()
mock_get_cursor.return_value = mock_cursor
mock_cursor.fetchall.return_value = [{"column1": "value1"}]
self.loader.load()
mock_cursor.execute.assert_called_once_with(self.config.selector)
self.assertEqual(self.loader.data, [{"column1": "value1"}])
# Using an in-memory SQLite database for testing
self.config = Datasource(connection=":memory:")
self.connection = sqlite3.connect(self.config.connection)
self.connection.row_factory = sqlite3.Row
self.cursor = self.connection.cursor()

# Set up the in-memory database and table for testing
self.cursor.execute("CREATE TABLE test_table (column1 TEXT)")
self.cursor.execute("INSERT INTO test_table (column1) VALUES ('value1')")
self.connection.commit()

# Inject the in-memory connection into DataLoader
self.loader = DataLoader(self.config, connection=self.connection)

def tearDown(self):
# Close the connection after each test
self.connection.close()

def test_get_connection(self):
loader = DataLoader(self.config)
connection = loader.get_connection()
self.assertIsNotNone(connection)
self.assertEqual(connection.execute("PRAGMA database_list").fetchall()[0]["file"], "")

def test_load(self):
data = self.loader.load("SELECT * FROM test_table")
self.assertEqual(data, [{"column1": "value1"}])


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f97ee30

Please sign in to comment.