Skip to content

Commit

Permalink
fix(ivy): Continues f431947; typehint corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTz committed Sep 5, 2023
1 parent f431947 commit 615d1bc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion ivy/data_classes/array/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def multi_head_attention(
average_attention_weights: bool = True,
dropout: float = 0.0,
training: bool = False,
out: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
return ivy.multi_head_attention(
self._data,
Expand Down
7 changes: 3 additions & 4 deletions ivy/data_classes/container/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,12 +1059,11 @@ def _static_multi_head_attention(
average_attention_weights: Union[bool, ivy.Container] = True,
dropout: Union[float, ivy.Container] = 0.0,
training: Union[bool, ivy.Container] = False,
key_chains: Optional[
Union[List[str], Dict[str, str], ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
out: Optional[Union[ivy.Array, ivy.Container]] = None,
) -> ivy.Container:
return ContainerBase.cont_multi_map_in_function(
"multi_head_attention",
Expand Down Expand Up @@ -1132,7 +1131,7 @@ def multi_head_attention(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[Union[ivy.Array, ivy.NativeArray, ivy.Container]] = None,
out: Optional[Union[ivy.Array, ivy.Container]] = None,
) -> ivy.Container:
return self._static_multi_head_attention(
self,
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def multi_head_attention(
key: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
value: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
num_heads: int = 8,
scale: float = None,
scale: Optional[float] = None,
attention_mask: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
in_proj_weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
q_proj_weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
Expand All @@ -729,7 +729,7 @@ def multi_head_attention(
average_attention_weights: bool = True,
dropout: float = 0.0,
training: bool = False,
out: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
out: Optional[ivy.Array] = None,
) -> Union[ivy.Array, ivy.NativeArray]:
"""
Apply multi-head attention to inputs x. This is an implementation of multi-headed
Expand Down

0 comments on commit 615d1bc

Please sign in to comment.