Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLIP text model #643

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Add CLIP text model #643

wants to merge 2 commits into from

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Dec 4, 2024

Ports the CLIP text model from Hugging Face. This is the first iteration so not much is changed from the original model. Things like dropout and checkpointing are removed.
Add numeric verification tests for the various components of the stack when executing in eager mode. Verifications are made for float32 and bfloat16. There are tests for toy-sized components and the whole model as well as the Large pretrained variant.
These tests does not include testing with IREE.

Functionalities for mask creation are not yet ported.

Ports the CLIP text model from Hugging Face.
Add numeric verification tests for the various components of the stack
when executing in eager mode. Verifications are made for float32 and
bfloat16. There are tests for toy-sized components and the whole model
as well as the Large pretrained variant.
These tests does not include testing with IREE.

Functionalities for mask creation are not yet ported.
@sogartar sogartar marked this pull request as ready for review December 4, 2024 02:16
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)
else:
self.bias = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't bother if-else ing. Just set self.bias = None prior to the if statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use unsqueeze instead of expand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to mention in the PR description that this is the initial port of the model and I did not attempt to optimize anything. The main goal was to put it under test. I would rather do these modifications later as they would be tracked in different commits and will be clear what changed compared to the original.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're directly referencing the huggingface code for implementation, do you want to link it in a comment?

f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using ** for inverse powers. Its better to do 1.0 / math.sqrt(self.head_dim). The numerical precision on pow operators is usually significantly worse.

)

# apply the causal_attention_mask first
if causal_attention_mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you separate the causal attention mask from the regular attention mask? They should just occur together. Even the causal attention mask should really just be a bool.

)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

attn_weights = ops.softmax(attn_weights, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than decomposing we should use the ops.scaled_dot_product_attention operation. Attention is attention, so we should avoid replicating the decomposed version everywhere.

return_dict if return_dict is not None else self.config.use_return_dict
)

encoder_states = () if output_hidden_states else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic of empty tuple vs none and adding feels unclear what it actually is attempting to do. Relying on none + tuple feels weird

def size(self, dim: Optional[int] = None) -> tuple[int]:
if dim is None:
return tuple(self.shape)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No else condition required. Just include the return when the condition is not taken.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return embeddings


class ClipAttention(BaseLayer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a quick glance is seems much of this is reused from the punet Attention (outside the decomposed SDPA which we should try avoiding) from here

class AttentionLayer(ThetaLayer):
. Can we just extend and reuse that implementation?

last_hidden_state = self.final_layer_norm(last_hidden_state)

if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR I assume is reference to diffusers? Confusing to have in our repo


@with_clip_data
def testSmokeExportLargeF32FromHuggingFace(self):
repo_id = "openai/clip-vit-large-patch14"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to avoid downloading full models for developers running locally. Should we implement a toy model to accommodate? Thoughts @rsuderman?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants