-
Notifications
You must be signed in to change notification settings - Fork 243
/
modeling_fuyu.py
186 lines (158 loc) · 8.73 KB
/
modeling_fuyu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
from transformers import FuyuPreTrainedModel, FuyuConfig, AutoModelForCausalLM
try:
from .modeling_persimmon import PersimmonForCausalLM
print("Using local PersimmonForCausalLM with Flash Attention")
except ImportError:
from transformers import PersimmonForCausalLM
print("Using transformers PersimmonForCausalLM without Flash Attention")
from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutputWithPast
import torch.nn as nn
class FuyuForCausalLM(FuyuPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FuyuDecoderLayer`]
Args:
config: FuyuConfig
"""
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.language_model = PersimmonForCausalLM._from_config(config.text_config)
self.vision_embed_tokens = nn.Linear(config.patch_size * config.patch_size * config.num_channels, config.hidden_size)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
continuous_embeddings: List[torch.Tensor],
image_patch_input_indices: torch.Tensor,
) -> torch.Tensor:
"""This function places the continuous_embeddings into the word_embeddings at the locations
indicated by image_patch_input_indices. Different batch elements can have different numbers of continuous
embeddings.
Args:
word_embeddings: Tensor of word embeddings. Shape: [b, s, h]
continuous_embeddings:
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is
shape [num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
indices in image_patch_input_indices for that batch element.
image_patch_input_indices: Tensor of indices of the image patches in the input_ids tensor. Shape: [b, s]
"""
if not (word_embeddings.shape[0] == len(continuous_embeddings)):
raise ValueError(f"Batch sizes must match! Got {len(continuous_embeddings)=} and {word_embeddings.shape[0]=}")
output_embeddings = word_embeddings.clone()
for batch_idx in range(word_embeddings.shape[0]):
# First, find the positions of all the non-negative values in image_patch_input_indices, those are the
# positions in word_embeddings that we want to replace with content from continuous_embeddings.
dst_indices = torch.nonzero(image_patch_input_indices[batch_idx] >= 0, as_tuple=True)[0]
# Next look up those indices in image_patch_input_indices to find the indices in continuous_embeddings that we
# want to use to replace the values in word_embeddings.
src_indices = image_patch_input_indices[batch_idx][dst_indices]
# Check if we have more indices than embeddings. Note that we could have fewer indices if images got truncated.
if src_indices.shape[0] > continuous_embeddings[batch_idx].shape[0]:
raise ValueError(f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match " f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}.")
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
return output_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
labels: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
# patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype))
patch_embeddings = [self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0) for patch in image_patches]
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,
image_patch_input_indices=image_patches_indices,
)
outputs = self.language_model(
inputs_embeds=inputs_embeds,
labels=labels,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
image_patches=None,
image_patches_indices=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if image_patches_indices is not None:
model_inputs["image_patches_indices"] = image_patches_indices
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"image_patches_indices": image_patches_indices if past_key_values is None else None,
"image_patches": image_patches if past_key_values is None else None,
}
)
return model_inputs