Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data augmentation using tf-image #62

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
107 changes: 107 additions & 0 deletions scripts/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import matplotlib.pyplot as plt
import tensorflow as tf

from tf2_yolov4.anchors import YOLOV4_ANCHORS
from tf2_yolov4.model import YOLOv4

HEIGHT, WIDTH = (608, 608)
INPUT_SHAPE = (HEIGHT, WIDTH, 3)

PASCAL_VOC_CLASSES = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus",
"car", "cat", "chair", "cow", "diningtable", "dog", "horse",
"motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor",
]

COCO_CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'dining table',
'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]

# Switch this variable between PASCAL_VOC_CLASSES and COCO_CLASSES depending
# on your training, or define your own set of classes.
CLASSES = PASCAL_VOC_CLASSES


image = tf.io.read_file("../notebooks/images/cars.jpg")
image = tf.image.decode_image(image)
image = tf.image.resize(image, (HEIGHT, WIDTH))
images = tf.expand_dims(image, axis=0) / 255.0

model = YOLOv4(
input_shape=(HEIGHT, WIDTH, 3),
anchors=YOLOV4_ANCHORS,
num_classes=len(CLASSES),
training=False,
yolo_max_boxes=100,
yolo_iou_threshold=0.5,
yolo_score_threshold=0.15,
)
model.load_weights("../yolov4_full.h5")
model.summary()

boxes, scores, classes, valid_detections = model.predict(images)

# colors for visualization
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933],
]


def plot_results(pil_img, boxes, scores, classes):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()

predictions_with_positive_score = [
(box, score, box_class)
for box, score, box_class in zip(
boxes.tolist(), scores.tolist(), classes.tolist()
)
if score > 0
]
for (xmin, ymin, xmax, ymax), score, cl in predictions_with_positive_score:
color = COLORS[cl % 6]
ax.add_patch(
plt.Rectangle(
(xmin, ymin),
xmax - xmin,
ymax - ymin,
fill=False,
color=color,
linewidth=3,
)
)
text = f"{CLASSES[cl]}: {score:0.2f}"
ax.text(
xmin, ymin, text, color="white",
fontsize=15, fontweight="bold",
bbox=dict(facecolor=color, alpha=0.7),
)
plt.axis("off")
plt.show()


plot_results(
images[0],
boxes[0] * [WIDTH, HEIGHT, WIDTH, HEIGHT],
scores[0],
classes[0].astype(int),
)
154 changes: 154 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Training script for Pascal VOC using tf2-yolov4
"""
from datetime import datetime
from pathlib import Path

import click
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tf2_yolov4.anchors import (
YOLOV4_ANCHORS,
YOLOV4_ANCHORS_MASKS,
compute_normalized_anchors,
)
from tf2_yolov4.datasets import prepare_dataset
from tf2_yolov4.losses import YoloV3Loss
from tf2_yolov4.model import YOLOv4

INPUT_SHAPE = (608, 608, 3)


def launch_training(batch_size, weights_path, all_frozen_epoch_number, backbone_frozen_epoch_number, num_epochs, dataset_name="voc"):
LOG_DIR = Path("./logs") / dataset_name / datetime.now().strftime("%m-%d-%Y %H:%M:%S")

voc_dataset, infos = tfds.load(dataset_name, with_info=True, shuffle_files=True)

ds_train, ds_test = voc_dataset["train"], voc_dataset["validation"]
ds_train = prepare_dataset(
ds_train,
shape=INPUT_SHAPE,
batch_size=batch_size,
shuffle=True,
apply_data_augmentation=True,
transform_to_bbox_by_stage=True,
)
ds_test = prepare_dataset(
ds_test,
shape=INPUT_SHAPE,
batch_size=batch_size,
shuffle=False,
apply_data_augmentation=False,
transform_to_bbox_by_stage=True,
)

steps_per_epoch = infos.splits["train"].num_examples // batch_size
validation_steps = infos.splits["validation"].num_examples // batch_size
num_classes = infos.features["objects"]["label"].num_classes

model = YOLOv4(
input_shape=INPUT_SHAPE,
anchors=YOLOV4_ANCHORS,
num_classes=num_classes,
training=True,
)
if weights_path is not None:
model.load_weights(str(weights_path), by_name=True, skip_mismatch=True)
print("Darknet weights loaded.")

optimizer = tf.keras.optimizers.Adam(1e-4)
normalized_anchors = compute_normalized_anchors(YOLOV4_ANCHORS, INPUT_SHAPE)
loss = [
YoloV3Loss(np.concatenate(list(normalized_anchors), axis=0)[mask])
for mask in YOLOV4_ANCHORS_MASKS
]

# Start training: 5 epochs with backbone + neck frozen
for layer in (
model.get_layer("CSPDarknet53").layers + model.get_layer("YOLOv4_neck").layers
):
layer.trainable = False
model.compile(optimizer=optimizer, loss=loss)
model.fit(
ds_train,
steps_per_epoch=steps_per_epoch,
validation_data=ds_test,
validation_steps=validation_steps,
epochs=all_frozen_epoch_number,
callbacks=[
tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR),
tf.keras.callbacks.ModelCheckpoint(
str(LOG_DIR / "yolov4_all_frozen.h5"),
save_best_only=True,
save_weights_only=True,
monitor="val_loss",
),
],
)

# Keep training: 10 epochs with backbone frozen -- unfreeze neck
for layer in model.get_layer("YOLOv4_neck").layers:
layer.trainable = True
model.compile(optimizer=optimizer, loss=loss)
model.fit(
ds_train,
steps_per_epoch=steps_per_epoch,
validation_data=ds_test,
validation_steps=validation_steps,
epochs=backbone_frozen_epoch_number + all_frozen_epoch_number,
initial_epoch=all_frozen_epoch_number,
callbacks=[
tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR),
tf.keras.callbacks.ModelCheckpoint(
str(LOG_DIR / "yolov4_backbone_frozen.h5"),
save_best_only=True,
save_weights_only=True,
monitor="val_loss",
),
],
)

# Final training
for layer in model.get_layer("CSPDarknet53").layers:
layer.trainable = True
model.compile(optimizer=optimizer, loss=loss)
model.fit(
ds_train,
steps_per_epoch=steps_per_epoch,
validation_data=ds_test,
validation_steps=validation_steps,
epochs=num_epochs,
initial_epoch=all_frozen_epoch_number + backbone_frozen_epoch_number,
callbacks=[
tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR),
tf.keras.callbacks.ModelCheckpoint(
str(LOG_DIR / "yolov4_full.h5"),
save_best_only=True,
save_weights_only=True,
monitor="val_loss",
),
tf.keras.callbacks.ModelCheckpoint(
str(LOG_DIR / "yolov4_train_loss.h5"),
save_best_only=True,
save_weights_only=True,
monitor="loss",
),
],
)


@click.command()
@click.option("--batch_size", type=int, default=16, help="Size of mini-batch")
@click.option("--weights_path", type=click.Path(exists=True), default=None, help="Path to pretrained weights")
@click.option("--all_frozen_epoch_number", type=int, default=20, help="Number of epochs to perform with backbone and neck frozen")
@click.option("--backbone_frozen_epoch_number", type=int, default=10, help="Number of epochs to perform with backbone frozen")
@click.option("--num_epochs", type=int, default=50, help="Total number of epochs to perform")
@click.option("--dataset_name", type=str, default="voc", help="Dataset used during training. Refer to TensorFlow Datasets documentation for dataset names.")
def launch_training_command(batch_size, weights_path, all_frozen_epoch_number, backbone_frozen_epoch_number, num_epochs, dataset_name):
launch_training(batch_size, weights_path, all_frozen_epoch_number, backbone_frozen_epoch_number, num_epochs, dataset_name)


if __name__ == "__main__":
launch_training_command()
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_model_should_predict_valid_shapes_at_inference(

@pytest.mark.parametrize("input_shape", [(32, 33, 3), (33, 32, 3)])
def test_model_instanciation_should_fail_with_input_shapes_not_multiple_of_32(
input_shape
input_shape,
):
with pytest.raises(ValueError):
YOLOv4(input_shape, 80, [])
2 changes: 2 additions & 0 deletions tf2_yolov4/anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
np.array([(142, 110), (192, 243), (459, 401)], np.float32),
]

YOLOV4_ANCHORS_MASKS = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]

YOLOV3_ANCHORS = [
np.array([(10, 13), (16, 30), (33, 23)], np.float32),
np.array([(30, 61), (62, 45), (59, 119)], np.float32),
Expand Down
Loading