Skip to content

Commit

Permalink
updated common.py to ruff standard
Browse files Browse the repository at this point in the history
  • Loading branch information
Tjalling-dejong committed Dec 3, 2024
1 parent ad3ac07 commit e60fdf2
Showing 1 changed file with 102 additions and 92 deletions.
194 changes: 102 additions & 92 deletions fm2prof/common.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-
"""
Base classes and data containers
"""
"""Base classes and data containers."""
from __future__ import annotations

import logging

# Imports from standard library
import os
from datetime import datetime
from logging import Logger, LogRecord
from pathlib import Path
from time import time
from typing import AnyStr, Mapping
from typing import TYPE_CHECKING

import colorama

Expand All @@ -20,16 +18,18 @@
from colorama import Back, Fore, Style

# Import from package
# none
if TYPE_CHECKING:
from fm2prof.ini_file import IniFile


IniFile = "fm2prof.IniFile.IniFile"


class TqdmLoggingHandler(logging.StreamHandler):
def __init__(self):
"""Logging handler for tqdm package."""
def __init__(self) -> None:
super().__init__()

def emit(self, record):
def emit(self, record: LogRecord):
try:
msg = self.format(record)
if self.formatter.pbar:
Expand All @@ -39,15 +39,15 @@ def emit(self, record):
stream.write(msg + self.terminator)

self.flush()
except Exception as e:
except Exception:
self.handleError(record)


class ElapsedFormatter:
__new_iteration = 1

def __init__(self):
self.start_time = time()
self.start_time = datetime.now()
self.number_of_iterations: int = 1
self.current_iteration: int = 0
self._pbar: tqdm.tqdm = None
Expand All @@ -66,44 +66,43 @@ def __init__(self):
self._loglibrary: dict = {"ERROR": 0, "WARNING": 0}

@property
def pbar(self):
def pbar(self) -> None | tqdm.tqdm | tqdm.std.tqdm:
"""Progress bar."""
return self._pbar

@pbar.setter
def pbar(self, pbar):
def pbar(self, pbar: tqdm.std.tqdm | tqdm.tqdm | None) -> None:
"""Set progress bar."""
if isinstance(pbar, (tqdm.std.tqdm, type(None))):
self._pbar = pbar
else:
raise ValueError
raise TypeError

def format(self, record):
def format(self, record: LogRecord) -> str:
"""Format log record."""
if self._intro:
return self.__format_intro(record)
if self.__new_iteration > 0:
return self.__format_header(record)
if self.__new_iteration == -1:
return self.__format_footer(record)
else:
return self.__format_message(record)
return self.__format_message(record)

def __format_intro(self, record: LogRecord):
def __format_intro(self, record: LogRecord) -> str:
return f"{record.getMessage()}"

def __format_header(self, record: LogRecord):
"""Formats the header of a new task"""

def __format_header(self, record: LogRecord) -> str:
self.__new_iteration -= 1
message = record.getMessage()
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
return f"╔═════╣ {self._resetStyle}{current_time} {message}{self._resetStyle}"

def __format_footer(self, record: LogRecord):
def __format_footer(self, record: LogRecord) -> str:
self.__new_iteration -= 1
elapsed_seconds = record.created - self.start_time
message = record.getMessage()
return f"╚═════╣ {self._resetStyle}Task finished in {elapsed_seconds:.2f}sec{self._resetStyle}"

def __format_message(self, record: LogRecord):
def __format_message(self, record: LogRecord) -> str:
elapsed_seconds = record.created - self.start_time
color = self._colors

Expand All @@ -114,38 +113,46 @@ def __format_message(self, record: LogRecord):
if level in self._loglibrary:
self._loglibrary[level] += 1

formatted_string = (
return (
f"║ {color[level][0]} {level:>7} "
+ f"{self._resetStyle}{color[level][1]}{self._resetStyle} T+ {elapsed_seconds:.2f}s {message}"
f"{self._resetStyle}{color[level][1]}{self._resetStyle} T+ {elapsed_seconds:.2f}s {message}"
)

return formatted_string

def __reset(self):
def __reset(self) -> None:
self.start_time = time()

def start_new_iteration(self, pbar: tqdm.tqdm = None):
def start_new_iteration(self, pbar: tqdm.tqdm | None = None) -> None:
"""Start a new iteration with a progress bar."""
self.current_iteration += 1
self.new_task()
self.pbar = pbar

def new_task(self):
def new_task(self) -> None:
"""Reset ElapsedTimeFormatter."""
self.__new_iteration = 1
self.__reset()

def finish_task(self):
def finish_task(self) -> None:
"""Finish task."""
self.__new_iteration = -1

def set_number_of_iterations(self, n):
assert n > 0, "Total number of iterations should be higher than zero"
def set_number_of_iterations(self, n: int) -> None:
"""Set numbber of iterations."""
if n > 0:
err_msg = "Total number of iterations should be higher than zero"
raise ValueError(err_msg)
self.number_of_iterations = n

def set_intro(self, flag: bool = True):
def set_intro(self,*, flag: bool = True) -> None:
"""Indicate intro section for formatter."""
self._intro = flag

def get_elapsed_time(self):
current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
return current_time - self.start_time
def get_elapsed_time(self) -> float:
"""Get elapsed time in seconds."""
current_time = datetime.now()
elapsed_time = current_time - self.start_time
return elapsed_time.total_seconds()


class ElapsedFileFormatter(ElapsedFormatter):
Expand All @@ -160,18 +167,11 @@ def __init__(self):
"RESET": "",
}

@property
def pbar(self):
return self._pbar

@pbar.setter
def pbar(self, pbar):
self._pbar = None


class FM2ProfBase:
"""
Base class for FM2PROF types. Implements methods for logging, project specific parameters
"""Base class for FM2PROF types.
Implements methods for logging, project specific parameters
"""

__logger = None
Expand All @@ -182,64 +182,66 @@ class FM2ProfBase:
__copyright__ = "Copyright 2016-2020, University of Twente & Deltares"
__license__ = "LPGL"

def __init__(self, logger: Logger = None, inifile: IniFile = None):
def __init__(self, logger: Logger | None = None, inifile: IniFile | None= None):
if logger:
self.set_logger(logger)
if inifile:
self.set_inifile(inifile)

def _create_logger(self):
def _create_logger(self) -> None:
# Create logger
self.__logger = logging.getLogger(__name__)
self.__logger.setLevel(logging.DEBUG)

# create formatter
self.__logger.__logformatter = ElapsedFormatter()
self.__logger._Filelogformatter = ElapsedFileFormatter()
self.__logger.__logformatter = ElapsedFormatter() #noqa: SLF001
self.__logger._Filelogformatter = ElapsedFileFormatter() #noqa: SLF001

# create console handler
if TqdmLoggingHandler not in map(type, self.__logger.handlers):
ch = TqdmLoggingHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(self.__logger.__logformatter)
ch.setFormatter(self.__logger.__logformatter) #noqa: SLF001
self.__logger.addHandler(ch)

def get_logger(self) -> Logger:
"""Use this method to return logger object"""
"""Use this method to return logger object."""
return self.__logger

def set_logger(self, logger: Logger) -> None:
"""
Use to set logger
"""Use to set logger.
Parameters:
Args:
logger (Logger): Logger instance
"""
assert isinstance(logger, Logger), (
"" + "logger should be instance of Logger class"
)
if not isinstance(logger, Logger):
err_msg = "logger should be instance of Logger class"
raise TypeError(err_msg)
self.__logger = logger

def set_logger_message(
self, err_mssg: str = "", level: str = "info", header: bool = False
self, err_mssg: str = "", level: str = "info", *, header: bool = False,
) -> None:
"""Sets message to logger if this is set.
Arguments:
err_mssg {str} -- Error message to send to logger.
Args:
err_mssg (str, optional): Error message to log. Defaults to "".
level (str, optional): Log level. Defaults to "info".
header (bool, optional): Set error message as header. Defaults to False.
"""
if not self.__logger:
return

if header:
self.get_logformatter().set_intro(True)
self.get_logger()._Filelogformatter.set_intro(True)
self.get_logger()._Filelogformatter.set_intro(True) #noqa: SLF001
else:
self.get_logformatter().set_intro(False)
self.get_logger()._Filelogformatter.set_intro(False)
self.get_logger()._Filelogformatter.set_intro(False) #noqa: SLF001

if level.lower() not in ["info", "debug", "warning", "error", "critical"]:
self.__logger.error("{} is not valid logging level.".format(level.lower()))
err_msg = f"{level.lower()} is not valid logging level."
raise ValueError(err_msg)

if level.lower() == "info":
self.__logger.info(err_mssg)
Expand All @@ -253,10 +255,9 @@ def set_logger_message(
self.__logger.critical(err_mssg)

def start_new_log_task(
self, task_name: str = "NOT DEFINED", pbar: tqdm.tqdm = None
self, task_name: str = "NOT DEFINED", pbar: tqdm.tqdm = None,
) -> None:
"""
Use this method to start a new task. Will reset the internal clock.
"""Use this method to start a new task. Will reset the internal clock.
:param task_name: task name, will be displayed in log message
"""
Expand All @@ -265,8 +266,7 @@ def start_new_log_task(
self.set_logger_message(f"Starting new task: {task_name}")

def finish_log_task(self) -> None:
"""
Use this method to finish task.
"""Use this method to finish task.
:param task_name: task name, will be displayed in log message
"""
Expand All @@ -275,23 +275,31 @@ def finish_log_task(self) -> None:
self.pbar = None

def get_logformatter(self) -> ElapsedFormatter:
"""Returns formatter"""
return self.get_logger().__logformatter
"""Returns log formatter."""
return self.get_logger().__logformatter #noqa: SLF001

def get_filelogformatter(self) -> ElapsedFormatter:
"""Returns formatter"""
return self.get_logger()._Filelogformatter
"""Returns file log formatter."""
return self.get_logger()._Filelogformatter #noqa: SLF001

def set_logfile(self, output_dir: str, filename: str = "fm2prof.log") -> None:
def set_logfile(self, output_dir: str | Path, filename: str = "fm2prof.log") -> None:
"""Set log file.
Args:
output_dir (str): _description_
filename (str, optional): _description_. Defaults to "fm2prof.log".
"""
# create file handler
fh = logging.FileHandler(os.path.join(output_dir, filename), encoding="utf-8")
if not output_dir:
err_msg = "output_dir is required."
raise ValueError(err_msg)
fh = logging.FileHandler(Path(output_dir).joinpath(filename), encoding="utf-8")
fh.setLevel(logging.DEBUG)
fh.setFormatter(self.get_logger()._Filelogformatter)
fh.setFormatter(self.get_logger()._Filelogformatter) #noqa: SLF001
self.__logger.addHandler(fh)

def set_inifile(self, inifile: IniFile = None):
"""
Use this method to set configuration file object.
def set_inifile(self, inifile: IniFile = None) -> None:
"""Use this method to set configuration file object.
For loading from file, use ``load_inifile`` instead
Expand All @@ -301,28 +309,30 @@ def set_inifile(self, inifile: IniFile = None):
self.__iniFile = inifile

def get_inifile(self) -> IniFile:
""" "Use this method to get the inifile object"""
"""Get the inifile object."""
return self.__iniFile


class FrictionTable:
"""
Container for friction table
"""
"""Container for friction table."""

def __init__(self, level, friction):
def __init__(self, level: np.ndarray, friction: np.ndarray) -> None:
if self._validate_input(level, friction):
self.level = level
self.friction = friction

def interpolate(self, new_z):
def interpolate(self, new_z: np.ndarray) -> None:
self.friction = np.interp(new_z, self.level, self.friction)
self.level = new_z

@staticmethod
def _validate_input(level, friction):
assert isinstance(level, np.ndarray)
assert isinstance(friction, np.ndarray)
assert level.shape == friction.shape

def _validate_input(level: np.ndarray, friction: np.ndarray) -> bool:
if not isinstance(level, np.ndarray):
err_msg = f"level argument not of type {np.ndarray}."
raise TypeError(err_msg)
if not isinstance(friction, np.ndarray):
err_msg = f"friction argument not of type {np.ndarray}."
if level.shape != friction.shape:
err_msg = "level and friction arrays should have the same shape."
raise ValueError(err_msg)
return True

0 comments on commit e60fdf2

Please sign in to comment.