Skip to content

Commit

Permalink
Update current state
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanDeveloper committed May 2, 2024
1 parent 25bffeb commit 3f70b46
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 42 deletions.
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,35 @@ heidgaf -h
Run your analysis:

```sh
heidgaf process start -r data/...
heidgaf inspect -r data/...
```

Train your own model:

```sh
heidgaf train -m xg -d all
```

### Data

Currently, we support the data format scheme provided by the [DNS-Collector](https://github.com/dmachard/go-dnscollector/):

`{{ .timestamp }} {{ .return_code }} {{ .client_ip }} {{ .server_ip }} {{ .query }} {{ .type }} {{ .answer }} {{ .size }}b`
- `{{ .timestamp }}`
- `{{ .return_code }}`
- `{{ .client_ip }}`
- `{{ .server_ip }}`
- `{{ .query }}`
- `{{ .type }}`
- `{{ .answer }}`
- `{{ .size }}b`

For training our models, we rely on the following data sets:

- [CICBellDNS2021](https://www.unb.ca/cic/datasets/dns-2021.html)
- [DGTA Benchmark](https://data.mendeley.com/datasets/2wzf9bz7xr/1)
- [DNS Tunneling Queries for Binary Classification](https://data.mendeley.com/datasets/mzn9hvdcxg/1)
- [UMUDGA - University of Murcia Domain Generation Algorithm Dataset](https://data.mendeley.com/datasets/y8ph45msv8/1)
- [Majestic Million](https://de.majestic.com/reports/majestic-million)
- [Real-CyberSecurity-Datasets](https://github.com/gfek/Real-CyberSecurity-Datasets/)

However, we compute all feature separately and only rely on the `domain` and `class`.
Currently, we are only interested in binary classification, thus, the `class` is either `benign` or `malicious`.
Expand Down Expand Up @@ -112,12 +125,9 @@ Based on the following work, we implement heiDGAF to find malicious behaviour su

- SHAP Interpretations of Tree and Neural Network DNS Classifiers for Analyzing DGA Family Characteristics



### Similar Projects

- [Deep Lookup](https://github.com/ybubnov/deep-lookup/) is a deep learning approach for DNS.
- [DGA Detective](https://github.com/COSSAS/dgad)
- https://github.com/Erxathos/DGA-Detector
- https://github.com/gfek/Real-CyberSecurity-Datasets/
- https://github.com/aasthac67/DNS-Tunneling-Detection/
- [Deep Lookup](https://github.com/ybubnov/deep-lookup/) is a deep learning approach for DNS detection.
- [DGA Detective](https://github.com/COSSAS/dgad) is a temporal convolutional network approach for DNS detection.
- [DGA Detector](https://github.com/Erxathos/DGA-Detector) is a NLP approach for DNS detection.
- [DNS Tunneling Detection](https://github.com/aasthac67/DNS-Tunneling-Detection/)
2 changes: 1 addition & 1 deletion heidgaf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ReturnCode(Enum):


logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s",
datefmt="%y-%m-%d %H:%M:%S",
handlers=[logging.FileHandler("heidgaf.log"), logging.StreamHandler()],
Expand Down
13 changes: 11 additions & 2 deletions heidgaf/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def train(model, dataset, output_dir):
default=Detector.THRESHOLDING,
help="Sets the anomaly detector.",
)
@click.option(
"-m",
"--model",
"model",
required=True,
type=click.Choice(Model),
help="Model for prediction."
)
@click.option(
"-s",
"--separator",
Expand Down Expand Up @@ -135,12 +143,13 @@ def train(model, dataset, output_dir):
help="Sets Redis max connection for caching results.",
)
def inspection(
input_dir, detector, separator, filetype, lag, influence, n_standard_deviations, redis_host, redis_port, redis_db, redis_max_connection
input_dir, detector, model, separator, filetype, lag, influence, n_standard_deviations, redis_host, redis_port, redis_db, redis_max_connection
):
click.echo("Starts processing log lines of DNS traffic.")
pipeline = DNSInspectorPipeline(
path=input_dir,
detector=detector,
model=model,
lag=lag,
anomaly_influence=influence,
n_standard_deviations=n_standard_deviations,
Expand All @@ -149,7 +158,7 @@ def inspection(
redis_host=redis_host,
redis_port=redis_port,
redis_db=redis_db,
redis_max_connections=redis_max_connection
redis_max_connections=redis_max_connection,
)
pipeline.run()

Expand Down
13 changes: 9 additions & 4 deletions heidgaf/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(
self.__train_test_val_split()
)

def __len__(self):
def __len__(self) -> int:
"""Returns the length of data set.
Returns:
Expand Down Expand Up @@ -188,15 +188,20 @@ def __train_test_val_split(self, train_frac: float = 0.8, random_state: int = No
return X_train, X_val, X_test, Y_train, Y_val, Y_test

@property
def train(self):
def train(self) -> dict:
"""Training set
Returns:
dict: dictionary with features and labels.
"""
return {"X": self.X_train, "Y": self.Y_train}

@property
def test(self):
def test(self) -> dict:
return {"X": self.X_test, "Y": self.Y_test}

@property
def val(self):
def val(self) -> dict:
return {"X": self.X_val, "Y": self.Y_val}


Expand Down
2 changes: 1 addition & 1 deletion heidgaf/detectors/arima_anomaly_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from statsmodels.tsa.arima.model import ARIMA

from .base_anomaly import AnomalyDetector, AnomalyDetectorConfig
from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig


class ARIMAAnomalyDetector(AnomalyDetector):
Expand Down
4 changes: 1 addition & 3 deletions heidgaf/detectors/exponential_thresholding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Dict, List, Tuple

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from .base_anomaly import AnomalyDetector, AnomalyDetectorConfig
from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig


class EMAAnomalyDetector(AnomalyDetector):
Expand Down
4 changes: 2 additions & 2 deletions heidgaf/detectors/real_time_anomaly.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, List
from typing import List

import numpy as np

from .base_anomaly import AnomalyDetector, AnomalyDetectorConfig
from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig


class RealTimeAnomalyDetector(AnomalyDetector):
Expand Down
2 changes: 1 addition & 1 deletion heidgaf/detectors/thresholding_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from .base_anomaly import AnomalyDetector, AnomalyDetectorConfig
from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig


class ThresholdingAnomalyDetector(AnomalyDetector):
Expand Down
12 changes: 11 additions & 1 deletion heidgaf/inspectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, config: InspectorConfig) -> None:
"thirdleveldomain",
"secondleveldomain",
"fqdn",
"tld"
]
),
mean_imputer=Imputer(features_to_impute=[], strategy="mean"),
Expand All @@ -79,7 +80,7 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame:
"""
pass

def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> None:
def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> pl.DataFrame:
"""Creates initial warning for classifiers
Args:
Expand All @@ -98,15 +99,20 @@ def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> None:
.alias("distro")
)
fqdn_distro = fqdn_distro.filter(pl.col("distro") > 0.05)

total_warnings = []

for warning in suspicious:
logging.debug(f"Analyze data in depth for {warning}")

data_id = data.filter(pl.col(id) == warning).filter(
pl.col("fqdn").is_in(fqdn_distro["fqdn"].to_list())
)
supicious_data = data.clear()
if not data_id.is_empty():
# Predict data based on model
# TODO Create ensemble classification
# TODO Set to regressor
y_pred = self.model_pipeline.predict(data_id)

indices = np.where(y_pred == 1)[0]
Expand All @@ -117,6 +123,10 @@ def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> None:
logging.debug(f"{warning} has following errors")
with pl.Config(tbl_rows=100):
logging.debug(supicious_data.select(["fqdn"]).unique())
total_warnings.append(supicious_data)

return pl.concat(total_warnings)


def update_count(
self,
Expand Down
12 changes: 8 additions & 4 deletions heidgaf/inspectors/domain_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,23 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame:
df = data.filter(pl.col("query") != "|").filter(
pl.col("query").str.split(".").list.len() != 1
)


findings = []

# Check anomalies in FQDN
warnings = self.update_count(df, min_date, max_date, "fqdn", self.KEY_FQDN)
self.warnings(data, warnings, "fqdn")
findings.append(self.warnings(data, warnings, "fqdn"))

# Check anomalies in second level
warnings = self.update_count(
df, min_date, max_date, "secondleveldomain", self.KEY_SECOND_LEVEL_DOMAIN
)
self.warnings(data, warnings, "secondleveldomain")
findings.append(self.warnings(data, warnings, "secondleveldomain"))

# Check anomalies in third level
warnings = self.update_count(
df, min_date, max_date, "thirdleveldomain", self.KEY_THIRD_LEVEL_DOMAIN
)
self.warnings(data, warnings, "thirdleveldomain")
findings.append(self.warnings(data, warnings, "thirdleveldomain"))

return pl.concat(findings)
8 changes: 6 additions & 2 deletions heidgaf/inspectors/ip_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame:
.filter(pl.col("return_code") != ReturnCode.NOERROR.value)
.filter(pl.col("query").str.split(".").list.len() != 1)
)

findings = []

# Update frequencies based on errors
warnings = self.update_count(
df, min_date, max_date, "client_ip", self.KEY_IP_FREQUENCY
)
self.warnings(data, warnings, "client_ip")
findings.append(self.warnings(data, warnings, "client_ip"))

warnings = self.update_count(
df, min_date, max_date, "dns_server", self.KEY_DNS_SERVER
)
self.warnings(data, warnings, "dns_server")
findings.append(self.warnings(data, warnings, "dns_server"))

return pl.concat(findings)
35 changes: 32 additions & 3 deletions heidgaf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import redis
import redis.exceptions
from click import Path
import requests

from heidgaf.cache import DataFrameRedisCache
from heidgaf.detectors.arima_anomaly_detector import ARIMAAnomalyDetector
Expand All @@ -17,6 +18,7 @@
from heidgaf.inspectors import Inspector, InspectorConfig
from heidgaf.inspectors.domain_analyzer import DomainInspector
from heidgaf.inspectors.ip_analyzer import IPInspector
from heidgaf.train import Model


@unique
Expand Down Expand Up @@ -53,6 +55,8 @@ def inspector_factory(source: str, config: InspectorConfig) -> Inspector:

class DNSInspectorPipeline:
"""Main analyzer pipeline. It loads new data and processes it through our analyzers. If an anomaly occurs, our models run"""

MODELS_URL="https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021"

def __init__(
self,
Expand All @@ -69,13 +73,15 @@ def __init__(
redis_db=0,
redis_max_connections=20,
threshold=5,
model=Model.RANDOM_FOREST_CLASSIFIER
) -> None:
try:
self.df_cache = DataFrameRedisCache(
redis_host, redis_port, redis_db, redis_max_connections
)
except redis.exceptions.ConnectionError:
logging.warning("No connection to Redis host")
self.df_cache = None


if os.path.isfile(path):
Expand All @@ -93,6 +99,8 @@ def __init__(
self.detector = detector
self.threshold = threshold
self.order = order
self.model = self.__get_model(model)


def load_data(self, path: str, separator: str) -> pl.DataFrame:
"""Loads data from csv files
Expand Down Expand Up @@ -128,10 +136,13 @@ def load_data(self, path: str, separator: str) -> pl.DataFrame:
(pl.col("query").str.split(".").alias("labels")),
]
)

x = x.filter(pl.col("query").str.len_chars() > 0)
x = x.filter(pl.col("labels").list.len() > 1)

x = x.with_columns(
[
(pl.col("labels").get(-1).alias("tld")),
(pl.col("labels").list.get(-1).alias("tld")),
]
)

Expand Down Expand Up @@ -198,7 +209,25 @@ def run(self):

# Run inspectors to find anomalies in data
config = InspectorConfig(
detector, self.df_cache, self.threshold, joblib.load("model.pkl")
detector, self.df_cache, self.threshold, self.model
)

errors = []
for inspector in ["IP", "Domain"]:
inspector_factory(inspector, config).run(self.data)
errors.append(inspector_factory(inspector, config).run(self.data))

errors_pl: pl.DataFrame = pl.concat(errors)

group_errors_pl = errors_pl.group_by(["client_ip", "fqdn"])
with pl.Config(tbl_rows=100):
logging.warning(group_errors_pl)

def __get_model(self, model_type: Model):
response = requests.get(f"{self.MODELS_URL}/files/?p=%2F{model_type.value}.pkl&dl=1")

response.raise_for_status()

with open(rf'/tmp/{model_type.value}.pkl', 'wb') as f:
f.write(response.content)

return joblib.load(f'/tmp/{model_type.value}.pkl')
Loading

0 comments on commit 3f70b46

Please sign in to comment.