diff --git a/heidgaf/cache.py b/heidgaf/cache.py index 673f2fd..94601a4 100644 --- a/heidgaf/cache.py +++ b/heidgaf/cache.py @@ -1,10 +1,34 @@ import logging -from typing import Any import polars as pl import redis +class StringRedisCache(object): + def __init__(self ,redis_host="localhost", redis_port=6379, redis_db=0, redis_max_connections=20,) -> None: + logging.debug("Connect to Redis server") + self.pool = redis.ConnectionPool(host=redis_host, port=redis_port, db=redis_db, max_connections=redis_max_connections) + self.redis_client = redis.Redis(connection_pool=self.pool) + self.redis_client.ping() + + def __getitem__(self, key: str) -> pl.DataFrame: + if self.redis_client.exists(key): + return self.redis_client.get(key) + else: + return "" + + def __contains__(self, key): + return self.redis_client.exists(key) + + def __str__(self): + return f'Redis has stored following keys: {self.redis_client.keys}' + + def __setitem__(self, key: str, value: str) -> pl.DataFrame: + self.redis_client.set(key, value) + + def __delitem__(self, key: str) -> None: + self.redis_client.delete(key) + class DataFrameRedisCache(object): def __init__(self ,redis_host="localhost", redis_port=6379, redis_db=0, redis_max_connections=20,) -> None: logging.debug("Connect to Redis server") diff --git a/heidgaf/dataset/__init__.py b/heidgaf/dataset/__init__.py index 62b9f85..6b117f4 100644 --- a/heidgaf/dataset/__init__.py +++ b/heidgaf/dataset/__init__.py @@ -52,8 +52,6 @@ def __init__(self, data_path: Any, cast_dataset: Callable = None) -> None: else: self.data = pl.read_csv(data_path) - logging.info(self.data) - self.X_train, self.X_val, self.X_test, self.Y_train, self.Y_val, self.Y_test = self.__train_test_val_split() def __len__(self): diff --git a/heidgaf/dataset/majestic.py b/heidgaf/dataset/majestic.py deleted file mode 100644 index eff2a5e..0000000 --- a/heidgaf/dataset/majestic.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -import string - -import polars as pl -from torch.utils.data.dataset import Dataset - -from heidgaf.cache import DataFrameRedisCache - - -class MajesticMillionDataset(Dataset): - def __init__(self, csv_file: str = "/home/smachmeier/projects/heiDGA/data/majestic_million/majestic_million.csv") -> None: - self.data = pl.read_csv(csv_file) - - def __len__(self) -> int: - return len(self.data) - - def __getitem__(self, idx: int) -> pl.DataFrame: - return self.data[idx, 0] - - def __call__(self, name: str, key: str) -> pl.DataFrame: - return self.data.filter(pl.col(key) == name) \ No newline at end of file diff --git a/heidgaf/main.py b/heidgaf/main.py index 7347407..9717fc5 100644 --- a/heidgaf/main.py +++ b/heidgaf/main.py @@ -6,7 +6,7 @@ import polars as pl from click import Path -from heidgaf.cache import DataFrameRedisCache +from heidgaf.cache import DataFrameRedisCache, StringRedisCache from heidgaf.post.feature import Preprocessor from heidgaf.pre.domain_analyzer import DomainAnalyzer from heidgaf.pre.ip_analyzer import IPAnalyzer @@ -25,18 +25,16 @@ class Separator(Enum): class DNSAnalyzerPipeline: def __init__(self, path: Path, redis_host="localhost", redis_port=6379, redis_db=0, redis_max_connections=20, filetype=FileType.TXT, separator=Separator.SPACE) -> None: - self.redis_cache = DataFrameRedisCache(redis_host, redis_port, redis_db, redis_max_connections) + self.df_cache = DataFrameRedisCache(redis_host, redis_port, redis_db, redis_max_connections) + self.string_cache = StringRedisCache(redis_host, redis_port, redis_db, redis_max_connections) if os.path.isfile(path): logging.debug(f"Processing files: {path}") self.data = self.load_data(path, separator.value) elif os.path.isdir(path): - # TODO Handle large files, Currently redis cannot store more than 512 MB logging.debug(f"Processing files: {path}/*.{filetype.value}") self.data = self.load_data(f'{path}/*.{filetype.value}', separator.value) - # self.redis_cache["data"] = self.data - def load_data(self, path, separator): dataframes = pl.read_csv(path, separator=separator, try_parse_dates=False, has_header=False).with_columns( [ @@ -44,7 +42,7 @@ def load_data(self, path, separator): ] ) - dataframes = dataframes.rename( + dataframes = dataframes[:50000].rename( { "column_1": "timestamp", "column_2": "return_code", @@ -60,11 +58,12 @@ def load_data(self, path, separator): return dataframes def run(self): + # Running modules to analyze log files # TODO Multithreading preprocessor = Preprocessor(features_to_drop=[]) processed_data = preprocessor.transform(self.data) - IPAnalyzer.run(processed_data, self.redis_cache) - DomainAnalyzer.run(processed_data, self.redis_cache) - TimeAnalyzer.run(processed_data, self.redis_cache) + IPAnalyzer.run(processed_data, self.df_cache) + DomainAnalyzer.run(processed_data, self.df_cache) + TimeAnalyzer.run(processed_data, self.df_cache) diff --git a/heidgaf/models/__init__.py b/heidgaf/models/__init__.py index 2ed2d04..a89b44d 100644 --- a/heidgaf/models/__init__.py +++ b/heidgaf/models/__init__.py @@ -18,13 +18,13 @@ def __init__(self, def fit(self, x_train, y_train): x_train = self.preprocessor.transform(x=x_train) - # x_train = self.target_encoder.fit_transform(x=x_train, y=y_train) - # x_train = self.mean_imputer.fit_transform(x=x_train) + x_train = self.target_encoder.fit_transform(x=x_train, y=y_train) + x_train = self.mean_imputer.fit_transform(x=x_train) self.clf.fit(x=x_train.to_numpy(), y=y_train) def predict(self, x): x = self.preprocessor.transform(x=x) - # x = self.target_encoder.transform(x=x) - # x = self.mean_imputer.transform(x=x) + x = self.target_encoder.transform(x=x) + x = self.mean_imputer.transform(x=x) return self.clf.predict(x=x.to_numpy()) diff --git a/heidgaf/models/xgboost2.py b/heidgaf/models/xgboost2.py deleted file mode 100644 index e69de29..0000000 diff --git a/heidgaf/post/feature.py b/heidgaf/post/feature.py index 421a969..edd75fb 100644 --- a/heidgaf/post/feature.py +++ b/heidgaf/post/feature.py @@ -1,4 +1,6 @@ +import math +from string import ascii_lowercase as alc from typing import List import polars as pl @@ -15,8 +17,6 @@ def __init__(self, features_to_drop: List): feature_to_drop (list): list of feature to drop """ self.features_to_drop = features_to_drop - # TODO Set majestic million score - self.majesticmillion = MajesticMillionDataset() def transform(self, x: pl.DataFrame) -> pl.DataFrame: """Transform our dataset with new features @@ -35,58 +35,80 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame: (pl.col("query").str.strip_chars(".").str.len_chars().alias("label_average")), ] ) + # Get letter frequency + for i in alc: + x = x.with_columns( + [ + (pl.col("query").str.to_lowercase().str.count_matches(rf"{i}").truediv(pl.col("query").str.len_chars())).alias(f"freq_{i}"), + ] + ) x = x.with_columns( - [ + [ # FQDN + (pl.col("query")).alias("fqdn"), + (pl.col("query").str.len_chars().alias("fqdn_full_count")), + (pl.col("query").str.count_matches(r"[a-zA-Z]").truediv(pl.col("query").str.len_chars())).alias("fqdn_alpha_count"), + (pl.col("query").str.count_matches(r"[0-9]").truediv(pl.col("query").str.len_chars())).alias("fqdn_numeric_count"), + (pl.col("query").str.count_matches(r"[^\w\s]").truediv(pl.col("query").str.len_chars())).alias("fqdn_special_count"), + ] + ) + x = x.with_columns( + [ + # Second-level domain (pl.when(pl.col("labels").list.len() > 2) .then( pl.col("labels").list.get(-2) ).otherwise( pl.col("labels").list.get(0) - ).alias("SLD")), - (pl.col("query").str.len_chars().alias("FQDN_full_count")), - (pl.col("query").str.count_matches(r"[0-9]").alias("FQDN_numeric_count")), - (pl.col("query").str.count_matches(r"[^\w\s]").alias("FQDN_special_count")), + ).alias("secondleveldomain")) + ] + ) + x = x.with_columns( + [ + (pl.col("secondleveldomain").str.len_chars().truediv(pl.col("secondleveldomain").str.len_chars()).alias("secondleveldomain_full_count")), + (pl.col("secondleveldomain").str.count_matches(r"[a-zA-Z]").truediv(pl.col("secondleveldomain").str.len_chars())).alias("secondleveldomainn_alpha_count"), + (pl.col("secondleveldomain").str.count_matches(r"[0-9]").truediv(pl.col("secondleveldomain").str.len_chars())).alias("secondleveldomainn_numeric_count"), + (pl.col("secondleveldomain").str.count_matches(r"[^\w\s]").truediv(pl.col("secondleveldomain").str.len_chars())).alias("secondleveldomain_special_count"), + ] + ) + x = x.with_columns( + [ # Third-level domain (pl.when(pl.col("labels").list.len() > 2) .then( - pl.col("labels").list.get(0).str.len_chars() - ).otherwise(0).alias("SD_full_count")), - (pl.when(pl.col("labels").list.len() > 2) - .then( - pl.col("labels").list.get(0).str.count_matches(r"[0-9]") - ).otherwise(0).alias("SD_numeric_count")), - (pl.when(pl.col("labels").list.len() > 2) - .then( - pl.col("labels").list.get(0).str.count_matches(r"[^\w\s]") - ).otherwise(0).alias("SD_special_count")), - # Second-level domain - (pl.when(pl.col("labels").list.len() > 2) - .then( - pl.col("labels").list.get(1).str.len_chars() - ).otherwise( - pl.col("labels").list.get(0).str.len_chars() - ).alias("SLD_full_count")), - (pl.when(pl.col("labels").list.len() > 2) - .then( - pl.col("labels").list.get(1).str.count_matches(r"[0-9]") - ).otherwise( - pl.col("labels").list.get(0).str.count_matches(r"[0-9]") - ).alias("SLD_numeric_count")), - (pl.when(pl.col("labels").list.len() > 2) - .then( - pl.col("labels").list.get(1).str.count_matches(r"[^\w\s]") - ).otherwise( - pl.col("labels").list.get(0).str.count_matches(r"[^\w\s]") - ).alias("SLD_special_count")), + pl.col("labels").list.slice(0, pl.col("labels").list.len() - 2).list.join(".") + ).otherwise(pl.lit("")).alias("thirdleveldomain")), + ] + ) + x = x.with_columns( + [ + (pl.col("thirdleveldomain").str.len_chars().truediv(pl.col("thirdleveldomain").str.len_chars()).alias("thirdleveldomain_full_count")), + (pl.col("thirdleveldomain").str.count_matches(r"[a-zA-Z]").truediv(pl.col("thirdleveldomain").str.len_chars())).alias("thirdleveldomain_alpha_count"), + (pl.col("thirdleveldomain").str.count_matches(r"[0-9]").truediv(pl.col("thirdleveldomain").str.len_chars())).alias("thirdleveldomain_numeric_count"), + (pl.col("thirdleveldomain").str.count_matches(r"[^\w\s]").truediv(pl.col("thirdleveldomain").str.len_chars())).alias("thirdleveldomain_special_count"), ] ) - x = x.with_columns([ - (pl.col("query").entropy(base=2).alias("FQDN_entropy")), - ]) - + for ent in ["fqdn", "thirdleveldomain", "secondleveldomain"]: + x = x.with_columns( + [ + (pl.col(ent).map_elements(lambda x: [float(str(x).count(c)) / len(str(x)) for c in dict.fromkeys(list(str(x)))])).alias("prob"), + ] + ) + + t = math.log(2.0) + + x = x.with_columns( + [ + # - sum([ p * math.log(p) / math.log(2.0) for p in prob ]) + (pl.col("prob").list.eval(- pl.element() * pl.element().log() / t).list.sum()).alias(f"{ent}_entropy"), + ] + ) + x = x.drop("prob") + + # Fill NaN + x = x.fill_nan(0) # Drop features not useful anymore x = x.drop(self.features_to_drop) diff --git a/heidgaf/pre/__init__.py b/heidgaf/pre/__init__.py index e97f9c2..bc20ffc 100644 --- a/heidgaf/pre/__init__.py +++ b/heidgaf/pre/__init__.py @@ -1,22 +1,78 @@ +import datetime +import logging from abc import ABCMeta, abstractmethod from typing import Any import polars as pl from heidgaf import ReturnCode -from heidgaf.cache import DataFrameRedisCache +from heidgaf.cache import DataFrameRedisCache, StringRedisCache class Analyzer(metaclass=ABCMeta): + + def __init__(self) -> None: pass @classmethod @abstractmethod - def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache): + def run(self, data: pl.DataFrame, df_cache: DataFrameRedisCache) -> None: pass - @classmethod + def set_warning(self, data: pl.DataFrame, warnings: pl.DataFrame, id: str) -> None: + warning_data = data.join(warnings, on=id, how="semi") + logging.info(warning_data) + + def update_count(self, df: pl.DataFrame, id: str, key: str, df_cache: DataFrameRedisCache, threshold: int) -> pl.DataFrame: + # Dividing highest and lowest timestamp to get time range. + # By this, we can work on relative values and are able to compare new data + timestamp_range = (df["timestamp"].max() - df["timestamp"].min()).seconds // 3600 + if timestamp_range == 0: + timestamp_range = 1 + + frequency = df.group_by(id).count().with_columns( + [ + pl.lit(datetime.datetime.now()).alias("timestamp"), + pl.col("count").truediv(timestamp_range), + pl.lit(timestamp_range).alias("duration") + ] + ) + + # Check if dns_server_frequency exists in redis cache + if key in df_cache: + frequency = df_cache[key].join(frequency, on=id, how="left") + + frequency = frequency.with_columns( + [ + (pl.col("count").add(pl.col("count_right").fill_null(0))), + ( + + pl.when(pl.col("timestamp_right") > pl.col("timestamp")) + .then( + pl.col("timestamp_right") + ).otherwise( + pl.col("timestamp") + ) + ).alias("timestamp_new") + ] + + ) + frequency = frequency.drop("timestamp", "count_right", "timestamp_right", "duration_right") + frequency = frequency.rename( + { + "timestamp_new" : "timestamp" + } + ) + logging.debug(f'Redis Data: {frequency}') + + ip_warnings = frequency.filter(pl.col("count") > (threshold / timestamp_range)).sort("count") + + # Store information in redis client + df_cache[key] = frequency + + return frequency, ip_warnings + @abstractmethod - def set_warning(self, data: Any, redis_cache: DataFrameRedisCache): - pass \ No newline at end of file + def update_threshold(threshould, tpr, fpr): + pass diff --git a/heidgaf/pre/domain_analyzer.py b/heidgaf/pre/domain_analyzer.py index e295d4b..00632d0 100644 --- a/heidgaf/pre/domain_analyzer.py +++ b/heidgaf/pre/domain_analyzer.py @@ -7,12 +7,23 @@ class DomainAnalyzer(Analyzer): + KEY_SECOND_LEVEL_DOMAIN = "secondleveldomain_frequency" + KEY_THIRD_LEVEL_DOMAIN = "thirdleveldomain_frequency" + KEY_FQDN = "fqdn_frequency" + def __init__(self) -> None: super().__init__() @classmethod - def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache): + def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache) -> pl.DataFrame: # Filter data with no errors - df = data.filter(pl.col("query") != "|").filter(pl.col("return_code") != ReturnCode.NOERROR.value).filter(pl.col("query").str.split(".").list.len() != 1) + df = data.filter(pl.col("query") != "|").filter(pl.col("query").str.split(".").list.len() != 1) + + _, warning = self.update_count(self, df, "fqdn", self.KEY_FQDN, redis_cache, 5000) + self.set_warning(self, df, warning, "fqdn") + + _, warning = self.update_count(self, df, "secondleveldomain", self.KEY_SECOND_LEVEL_DOMAIN, redis_cache, 5000) + self.set_warning(self, df, warning, "secondleveldomain") - \ No newline at end of file + _, warning = self.update_count(self, df, "thirdleveldomain", self.KEY_THIRD_LEVEL_DOMAIN, redis_cache, 5000) + self.set_warning(self, df, warning, "thirdleveldomain") diff --git a/heidgaf/pre/ip_analyzer.py b/heidgaf/pre/ip_analyzer.py index fbfe781..dd67610 100644 --- a/heidgaf/pre/ip_analyzer.py +++ b/heidgaf/pre/ip_analyzer.py @@ -1,5 +1,6 @@ import logging +import numpy as np import polars as pl from heidgaf import ReturnCode @@ -8,42 +9,20 @@ class IPAnalyzer(Analyzer): - KEY_IP_FREQUENCY = "client_ip_frequency" - KEY_DNS_SERVER = "dns_server_frequency" - KEY_SLD = "sld_frequency" + KEY_IP_FREQUENCY = "client_ip_error_frequency" + KEY_DNS_SERVER = "dns_server_error_frequency" def __init__(self) -> None: super().__init__() - self.threshold = 200 @classmethod - def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache): + def run(self, data: pl.DataFrame, df_cache: DataFrameRedisCache) -> pl.DataFrame: # Filter data with no errors df = data.filter(pl.col("query") != "|").filter(pl.col("return_code") != ReturnCode.NOERROR.value).filter(pl.col("query").str.split(".").list.len() != 1) - # Update IP frequency based on errors - self.__update_count(df, "client_ip", self.KEY_IP_FREQUENCY, redis_cache) - self.__update_count(df, "dns_server", self.KEY_DNS_SERVER, redis_cache) + # Update frequencies based on errors + _, warning = self.update_count(self, df, "client_ip", self.KEY_IP_FREQUENCY, df_cache, 200) + self.set_warning(self, df, warning, "client_ip") - self.__update_count(df, "SLD", self.KEY_SLD, redis_cache) - - # TODO: Process frequency and return values - - # TODO: Check if IP has more than threshold error request -> if yes, check distribution. - - - def __update_count(df: pl.DataFrame, id: str, key: str, redis_cache: DataFrameRedisCache) -> None: - frequency = df.group_by(id).count() - - # TODO Dividing highest and lowest timestamp - df.max(pl.col("timestamp")) - df.min(pl.col("timestamp")) - - # Check if dns_server_frequency exists in redis cache - if key in redis_cache: - frequency = pl.concat([redis_cache[key], frequency]).groupby(id).agg(pl.sum('count')) - logging.debug(f'Redis Data: {frequency}') - - # Store information in redis client - redis_cache[key] = frequency - - \ No newline at end of file + _, warning = self.update_count(self, df, "dns_server", self.KEY_DNS_SERVER, df_cache, 200) + self.set_warning(self, df, warning, "dns_server") diff --git a/heidgaf/pre/noerror_analyzer.py b/heidgaf/pre/noerror_analyzer.py deleted file mode 100644 index 1e667f5..0000000 --- a/heidgaf/pre/noerror_analyzer.py +++ /dev/null @@ -1,16 +0,0 @@ -import polars as pl -import redis - -from heidgaf import ReturnCode -from heidgaf.cache import DataFrameRedisCache -from heidgaf.pre import Analyzer - - -class NoErrorAnalyzer(Analyzer): - def __init__(self) -> None: - super().__init__() - - @classmethod - def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache): - # Filter data with no errors - df = data.filter(pl.col("query") != "|").filter(pl.col("return_code") != ReturnCode.NOERROR.value).filter(pl.col("query").str.split(".").list.len() != 1) diff --git a/heidgaf/pre/time_analyer.py b/heidgaf/pre/time_analyer.py index 7cdee24..3438147 100644 --- a/heidgaf/pre/time_analyer.py +++ b/heidgaf/pre/time_analyer.py @@ -1,3 +1,5 @@ +import logging + import polars as pl from heidgaf.cache import DataFrameRedisCache @@ -5,9 +7,15 @@ class TimeAnalyzer(Analyzer): + KEY_IP_FREQUENCY = "client_ip_frequency" + def __init__(self) -> None: super().__init__() @classmethod - def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache): - pass \ No newline at end of file + def run(self, data: pl.DataFrame, df_cache: DataFrameRedisCache): + df = data.filter(pl.col("query") != "|").filter(pl.col("query").str.split(".").list.len() != 1) + + # Update count and handle warnings + _, warning = self.update_count(self, df, "client_ip", self.KEY_IP_FREQUENCY, df_cache, 200) + self.set_warning(self, df, warning, "client_ip")