Skip to content

Commit

Permalink
Update implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanDeveloper committed Apr 2, 2024
1 parent b582da5 commit f00495f
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 133 deletions.
26 changes: 25 additions & 1 deletion heidgaf/cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
import logging
from typing import Any

import polars as pl
import redis


class StringRedisCache(object):
def __init__(self ,redis_host="localhost", redis_port=6379, redis_db=0, redis_max_connections=20,) -> None:
logging.debug("Connect to Redis server")
self.pool = redis.ConnectionPool(host=redis_host, port=redis_port, db=redis_db, max_connections=redis_max_connections)
self.redis_client = redis.Redis(connection_pool=self.pool)
self.redis_client.ping()

def __getitem__(self, key: str) -> pl.DataFrame:
if self.redis_client.exists(key):
return self.redis_client.get(key)
else:
return ""

def __contains__(self, key):
return self.redis_client.exists(key)

def __str__(self):
return f'Redis has stored following keys: {self.redis_client.keys}'

def __setitem__(self, key: str, value: str) -> pl.DataFrame:
self.redis_client.set(key, value)

def __delitem__(self, key: str) -> None:
self.redis_client.delete(key)

class DataFrameRedisCache(object):
def __init__(self ,redis_host="localhost", redis_port=6379, redis_db=0, redis_max_connections=20,) -> None:
logging.debug("Connect to Redis server")
Expand Down
2 changes: 0 additions & 2 deletions heidgaf/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def __init__(self, data_path: Any, cast_dataset: Callable = None) -> None:
else:
self.data = pl.read_csv(data_path)

logging.info(self.data)

self.X_train, self.X_val, self.X_test, self.Y_train, self.Y_val, self.Y_test = self.__train_test_val_split()

def __len__(self):
Expand Down
21 changes: 0 additions & 21 deletions heidgaf/dataset/majestic.py

This file was deleted.

17 changes: 8 additions & 9 deletions heidgaf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import polars as pl
from click import Path

from heidgaf.cache import DataFrameRedisCache
from heidgaf.cache import DataFrameRedisCache, StringRedisCache
from heidgaf.post.feature import Preprocessor
from heidgaf.pre.domain_analyzer import DomainAnalyzer
from heidgaf.pre.ip_analyzer import IPAnalyzer
Expand All @@ -25,26 +25,24 @@ class Separator(Enum):

class DNSAnalyzerPipeline:
def __init__(self, path: Path, redis_host="localhost", redis_port=6379, redis_db=0, redis_max_connections=20, filetype=FileType.TXT, separator=Separator.SPACE) -> None:
self.redis_cache = DataFrameRedisCache(redis_host, redis_port, redis_db, redis_max_connections)
self.df_cache = DataFrameRedisCache(redis_host, redis_port, redis_db, redis_max_connections)
self.string_cache = StringRedisCache(redis_host, redis_port, redis_db, redis_max_connections)

if os.path.isfile(path):
logging.debug(f"Processing files: {path}")
self.data = self.load_data(path, separator.value)
elif os.path.isdir(path):
# TODO Handle large files, Currently redis cannot store more than 512 MB
logging.debug(f"Processing files: {path}/*.{filetype.value}")
self.data = self.load_data(f'{path}/*.{filetype.value}', separator.value)

# self.redis_cache["data"] = self.data

def load_data(self, path, separator):
dataframes = pl.read_csv(path, separator=separator, try_parse_dates=False, has_header=False).with_columns(
[
(pl.col('column_1').str.strptime(pl.Datetime).cast(pl.Datetime))
]
)

dataframes = dataframes.rename(
dataframes = dataframes[:50000].rename(
{
"column_1": "timestamp",
"column_2": "return_code",
Expand All @@ -60,11 +58,12 @@ def load_data(self, path, separator):
return dataframes

def run(self):

# Running modules to analyze log files
# TODO Multithreading
preprocessor = Preprocessor(features_to_drop=[])
processed_data = preprocessor.transform(self.data)

IPAnalyzer.run(processed_data, self.redis_cache)
DomainAnalyzer.run(processed_data, self.redis_cache)
TimeAnalyzer.run(processed_data, self.redis_cache)
IPAnalyzer.run(processed_data, self.df_cache)
DomainAnalyzer.run(processed_data, self.df_cache)
TimeAnalyzer.run(processed_data, self.df_cache)
8 changes: 4 additions & 4 deletions heidgaf/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def __init__(self,

def fit(self, x_train, y_train):
x_train = self.preprocessor.transform(x=x_train)
# x_train = self.target_encoder.fit_transform(x=x_train, y=y_train)
# x_train = self.mean_imputer.fit_transform(x=x_train)
x_train = self.target_encoder.fit_transform(x=x_train, y=y_train)
x_train = self.mean_imputer.fit_transform(x=x_train)
self.clf.fit(x=x_train.to_numpy(), y=y_train)

def predict(self, x):
x = self.preprocessor.transform(x=x)
# x = self.target_encoder.transform(x=x)
# x = self.mean_imputer.transform(x=x)
x = self.target_encoder.transform(x=x)
x = self.mean_imputer.transform(x=x)
return self.clf.predict(x=x.to_numpy())

Empty file removed heidgaf/models/xgboost2.py
Empty file.
102 changes: 62 additions & 40 deletions heidgaf/post/feature.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import math
from string import ascii_lowercase as alc
from typing import List

import polars as pl
Expand All @@ -15,8 +17,6 @@ def __init__(self, features_to_drop: List):
feature_to_drop (list): list of feature to drop
"""
self.features_to_drop = features_to_drop
# TODO Set majestic million score
self.majesticmillion = MajesticMillionDataset()

def transform(self, x: pl.DataFrame) -> pl.DataFrame:
"""Transform our dataset with new features
Expand All @@ -35,58 +35,80 @@ def transform(self, x: pl.DataFrame) -> pl.DataFrame:
(pl.col("query").str.strip_chars(".").str.len_chars().alias("label_average")),
]
)
# Get letter frequency
for i in alc:
x = x.with_columns(
[
(pl.col("query").str.to_lowercase().str.count_matches(rf"{i}").truediv(pl.col("query").str.len_chars())).alias(f"freq_{i}"),
]
)

x = x.with_columns(
[
[
# FQDN
(pl.col("query")).alias("fqdn"),
(pl.col("query").str.len_chars().alias("fqdn_full_count")),
(pl.col("query").str.count_matches(r"[a-zA-Z]").truediv(pl.col("query").str.len_chars())).alias("fqdn_alpha_count"),
(pl.col("query").str.count_matches(r"[0-9]").truediv(pl.col("query").str.len_chars())).alias("fqdn_numeric_count"),
(pl.col("query").str.count_matches(r"[^\w\s]").truediv(pl.col("query").str.len_chars())).alias("fqdn_special_count"),
]
)
x = x.with_columns(
[
# Second-level domain
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(-2)
).otherwise(
pl.col("labels").list.get(0)
).alias("SLD")),
(pl.col("query").str.len_chars().alias("FQDN_full_count")),
(pl.col("query").str.count_matches(r"[0-9]").alias("FQDN_numeric_count")),
(pl.col("query").str.count_matches(r"[^\w\s]").alias("FQDN_special_count")),
).alias("secondleveldomain"))
]
)
x = x.with_columns(
[
(pl.col("secondleveldomain").str.len_chars().truediv(pl.col("secondleveldomain").str.len_chars()).alias("secondleveldomain_full_count")),
(pl.col("secondleveldomain").str.count_matches(r"[a-zA-Z]").truediv(pl.col("secondleveldomain").str.len_chars())).alias("secondleveldomainn_alpha_count"),
(pl.col("secondleveldomain").str.count_matches(r"[0-9]").truediv(pl.col("secondleveldomain").str.len_chars())).alias("secondleveldomainn_numeric_count"),
(pl.col("secondleveldomain").str.count_matches(r"[^\w\s]").truediv(pl.col("secondleveldomain").str.len_chars())).alias("secondleveldomain_special_count"),
]
)
x = x.with_columns(
[
# Third-level domain
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(0).str.len_chars()
).otherwise(0).alias("SD_full_count")),
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(0).str.count_matches(r"[0-9]")
).otherwise(0).alias("SD_numeric_count")),
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(0).str.count_matches(r"[^\w\s]")
).otherwise(0).alias("SD_special_count")),
# Second-level domain
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(1).str.len_chars()
).otherwise(
pl.col("labels").list.get(0).str.len_chars()
).alias("SLD_full_count")),
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(1).str.count_matches(r"[0-9]")
).otherwise(
pl.col("labels").list.get(0).str.count_matches(r"[0-9]")
).alias("SLD_numeric_count")),
(pl.when(pl.col("labels").list.len() > 2)
.then(
pl.col("labels").list.get(1).str.count_matches(r"[^\w\s]")
).otherwise(
pl.col("labels").list.get(0).str.count_matches(r"[^\w\s]")
).alias("SLD_special_count")),
pl.col("labels").list.slice(0, pl.col("labels").list.len() - 2).list.join(".")
).otherwise(pl.lit("")).alias("thirdleveldomain")),
]
)
x = x.with_columns(
[
(pl.col("thirdleveldomain").str.len_chars().truediv(pl.col("thirdleveldomain").str.len_chars()).alias("thirdleveldomain_full_count")),
(pl.col("thirdleveldomain").str.count_matches(r"[a-zA-Z]").truediv(pl.col("thirdleveldomain").str.len_chars())).alias("thirdleveldomain_alpha_count"),
(pl.col("thirdleveldomain").str.count_matches(r"[0-9]").truediv(pl.col("thirdleveldomain").str.len_chars())).alias("thirdleveldomain_numeric_count"),
(pl.col("thirdleveldomain").str.count_matches(r"[^\w\s]").truediv(pl.col("thirdleveldomain").str.len_chars())).alias("thirdleveldomain_special_count"),
]
)

x = x.with_columns([
(pl.col("query").entropy(base=2).alias("FQDN_entropy")),
])

for ent in ["fqdn", "thirdleveldomain", "secondleveldomain"]:
x = x.with_columns(
[
(pl.col(ent).map_elements(lambda x: [float(str(x).count(c)) / len(str(x)) for c in dict.fromkeys(list(str(x)))])).alias("prob"),
]
)

t = math.log(2.0)

x = x.with_columns(
[
# - sum([ p * math.log(p) / math.log(2.0) for p in prob ])
(pl.col("prob").list.eval(- pl.element() * pl.element().log() / t).list.sum()).alias(f"{ent}_entropy"),
]
)
x = x.drop("prob")

# Fill NaN
x = x.fill_nan(0)
# Drop features not useful anymore
x = x.drop(self.features_to_drop)

Expand Down
66 changes: 61 additions & 5 deletions heidgaf/pre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,78 @@
import datetime
import logging
from abc import ABCMeta, abstractmethod
from typing import Any

import polars as pl

from heidgaf import ReturnCode
from heidgaf.cache import DataFrameRedisCache
from heidgaf.cache import DataFrameRedisCache, StringRedisCache


class Analyzer(metaclass=ABCMeta):


def __init__(self) -> None:
pass

@classmethod
@abstractmethod
def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache):
def run(self, data: pl.DataFrame, df_cache: DataFrameRedisCache) -> None:
pass

@classmethod
def set_warning(self, data: pl.DataFrame, warnings: pl.DataFrame, id: str) -> None:
warning_data = data.join(warnings, on=id, how="semi")
logging.info(warning_data)

def update_count(self, df: pl.DataFrame, id: str, key: str, df_cache: DataFrameRedisCache, threshold: int) -> pl.DataFrame:
# Dividing highest and lowest timestamp to get time range.
# By this, we can work on relative values and are able to compare new data
timestamp_range = (df["timestamp"].max() - df["timestamp"].min()).seconds // 3600
if timestamp_range == 0:
timestamp_range = 1

frequency = df.group_by(id).count().with_columns(
[
pl.lit(datetime.datetime.now()).alias("timestamp"),
pl.col("count").truediv(timestamp_range),
pl.lit(timestamp_range).alias("duration")
]
)

# Check if dns_server_frequency exists in redis cache
if key in df_cache:
frequency = df_cache[key].join(frequency, on=id, how="left")

frequency = frequency.with_columns(
[
(pl.col("count").add(pl.col("count_right").fill_null(0))),
(

pl.when(pl.col("timestamp_right") > pl.col("timestamp"))
.then(
pl.col("timestamp_right")
).otherwise(
pl.col("timestamp")
)
).alias("timestamp_new")
]

)
frequency = frequency.drop("timestamp", "count_right", "timestamp_right", "duration_right")
frequency = frequency.rename(
{
"timestamp_new" : "timestamp"
}
)
logging.debug(f'Redis Data: {frequency}')

ip_warnings = frequency.filter(pl.col("count") > (threshold / timestamp_range)).sort("count")

# Store information in redis client
df_cache[key] = frequency

return frequency, ip_warnings

@abstractmethod
def set_warning(self, data: Any, redis_cache: DataFrameRedisCache):
pass
def update_threshold(threshould, tpr, fpr):
pass
17 changes: 14 additions & 3 deletions heidgaf/pre/domain_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@


class DomainAnalyzer(Analyzer):
KEY_SECOND_LEVEL_DOMAIN = "secondleveldomain_frequency"
KEY_THIRD_LEVEL_DOMAIN = "thirdleveldomain_frequency"
KEY_FQDN = "fqdn_frequency"

def __init__(self) -> None:
super().__init__()

@classmethod
def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache):
def run(self, data: pl.DataFrame, redis_cache: DataFrameRedisCache) -> pl.DataFrame:
# Filter data with no errors
df = data.filter(pl.col("query") != "|").filter(pl.col("return_code") != ReturnCode.NOERROR.value).filter(pl.col("query").str.split(".").list.len() != 1)
df = data.filter(pl.col("query") != "|").filter(pl.col("query").str.split(".").list.len() != 1)

_, warning = self.update_count(self, df, "fqdn", self.KEY_FQDN, redis_cache, 5000)
self.set_warning(self, df, warning, "fqdn")

_, warning = self.update_count(self, df, "secondleveldomain", self.KEY_SECOND_LEVEL_DOMAIN, redis_cache, 5000)
self.set_warning(self, df, warning, "secondleveldomain")


_, warning = self.update_count(self, df, "thirdleveldomain", self.KEY_THIRD_LEVEL_DOMAIN, redis_cache, 5000)
self.set_warning(self, df, warning, "thirdleveldomain")
Loading

0 comments on commit f00495f

Please sign in to comment.