-
Notifications
You must be signed in to change notification settings - Fork 1
/
pipeline.py
117 lines (92 loc) · 4.75 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# from apache_beam.internal import pickler
# pickler.set_library(pickler.USE_CLOUDPICKLE)
import argparse
import apache_beam as beam
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from image_utils import ReadImagesFromUrl, ReadImagesFromGcsUrl, FormatCaptions
from blip_processing import PreprocessBLIPInput, PostprocessBLIPOutput, BLIPWrapper
from models.blip import blip_decoder
from clip_processing import PreprocessCLIPInput, RankCLIPOutput, CLIPWrapper
from transformers import CLIPProcessor
from transformers import CLIPTokenizer
from transformers import CLIPModel
from transformers import CLIPConfig
from transformers import CLIPFeatureExtractor
from apache_beam.ml.inference.base import RunInference
from model_handler import PytorchModelHandlerKeyedTensor, PytorchNoBatchModelHandlerKeyedTensor
def main(parser, save_main_session=True):
# Increasing Beam search might improve the quality of the captions,
# but also results in more compute time
NUM_BEAMS = 5
# Number of captions generated per image.
NUM_CAPTIONS_PER_IMAGE = 10
# Number of top captions to display.
NUM_TOP_CAPTIONS_TO_DISPLAY = 3
clip_feature_extractor_config_path = '/captioning/clip-vit-base-patch32/preprocessor_config.json'
clip_tokenizer_vocab_config_path = '/captioning/clip-vit-base-patch32/vocab.json'
clip_merges_config_path = '/captioning/clip-vit-base-patch32/merges.txt'
clip_model_config_path = '/captioning/clip-vit-base-patch32/config.json'
clip_state_dict_path = '/captioning/clip-vit-base-patch32/pytorch_model.bin'
known_args, pipeline_args = parser.parse_known_args()
dataset_filename = known_args.dataset_filename
output_filename = known_args.output_filename
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
CLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(
state_dict_path=clip_state_dict_path,
model_class=CLIPWrapper,
model_params={'config': CLIPConfig.from_pretrained(clip_model_config_path)},
device='GPU')
CLIP_keyed_model_handler = KeyedModelHandler(CLIP_model_handler)
blip_state_dict_path = 'blip_state_dict.pth'
MAX_CAPTION_LENGTH = 80
MIN_CAPTION_LENGTH = 10
# Increasing Beam search might improve the quality of the captions,
# but also results in more compute time
NUM_BEAMS = 1
BLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(
state_dict_path=blip_state_dict_path,
model_class=BLIPWrapper,
model_params={'base_model': blip_decoder, 'num_beams': NUM_BEAMS,
'max_length': MAX_CAPTION_LENGTH, 'min_length': MIN_CAPTION_LENGTH},
device='GPU')
BLIP_keyed_model_handler = KeyedModelHandler(BLIP_model_handler)
with beam.Pipeline(options=pipeline_options) as pipeline:
read_images = (
pipeline
| "ReadUrl" >> beam.io.ReadFromText(dataset_filename)
| "ReadImages" >> beam.ParDo(ReadImagesFromGcsUrl()))
blip_caption_generation = (
read_images
| "PreprocessBlipInput" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE))
| "GenerateCaptions" >> RunInference(BLIP_keyed_model_handler)
| "PostprocessCaptions" >> beam.ParDo(PostprocessBLIPOutput()))
clip_captions_ranking = (
({'image' : read_images, 'captions': blip_caption_generation})
| "CreateImageCaptionPair" >> beam.CoGroupByKey()
| "PreprocessClipInput" >> beam.ParDo(
PreprocessCLIPInput(
clip_feature_extractor_config_path,
clip_tokenizer_vocab_config_path,
clip_merges_config_path))
| "GetRankingLogits" >> RunInference(CLIP_keyed_model_handler)
| "RankClipOutput" >> beam.ParDo(RankCLIPOutput()))
results = (clip_captions_ranking | "FormatCaptions" >> beam.ParDo(FormatCaptions(NUM_TOP_CAPTIONS_TO_DISPLAY)))
results | "Write Results" >> beam.io.WriteToText(output_filename)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Pass your arguments")
parser.add_argument(
"--dataset-filename",
type=str,
required=True,
help="Dataset filename location. Ex: gs://<project-id>/dataset.txt"
)
parser.add_argument(
"--output-filename",
type=str,
required=True,
help="Write output training file. Ex: gs://<project-id>/metadata.jsonl"
)
main(parser)