Skip to content

Commit

Permalink
Add first draft
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanDeveloper committed Apr 15, 2024
1 parent 5248c57 commit e0ceddd
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 174 deletions.
2 changes: 1 addition & 1 deletion heidgaf/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from heidgaf import CONTEXT_SETTINGS
from heidgaf.main import DNSAnalyzerPipeline, Detector, FileType, Separator
from heidgaf.main import Detector, DNSAnalyzerPipeline, FileType, Separator
from heidgaf.models.lr import LogisticRegression
from heidgaf.train import DNSAnalyzerTraining
from heidgaf.version import __version__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import polars as pl
import sklearn.model_selection
from torch.utils.data.dataset import Dataset
from fe_polars.encoding.one_hot_encoding import OneHotEncoder
from torch.utils.data.dataset import Dataset


def preprocess(x: pl.DataFrame):
Expand Down
File renamed without changes.
118 changes: 70 additions & 48 deletions heidgaf/pre/__init__.py → heidgaf/inspectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,49 @@
import logging
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, List
import matplotlib.pyplot as plt
from typing import List

import numpy as np
import polars as pl
import torch
from fe_polars.encoding.target_encoding import TargetEncoder
from fe_polars.imputing.base_imputing import Imputer
from heidgaf import ReturnCode

from heidgaf.cache import DataFrameRedisCache
from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig
from heidgaf.detectors.thresholding_algorithm import ThresholdingAnomalyDetector
from heidgaf.detectors.base_anomaly import AnomalyDetector
from heidgaf.models import Pipeline
from heidgaf.post.feature import Preprocessor
from heidgaf.feature import Preprocessor


@dataclass
class AnalyzerConfig:
"""Configuration class of Analyzers"""
class InspectorConfig:
"""Configuration class of inspectors."""

detector: AnomalyDetector
df_cache: DataFrameRedisCache
threshold: 3
model: torch.nn.Module


class Analyzer(metaclass=ABCMeta):
"""_summary_
class Inspector(metaclass=ABCMeta):
"""Metaclass to test DNS requests.
Args:
metaclass (_type_, optional): _description_. Defaults to ABCMeta.
metaclass (_type_, optional): Metaclass object. Defaults to ABCMeta.
"""

def __init__(self, config: AnalyzerConfig) -> None:
def __init__(self, config: InspectorConfig) -> None:
"""Initializes detector, cache, threshold, and model.
Args:
config (TesterConfig): TesterConfig.
"""
self.detector = config.detector
self.df_cache = config.df_cache
self.threshold = config.threshold
self.model = config.model

@abstractmethod
def run(self, data: pl.DataFrame) -> None:
"""_summary_
Args:
data (pl.DataFrame): _description_
df_cache (DataFrameRedisCache): _description_
"""
pass

def set_warning(self, data: pl.DataFrame, warnings: List, id: str) -> None:
"""Creates initial warning for classifiers
Args:
data (pl.DataFrame): _description_
warnings (pl.DataFrame): _description_
id (str): _description_
"""
model_pipeline = Pipeline(
self.model_pipeline = Pipeline(
preprocessor=Preprocessor(
features_to_drop=[
"timestamp",
Expand All @@ -80,22 +66,57 @@ def set_warning(self, data: pl.DataFrame, warnings: List, id: str) -> None:
target_encoder=TargetEncoder(smoothing=100, features_to_encode=[]),
clf=self.model,
)
for warning in warnings:

@abstractmethod
def run(self, data: pl.DataFrame) -> pl.DataFrame:
"""Runs tester.
Args:
data (pl.DataFrame): Preprocessed data
Returns:
pl.DataFrame: Suspicious Ids.
"""
pass

def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> None:
"""Creates initial warning for classifiers
Args:
data (pl.DataFrame): Preprocessed data.
suspicious (List): Suspicious Id's retrieved by detectors.
id (str): Id of column to process.
"""

fqdn_distro = data.group_by("fqdn").agg(
pl.col("return_code")
.is_in(["NXDOMAIN", "SERVFAIL"])
.sum()
.truediv(
pl.col("client_ip").count().truediv(pl.col("client_ip").n_unique())
)
.alias("distro")
)
fqdn_distro = fqdn_distro.filter(pl.col("distro") > 0.1)

for warning in suspicious:
logging.debug(f"Analyze data in depth for {warning}")

data_id = data.filter(pl.col(id) == warning)

# Predict data based on model
y_pred = model_pipeline.predict(data_id)

indices = np.where(y_pred == 1)[0]
data_id = data_id.with_row_count(name="idx", offset=0)
supicious_data = data_id.filter(pl.col("idx").is_in(indices))


if not supicious_data.is_empty():
print(f"{warning} has following errors")
print(supicious_data.select(["fqdn"]).unique())

data_id = data.filter(pl.col(id) == warning).filter(
pl.col("fqdn").is_in(fqdn_distro["fqdn"].to_list())
)
if not data_id.is_empty():
# Predict data based on model
y_pred = self.model_pipeline.predict(data_id)

indices = np.where(y_pred == 1)[0]
data_id = data_id.with_row_count(name="idx", offset=0)
supicious_data = data_id.filter(pl.col("idx").is_in(indices))

if not supicious_data.is_empty():
logging.debug(f"{warning} has following errors")
with pl.Config(tbl_rows=100):
logging.debug(supicious_data.select(["fqdn"]).unique())

def update_count(
self,
Expand Down Expand Up @@ -126,14 +147,15 @@ def update_count(
id_distribution = all_dates.join(
id_distribution, how="left", on=[id, "timestamp"]
).fill_null(0)

# Iterate over all unique IDs
unique_id = id_distribution.select([id]).unique()
for row in unique_id.rows(named=True):

# Run detector
unique_id_distro = id_distribution.filter(pl.col(id) == row[id])
detector_results = self.detector.run(unique_id_distro["count"])

# If total count of signals is higher than threshold, we append our supsicious IPs to our warnings list
if np.sum(detector_results["signals"]) > self.threshold:
warnings.append(row[id])
Expand Down
53 changes: 53 additions & 0 deletions heidgaf/inspectors/domain_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import polars as pl

from heidgaf.inspectors import Inspector, InspectorConfig


class DomainInspector(Inspector):
"""Inspects domains based features.
Args:
Tester (Tester): Configuration.
"""
KEY_SECOND_LEVEL_DOMAIN = "secondleveldomain_frequency"
KEY_THIRD_LEVEL_DOMAIN = "thirdleveldomain_frequency"
KEY_FQDN = "fqdn_frequency"

def __init__(self, config: InspectorConfig) -> None:
super().__init__(config)

def update_threshold(threshould, tpr, fpr):
pass

def run(self, data: pl.DataFrame) -> pl.DataFrame:
"""Runs tester for domain name based features.
Args:
data (pl.DataFrame): Proprocessed data.
Returns:
pl.DataFrame: Anomalies.
"""
min_date = data.select(["timestamp"]).min().item()
max_date = data.select(["timestamp"]).max().item()

# Filter data with no errors
df = data.filter(pl.col("query") != "|").filter(
pl.col("query").str.split(".").list.len() != 1
)

# Check anomalies in FQDN
warnings = self.update_count(df, min_date, max_date, "fqdn", self.KEY_FQDN)
self.warnings(data, warnings, "fqdn")

# Check anomalies in second level
warnings = self.update_count(
df, min_date, max_date, "secondleveldomain", self.KEY_SECOND_LEVEL_DOMAIN
)
self.warnings(data, warnings, "secondleveldomain")

# Check anomalies in third level
warnings = self.update_count(
df, min_date, max_date, "thirdleveldomain", self.KEY_THIRD_LEVEL_DOMAIN
)
self.warnings(data, warnings, "thirdleveldomain")
35 changes: 22 additions & 13 deletions heidgaf/pre/ip_analyzer.py → heidgaf/inspectors/ip_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
import datetime
import logging

import numpy as np
import polars as pl

from heidgaf import ReturnCode
from heidgaf.cache import DataFrameRedisCache
from heidgaf.pre import Analyzer, AnalyzerConfig
from heidgaf.inspectors import Inspector, InspectorConfig


class IPAnalyzer(Analyzer):
class IPInspector(Inspector):
KEY_IP_FREQUENCY = "client_ip_error_frequency"
KEY_DNS_SERVER = "dns_server_error_frequency"

def __init__(self, config: AnalyzerConfig) -> None:
def __init__(self, config: InspectorConfig) -> None:
"""IP analyzer class. It checks for anomalies of requests by client IPs.
Args:
config (AnalyzerConfig): Analyzer configuraiton.
"""
super().__init__(config)

def update_threshold(threshould, tpr, fpr):
pass

def run(self, data: pl.DataFrame) -> pl.DataFrame:
"""Runs tester for IP address based features.
Args:
data (pl.DataFrame): Preprocessed data.
Returns:
pl.DataFrame: Suspicious Ids.
"""
min_date = data.select(["timestamp"]).min().item()
max_date = data.select(["timestamp"]).max().item()

# Filter data with no errors
df = (
data.filter(pl.col("query") != "|")
Expand All @@ -33,9 +42,9 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame:
warnings = self.update_count(
df, min_date, max_date, "client_ip", self.KEY_IP_FREQUENCY
)
self.set_warning(data, warnings, "client_ip")
self.warnings(data, warnings, "client_ip")

# warnings = self.update_count(
# df, min_date, max_date, "dns_server", self.KEY_DNS_SERVER
# )
# self.set_warning(data, warnings, "dns_server")
warnings = self.update_count(
df, min_date, max_date, "dns_server", self.KEY_DNS_SERVER
)
self.warnings(data, warnings, "dns_server")
35 changes: 17 additions & 18 deletions heidgaf/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
import logging
import os
from enum import Enum, unique
Expand All @@ -7,16 +6,15 @@
import polars as pl
from click import Path

from heidgaf.cache import DataFrameRedisCache, StringRedisCache
from heidgaf.cache import DataFrameRedisCache
from heidgaf.detectors.arima_anomaly_detector import ARIMAAnomalyDetector
from heidgaf.detectors.base_anomaly import AnomalyDetectorConfig
from heidgaf.detectors.exponential_thresholding import EMAAnomalyDetector
from heidgaf.detectors.thresholding_algorithm import ThresholdingAnomalyDetector
from heidgaf.post.feature import Preprocessor
from heidgaf.pre import Analyzer, AnalyzerConfig
from heidgaf.pre.domain_analyzer import DomainAnalyzer
from heidgaf.pre.ip_analyzer import IPAnalyzer
from heidgaf.pre.time_analyer import TimeAnalyzer
from heidgaf.detectors.thresholding_algorithm import \
ThresholdingAnomalyDetector
from heidgaf.inspectors import Inspector, InspectorConfig
from heidgaf.inspectors.domain_analyzer import DomainInspector
from heidgaf.inspectors.ip_analyzer import IPInspector


@unique
Expand All @@ -38,11 +36,10 @@ class Separator(Enum):
COMMA = ","


def analyzer_factory(source: str, config: AnalyzerConfig) -> Analyzer:
def inspector_factory(source: str, config: InspectorConfig) -> Inspector:
factory = {
"IP": (IPAnalyzer(config)),
"Domain": (DomainAnalyzer(config)),
"Time": (TimeAnalyzer(config)),
"IP": (IPInspector(config)),
"Domain": (DomainInspector(config)),
}
if source in factory:
return factory[source]
Expand All @@ -68,7 +65,7 @@ def __init__(
redis_port=6379,
redis_db=0,
redis_max_connections=20,
threshold=5
threshold=5,
) -> None:
self.df_cache = DataFrameRedisCache(
redis_host, redis_port, redis_db, redis_max_connections
Expand Down Expand Up @@ -158,7 +155,7 @@ def load_data(self, path: str, separator: str) -> pl.DataFrame:
),
]
)

# Filter invalid domains
x = x.filter(pl.col("query") != "|")
x = x.filter(pl.col("labels").list.len() > 1)
Expand All @@ -180,12 +177,14 @@ def run(self):
detector = ThresholdingAnomalyDetector(config)
case "arima":
detector = ARIMAAnomalyDetector(config)
case "threshold":
case "ema":
detector = EMAAnomalyDetector(config)
case _:
raise NotImplementedError(f"Detector not implemented!")

# Run anaylzers to find anomalies in data
config = AnalyzerConfig(detector, self.df_cache, self.threshold, joblib.load("model.pkl"))
for analyzer in ["IP"]:
analyzer_factory(analyzer, config).run(self.data)
config = InspectorConfig(
detector, self.df_cache, self.threshold, joblib.load("model.pkl")
)
for inspector in ["IP", "Domain"]:
inspector_factory(inspector, config).run(self.data)
Loading

0 comments on commit e0ceddd

Please sign in to comment.