Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Shashank/flexattention #1675

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f416539
adding flex attention
ShashankMosaicML Nov 18, 2024
ac3a884
registrifying score mods
ShashankMosaicML Nov 18, 2024
31b27e2
registrifying attention mask mods
ShashankMosaicML Nov 18, 2024
c8fffa5
Merge branch 'mosaicml:main' into shashank/flexattention
ShashankMosaicML Nov 18, 2024
86dce3b
bug_fix
ShashankMosaicML Nov 19, 2024
cb8f4a6
bug_fix
ShashankMosaicML Nov 19, 2024
902850a
lint
ShashankMosaicML Nov 19, 2024
9c9708d
configuring test
ShashankMosaicML Nov 19, 2024
f1ff430
configuring tests
ShashankMosaicML Nov 19, 2024
e537f5a
bug fix
ShashankMosaicML Nov 19, 2024
c527dd7
fixing alibi
ShashankMosaicML Nov 19, 2024
15e303e
Merge branch 'mosaicml:main' into shashank/flexattention
ShashankMosaicML Nov 19, 2024
c4ef5d9
configuring further tests
ShashankMosaicML Nov 19, 2024
6b37427
refactoring
ShashankMosaicML Nov 19, 2024
e30fe7a
adding warnings and errors
ShashankMosaicML Nov 19, 2024
924a53c
gating tests on torch version
ShashankMosaicML Nov 19, 2024
57048e3
Merge branch 'mosaicml:main' into shashank/flexattention
ShashankMosaicML Nov 19, 2024
67a2aea
reorganizing function defs
ShashankMosaicML Nov 19, 2024
04f3a62
refactoring
ShashankMosaicML Nov 19, 2024
ab6c58c
passing in dicts of mask and score mods
ShashankMosaicML Nov 19, 2024
3b3827d
making mask and score mods configurable via yaml
ShashankMosaicML Nov 19, 2024
be43e8d
Merge branch 'mosaicml:main' into shashank/flexattention
ShashankMosaicML Nov 19, 2024
2264f91
adding torch.compile
ShashankMosaicML Nov 20, 2024
e274d9f
..
ShashankMosaicML Nov 20, 2024
a26bb4f
..
ShashankMosaicML Nov 20, 2024
d5ab7d3
undoing comment out
ShashankMosaicML Nov 20, 2024
d40e978
Merge branch 'mosaicml:main' into shashank/flexattention
ShashankMosaicML Nov 26, 2024
5f13e7b
adding torch comile
ShashankMosaicML Nov 26, 2024
ca8e173
temporary commit commenting out block mask and score mod
ShashankMosaicML Nov 26, 2024
f5486ff
undoing prev temp commit
ShashankMosaicML Nov 26, 2024
fdced3a
Merge branch 'mosaicml:main' into shashank/flexattention
ShashankMosaicML Nov 26, 2024
c53db63
speeding up block mask generation
ShashankMosaicML Nov 27, 2024
ec5900d
precompilining create block mask
ShashankMosaicML Nov 27, 2024
02ad3b6
minor
ShashankMosaicML Nov 27, 2024
13a5fc8
compiling mask and flex attn once for the entire model
ShashankMosaicML Nov 27, 2024
2ae6027
..
ShashankMosaicML Nov 27, 2024
0c5150a
..
ShashankMosaicML Nov 27, 2024
ff28304
making sequence id transforms configurable
ShashankMosaicML Nov 27, 2024
23ba20f
..
ShashankMosaicML Nov 27, 2024
72c45ae
..
ShashankMosaicML Nov 27, 2024
73066a4
..
ShashankMosaicML Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,63 @@
description=_attention_implementations_description,
)

_flex_attention_score_mods_description = (
"""The flex_attention_score_mods registry is used to register functions that implement flex attention score mods.

One example is 'alibi'. See attention.py for examples.

Args:
kwargs: Dict[str, Any]: Additional keyword arguments the implementation accepts.
Returns:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tensor]: The score mod function (see https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py)
"""
)
flex_attention_score_mods = create_registry(
'llmfoundry',
'flex_attention_score_mods',
generic_type=Callable,
entry_points=True,
description=_flex_attention_score_mods_description,
)

_flex_attention_mask_mods_description = (
"""The flex_attention_masks registry is used to register functions that implement flex attention mask mods.

One example is 'sequence_id'. See attention.py for examples.

Args:
kwargs: Dict[str, Any]: Additional keyword arguments the implementation accepts.
Returns:
Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]: The mask mod function (see https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py)
"""
)
flex_attention_mask_mods = create_registry(
'llmfoundry',
'flex_attention_mask_mods',
generic_type=Callable,
entry_points=True,
description=_flex_attention_mask_mods_description,
)

_sequence_id_transformer_registry = (
"""The sequence_id_transformer_registry registry is used to register functions that implement sequence id transformations.

One example is 'attention_mask_in_length' in modeling_mpt.py.

Args:
torch.Tensor: The sequence id tensor.
Returns:
Any: The sequence id transformed.
"""
)
sequence_id_transformer_registry = create_registry(
'llmfoundry',
'sequence_id_transformer_registry',
generic_type=Callable,
entry_points=True,
description=_sequence_id_transformer_registry,
)

_param_init_fns_description = (
"""The param_init_fns registry is used to register functions that initialize parameters.

Expand Down Expand Up @@ -231,5 +288,8 @@
'ffns_with_megablocks',
'attention_classes',
'attention_implementations',
'flex_attention_score_mods',
'flex_attention_mask_mods',
'sequence_id_transformer_registry',
'fcs',
]
Loading
Loading