Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collection of various changes #3

Open
wants to merge 5 commits into
base: augment_tfdataset
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ The dataset is expected to be provided in the following structure:
```
path/to/my/dataset/
├── label_colors.txt
├── im1_image.tif
├── im1_label.tif
├── im2_image.tif
├── im2_label.tif
├── im1_image.tif/.vrt
├── im1_label.tif/.vrt
├── im2_image.tif/.vrt
├── im2_label.tif/.vrt
└── ...
```

Expand Down Expand Up @@ -43,24 +43,24 @@ directories.
```
path/to/my/dataset/
├── label_colors.txt
├── im1_image.tif
├── im1_label.tif
├── im2_image.tif
├── im2_label.tif
├── im1_image.tif/.vrt
├── im1_label.tif/.vrt
├── im2_image.tif/.vrt
├── im2_label.tif/.vrt
├── train_images
│ ├── image_0.tif
│ ├── image_1.tif
│ ├── image_2.tif
│ └── image_4.tif
│ ├── image_0.tif/.vrt
│ ├── image_1.tif/.vrt
│ ├── image_2.tif/.vrt
│ └── image_4.tif/.vrt
├── train_masks
│ ├── image_0.tif
│ ├── image_1.tif
│ ├── image_2.tif
│ └── image_4.tif
│ ├── image_0.tif/.vrt
│ ├── image_1.tif/.vrt
│ ├── image_2.tif/.vrt
│ └── image_4.tif/.vrt
├── val_images
│ └── image_3.tif
│ └── image_3.tif/.vrt
└── val_masks
└── image_3.tif
└── image_3.tif/.vrt
```

## Training
Expand Down
10 changes: 8 additions & 2 deletions src/cnn_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,8 +1214,14 @@ def get_tf_dataset(data_dir, batch_size=5, operation='train',
# create variables useful throughout the entire class
nr_samples = len(os.listdir(images_dir))

img_filelist = list(pathlib.Path(images_dir).glob('*.tif'))
mask_filelist = list(pathlib.Path(masks_dir).glob('*.tif'))
img_filelist = sorted(
list(pathlib.Path(images_dir).glob('*.tif'))
+ list(pathlib.Path(images_dir).glob('*.vrt'))
)
mask_filelist = sorted(
list(pathlib.Path(masks_dir).glob('*.tif'))
+ list(pathlib.Path(masks_dir).glob('*.vrt'))
)

image_count = len(img_filelist)
mask_count = len(mask_filelist)
Expand Down
13 changes: 10 additions & 3 deletions src/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@ def generate_dataset_structure(data_dir, tensor_shape=(256, 256),
dir_names = train_val_determination(val_set_pct)

# tile and write samples
source_images = sorted(glob.glob(os.path.join(data_dir, '*image.tif')))
for i in source_images:
tile(i, i.replace('image.tif', 'label.tif'), tensor_shape,
source_images = sorted(
glob.glob(os.path.join(data_dir, '*image.tif'))
+ glob.glob(os.path.join(data_dir, '*image.vrt'))
)
source_label = sorted(
glob.glob(os.path.join(data_dir, '*label.tif'))
+ glob.glob(os.path.join(data_dir, '*label.vrt'))
)
for i in range(len(source_images)):
tile(source_images[i], source_label[i], tensor_shape,
filter_by_class, augment, dir_names, ignore_masks)

# check if there are some training data
Expand Down
14 changes: 10 additions & 4 deletions src/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@
from cnn_exceptions import DatasetError


def main(data_dir, model, in_weights_path, visualization_path, batch_size,
def main(data_dir, label_colors, model, in_weights_path, visualization_path, batch_size,
seed, tensor_shape, force_dataset_generation, fit_memory, val_set_pct,
filter_by_class, backbone=None, ignore_masks=False):
utils.print_device_info()

if ignore_masks is False:
# check if labels are provided
import glob
if len(glob.glob(os.path.join(data_dir, '*label.tif'))) == 0:
if len(
glob.glob(os.path.join(data_dir, '*label.tif'))
+ glob.glob(os.path.join(data_dir, '*label.vrt'))
) == 0:
raise DatasetError('No labels provided in the dataset.')

# get nr of bands
nr_bands = utils.get_nr_of_bands(data_dir)

label_codes, label_names, id2code = utils.get_codings(
os.path.join(data_dir, 'label_colors.txt'))
os.path.join(data_dir, label_colors))

# set TensorFlow seed
tf.random.set_seed(seed)
Expand Down Expand Up @@ -121,6 +124,9 @@ def get_geoinfo(data_dir):
parser.add_argument(
'--data_dir', type=str, required=True,
help='Path to the directory containing images and labels')
parser.add_argument(
"--label_colors", type=str, default="label_colors.txt",
help="Name of label colors txt file (located at top of --data-dir)")
parser.add_argument(
'--model', type=str, default='U-Net',
choices=('U-Net', 'SegNet', 'DeepLab'),
Expand Down Expand Up @@ -185,7 +191,7 @@ def get_geoinfo(data_dir):
raise parser.error(
'Argument validation_set_percentage must be greater or equal to 0')

main(args.data_dir, args.model, args.weights_path, args.visualization_path,
main(args.data_dir, args.label_colors, args.model, args.weights_path, args.visualization_path,
args.batch_size, args.seed, (args.tensor_height, args.tensor_width),
args.force_dataset_generation, args.fit_dataset_in_memory,
args.validation_set_percentage, args.filter_by_classes,
Expand Down
109 changes: 98 additions & 11 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import argparse
import sys

import numpy as np
import tensorflow as tf
Expand All @@ -21,23 +22,85 @@ def rescale_image(input_image, input_mask):

return input_image, input_mask

def load_pretrained_model(model, id2code,
tensor_shape, loss_function, tversky_alpha, tversky_beta,
dropout_rate_input, dropout_rate_hidden, backbone, name,
in_weights_path, model_new,
finetune_old_inp_dim, finetune_old_out_dim):
# if input or output dimension changed w.r.t pretrained model
if finetune_old_inp_dim or finetune_old_out_dim:
if model == "U-Net":
# set dimensions for creating pretrained model
if finetune_old_inp_dim:
nr_bands = finetune_old_inp_dim
if finetune_old_out_dim:
num_class = finetune_old_out_dim
else:
num_class = len(id2code)

# creating model with dimensions of pretrained model
# NOTE: do not set create_model to verbose=False
# --> need once run model.summary() -> otherwise model dimensions are not set
print("------------------------------")
print("-- Start: Dimensions of OLD Model: --")
print("------------------------------")
model_old = create_model(
model, num_class , nr_bands, tensor_shape, nr_filters=32, loss=loss_function,
alpha=tversky_alpha, beta=tversky_beta,
dropout_rate_input=dropout_rate_input,
dropout_rate_hidden=dropout_rate_hidden, backbone=backbone, name=name)
print("----------------------------------")
print("-- End: Dimensions of OLD Model: --")
print("----------------------------------")
# load model weights of pretrained model
model_old.load_weights(in_weights_path)

# Set weights of new model, with weights of pretrained model
# NOTE: model.layers returns list of model layers BUT not necessarily in the correct order
# Thus have to explicitely check for first and last layer index
# Get all layer names:
layer_names = [layer.name for layer in model_new.layers]
# Get layer index of first downsampling block
chlayer_first = model_new.ds_blocks[0].name
ind_chlayer_first = layer_names.index(chlayer_first)
# Get layer index of last layer od model
chlayer_last = "classifier_layer"
ind_chlayer_last = layer_names.index(chlayer_last)
# iterate over all layers to set the weights
for ind in range(0,len(model_new.layers)):
# if input dimension changed, don't set weigts for this layer in new model
if ind == ind_chlayer_first and finetune_old_inp_dim:
continue
# if output dimension changed, don't set weigts for this layer in new model
if ind == ind_chlayer_last and finetune_old_out_dim:
continue
# set weights from pretrained model, for all remaining layers
model_new.layers[ind].set_weights(model_old.layers[ind].get_weights())
else:
sys.exit("ERROR: Change of input or output dimensions w.r.t pretrained models only "
"supported for U-Net so far (parameter --finetune_old_inp_dim or --finetune_old_out_dim)")
else:
# if model dimension did not chainged, load weights from complete model
model_new.load_weights(in_weights_path)

def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
def main(operation, data_dir, label_colors, output_dir , model, model_fn, in_weights_path=None,
visualization_path='/tmp', nr_epochs=1, initial_epoch=0, batch_size=1,
loss_function='dice', seed=1, patience=100, tensor_shape=(256, 256),
monitored_value='val_accuracy', force_dataset_generation=False,
fit_memory=False, augment=False, tversky_alpha=0.5,
tversky_beta=0.5, dropout_rate_input=None, dropout_rate_hidden=None,
val_set_pct=0.2, filter_by_class=None, backbone=None, name='model',
verbose=1):
val_set_pct=0.2, filter_by_class=None, backbone=None,
finetune_old_inp_dim=None, finetune_old_out_dim=None,
name='model', verbose=1,
):
if verbose > 0:
utils.print_device_info()

# get nr of bands
nr_bands = utils.get_nr_of_bands(data_dir)

label_codes, label_names, id2code = utils.get_codings(
os.path.join(data_dir, 'label_colors.txt'))
os.path.join(data_dir, label_colors))

# set TensorFlow seed
if seed is not None:
Expand All @@ -48,7 +111,7 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
tf.keras.utils.set_random_seed(seed)

# tinyunet: nr_filters=32
model = create_model(
model_new = create_model(
model, len(id2code), nr_bands, tensor_shape, nr_filters=32, loss=loss_function,
alpha=tversky_alpha, beta=tversky_beta,
dropout_rate_input=dropout_rate_input,
Expand Down Expand Up @@ -78,9 +141,15 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
num_parallel_calls=tf.data.AUTOTUNE)
.repeat())

# load weights if the model is supposed to do so
# load weights if the model is supposed to do so (i.e. fine-tune mode)
if operation == 'fine-tune':
model.load_weights(in_weights_path)
load_pretrained_model(
model, id2code,
tensor_shape, loss_function, tversky_alpha, tversky_beta,
dropout_rate_input, dropout_rate_hidden, backbone, name,
in_weights_path, model_new,
finetune_old_inp_dim, finetune_old_out_dim
)

#train_generator = AugmentGenerator(
# data_dir, batch_size, 'train', fit_memory=fit_memory,
Expand All @@ -105,7 +174,7 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
.map(Augment())
.prefetch(buffer_size=tf.data.AUTOTUNE))

train(model, train_generator, train_nr_samples, val_generator, val_nr_samples, id2code, batch_size,
train(model_new, train_generator, train_nr_samples, val_generator, val_nr_samples, id2code, batch_size,
output_dir, visualization_path, model_fn, nr_epochs,
initial_epoch, seed=seed, patience=patience,
monitored_value=monitored_value, verbose=verbose)
Expand Down Expand Up @@ -202,6 +271,9 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample
parser.add_argument(
'--data_dir', type=str, required=True,
help='Path to the directory containing images and labels')
parser.add_argument(
"--label_colors", type=str, default="label_colors.txt",
help="Name of label colors txt file (located at top of --data-dir)")
parser.add_argument(
'--output_dir', type=str, required=True, default=None,
help='Path where logs and the model will be saved')
Expand Down Expand Up @@ -294,13 +366,27 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample
'--backbone', type=str, default=None,
choices=('ResNet50', 'ResNet101', 'ResNet152'),
help='Backbone architecture')

parser.add_argument(
"--finetune_old_inp_dim", type=int, default=None,
help="Input dimension of pretrained model, used for finetuning. "
"Set if dimension changed in new/currently trained model."
)
parser.add_argument(
"--finetune_old_out_dim", type=int, default=None,
help="Output dimension of pretrained model, used for finetuning. "
"Set if dimension changed in new/currently trained model."
)
args = parser.parse_args()

# check required arguments by individual operations
if args.operation == 'fine-tune' and args.weights_path is None:
raise parser.error(
'Argument weights_path required for operation == fine-tune')
if (args.finetune_old_inp_dim or args.finetune_old_out_dim) and args.operation != "fine-tune":
raise parser.error(
"Argument operation==fine-tune required for arguments "
"finetune_old_inp_dim or finetune_old_out_dim"
)
if args.operation == 'train' and args.initial_epoch != 0:
raise parser.error(
'Argument initial_epoch must be 0 for operation == train')
Expand All @@ -316,7 +402,7 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample
'Argument validation_set_percentage must be greater or equal to '
'0 and smaller or equal than 1')

main(args.operation, args.data_dir, args.output_dir,
main(args.operation, args.data_dir, args.label_colors, args.output_dir,
args.model, args.model_fn, args.weights_path, args.visualization_path,
args.nr_epochs, args.initial_epoch, args.batch_size,
args.loss_function, args.seed, args.patience,
Expand All @@ -325,4 +411,5 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample
args.augment_training_dataset, args.tversky_alpha,
args.tversky_beta, args.dropout_rate_input,
args.dropout_rate_hidden, args.validation_set_percentage,
args.filter_by_classes, args.backbone)
args.filter_by_classes, args.backbone,
args.finetune_old_inp_dim, args.finetune_old_out_dim)
12 changes: 11 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import glob
import argparse
import sys

import tensorflow as tf

Expand All @@ -24,11 +25,20 @@ def get_codings(description_file):


def get_nr_of_bands(data_dir):
"""Get number of bands in the first *image.tif raster in a directory.
"""Get number of bands in the first *image.tif or first *.vrt raster in a directory.

:param data_dir: directory with images for training or detection
"""
# check for tif
images = glob.glob(os.path.join(data_dir, '*image.tif'))
# if no tif found check for vrt
if not images:
images = glob.glob(os.path.join(data_dir, '*.vrt'))
# if still nothing found return error message
if not images:
sys.exit("ERROR: No *image.tif or *.vrt found on top level of training directory."
"Needed to determine number of bands of input")

dataset_image = gdal.Open(images[0], gdal.GA_ReadOnly)
nr_bands = dataset_image.RasterCount
dataset_image = None
Expand Down