Skip to content

Commit

Permalink
Merge pull request #7 from yunzheng/main
Browse files Browse the repository at this point in the history
Output scanning stats and version information
  • Loading branch information
yunzheng authored Dec 15, 2021
2 parents 0d4c8c9 + 3fb31cc commit 52c6e2a
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions log4j-finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import io
import sys
import time
import zipfile
import logging
import argparse
Expand All @@ -29,13 +30,23 @@

from pathlib import Path

__version__ = "1.0.1"
FIGLET = f"""\
__ _____ __ ___ __ __
| |.-----.-----.| | ||__|______.' _|__|.-----.--| |.-----.----.
| || _ | _ ||__ | |______| _| || | _ || -__| _|
|__||_____|___ | |__|| | |__| |__||__|__|_____||_____|__|
|_____| |___| v{__version__} https://github.com/fox-it/log4j-finder
"""

# Optionally import colorama to enable colored output for Windows
try:
import colorama

colorama.init()
NO_COLOR = False
except ImportError:
NO_COLOR = True if sys.platform == 'win32' else False
NO_COLOR = True if sys.platform == "win32" else False

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,15 +93,17 @@ def md5_digest(fobj):
return d.hexdigest()


def iter_scandir(path):
def iter_scandir(path, stats=None):
"""
Yields all files matcthing JAR_EXTENSIONS or FILENAMES recursively in path
"""
p = Path(path)
if p.is_file():
if stats:
stats["files"] += 1
yield p
try:
for entry in scantree(path):
for entry in scantree(path, stats=stats):
if entry.is_symlink():
continue
elif entry.is_file():
Expand All @@ -103,20 +116,24 @@ def iter_scandir(path):
log.debug(e)


def scantree(path):
def scantree(path, stats=None):
"""Recursively yield DirEntry objects for given directory."""
try:
with os.scandir(path) as it:
for entry in it:
if entry.is_dir(follow_symlinks=False):
yield from scantree(entry.path)
if stats:
stats["directories"] += 1
yield from scantree(entry.path, stats=stats)
else:
if stats:
stats["files"] += 1
yield entry
except IOError as e:
log.debug(e)


def iter_jarfile(fobj, parents=None):
def iter_jarfile(fobj, parents=None, stats=None):
"""
Yields (zfile, zinfo, zpath, parents) for each file in zipfile that matches `FILENAMES` or `JAR_EXTENSIONS` (recursively)
"""
Expand All @@ -133,9 +150,9 @@ def iter_jarfile(fobj, parents=None):
zfile.open(zinfo.filename), parents=parents + [zpath]
)
except IOError as e:
log.debug(f"{fobj}: {e}", e)
log.debug(f"{fobj}: {e}")
except zipfile.BadZipFile as e:
log.debug(f"{fobj}: {e}", e)
log.debug(f"{fobj}: {e}")


def red(s):
Expand Down Expand Up @@ -210,6 +227,7 @@ def main():
parser.add_argument(
"-n", "--no-color", action="store_true", help="disable color output"
)
parser.add_argument("-b", "--no-banner", action="store_true", help="disable banner")
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
Expand All @@ -226,20 +244,29 @@ def main():
NO_COLOR = True

stats = {
"scanned": 0,
"files": 0,
"directories": 0,
"vulnerable": 0,
"good": 0,
"unknown": 0,
}
start_time = time.monotonic()

if not args.no_banner:
print(FIGLET)
for directory in args.path:
print("Scanning:", directory)
for p in iter_scandir(directory):
print(f"[{datetime.datetime.utcnow()}] Scanning: {directory}")
for p in iter_scandir(directory, stats=stats):
if p.name.lower() in FILENAMES:
stats["scanned"] += 1
log.info(f"Found file: {p}")
with p.open("rb") as fobj:
check_vulnerable(fobj, [p], stats)
if p.suffix.lower() in JAR_EXTENSIONS:
try:
log.info(f"Found jar file: {p}")
stats["scanned"] += 1
for (zinfo, zfile, zpath, parents) in iter_jarfile(
p.resolve().open("rb"), parents=[p.resolve()]
):
Expand All @@ -249,13 +276,21 @@ def main():
except IOError as e:
log.debug(f"{p}: {e}", e)

elapsed_time = time.monotonic() - start_time
print(
f"[{datetime.datetime.utcnow()}] Finished scan, elapsed time: {elapsed_time:.2f} seconds"
)

print("\nSummary:")
print(f" Processed {stats['files']} files and {stats['directories']} directories")
print(f" Scanned {stats['scanned']} files")
if stats["vulnerable"]:
print(" Found {} vulnerable files".format(stats["vulnerable"]))
print(" Found {} vulnerable files".format(stats["vulnerable"]))
if stats["good"]:
print(" Found {} good files".format(stats["good"]))
print(" Found {} good files".format(stats["good"]))
if stats["unknown"]:
print(" Found {} unknown files".format(stats["unknown"]))
print(" Found {} unknown files".format(stats["unknown"]))
print(f"\nElapsed time: {elapsed_time:.2f} seconds ")


if __name__ == "__main__":
Expand Down

0 comments on commit 52c6e2a

Please sign in to comment.