Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Неверные аннотации размерностей в my_multihead_attention в пятом семинаре #12

Open
proshian opened this issue Jun 24, 2023 · 2 comments · May be fixed by #13

Comments

@proshian
Copy link

proshian commented Jun 24, 2023

Я считаю, что в пятом семинаре в функции my_multihead_attention следует заменить ValuesLen на QueriesLen абсолютно везде. Насколько я понимаю, в общем случае (не self-attention) количество запросов не обязательно равно количеству значений. Количество значений обязательно равно количеству ключей (не количеству запросов). Соответственно, вторая размерность resulting_features и attention_map тоже будет = QueriesLen

После исправления функция будет выглядеть так:

def my_multihead_attention(queries, keys, values,
                           keys_padding_mask, dependency_mask,
                           is_training,
                           weights_dropout):
    """
    queries - BatchSize x QueriesLen x HeadN x KeySize
    keys - BatchSize x KeysLen x HeadN x KeySize
    values - BatchSize x KeysLen x HeadN x ValueSize
    keys_padding_mask - BatchSize x KeysLen
    dependency_mask - ValuesLen x KeysLen
    is_training - bool
    weights_dropout - float
    
    result - tuple of two:
        - BatchSize x QueriesLen x HeadN x ValueSize - resulting features
        - BatchSize x QueriesLen x KeysLen x HeadN - attention map
    """

    # BatchSize x QueriesLen x KeysLen x HeadN
    relevances = torch.einsum('bqhs,bkhs->bqkh', (queries, keys))
    
    # замаскировать элементы, выходящие за длины последовательностей ключей
    padding_mask_expanded = keys_padding_mask[:, None, :, None].expand_as(relevances)
    relevances.masked_fill_(padding_mask_expanded, float('-inf'))
    
    # замаскировать пары <выходная позиция, входная позиция>
    relevances = relevances + dependency_mask[None, :, :, None].expand_as(relevances)
    
    normed_rels = F.softmax(relevances, dim=2)    
    normed_rels = F.dropout(normed_rels, weights_dropout, is_training)
    
    # BatchSize x QueriesLen x KeysLen x HeadN x 1
    normed_rels_expanded = normed_rels.unsqueeze(-1)
    
    # BatchSize x 1 x KeysLen x HeadN x ValueSize
    values_expanded = values.unsqueeze(1)
    
    # BatchSize x QueriesLen x KeysLen x HeadN x ValueSize
    weighted_values = normed_rels_expanded * values_expanded
    result = weighted_values.sum(2)  # BatchSize x QueriesLen x HeadN x ValueSize
    
    return result, normed_rels

И на всякий случай визуальное пояснение. Для простоты рассмотрен случай с одной головой и batch_size = 1
визуализация

@proshian proshian linked a pull request Jun 24, 2023 that will close this issue
@proshian
Copy link
Author

Я только что заметил, что там не ValueSize, а ValuesLen. ValuesLen действительно равен QuerySize

@proshian
Copy link
Author

proshian commented Jul 16, 2023

Мне все-таки кажется, что размерности в в функции my_multihead_attention пятого семинара не совсем корректны

Я считаю, что должно быть либо:

queries - BatchSize x QueriesLen x HeadN x KeySize
result - tuple of two:
- BatchSize x QueriesLen x HeadN x ValueSize - resulting features
- BatchSize x QueriesLen x KeysLen x HeadN - attention map


либо:

queries - BatchSize x ResultLen x HeadN x KeySize
result - tuple of two:
- BatchSize x ResultLen x HeadN x ValueSize - resulting features
- BatchSize x ResultLen x KeysLen x HeadN - attention map


Но не как в текущей версии:

queries - BatchSize x ValuesLen x HeadN x KeySize
result - tuple of two:
- BatchSize x ValuesLen x HeadN x ValueSize - resulting features
- BatchSize x ValuesLen x KeysLen x HeadN - attention map

Потому что под ValuesLen логичнее понимать число элементов в матрице Values. То есть
ValuesLen = KesyLen
QueriesLen = ResultLen
ValuesLen != QueriesLen

@proshian proshian reopened this Jul 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant