-
Notifications
You must be signed in to change notification settings - Fork 30
/
detector_mean.py
77 lines (63 loc) · 2.64 KB
/
detector_mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for Mean and Weighted Mean scoring functions."""
from typing import Optional
import jax.numpy as jnp
def mean_score(
g_values: jnp.ndarray,
mask: jnp.ndarray,
) -> jnp.ndarray:
"""Computes the Mean score.
Args:
g_values: g-values of shape [batch_size, seq_len, watermarking_depth].
mask: A binary array shape [batch_size, seq_len] indicating which g-values
should be used. g-values with mask value 0 are discarded.
Returns:
Mean scores, of shape [batch_size]. This is the mean of the unmasked
g-values.
"""
watermarking_depth = g_values.shape[-1]
num_unmasked = jnp.sum(mask, axis=1) # shape [batch_size]
return jnp.sum(g_values * jnp.expand_dims(mask, 2), axis=(1, 2)) / (
watermarking_depth * num_unmasked
)
def weighted_mean_score(
g_values: jnp.ndarray,
mask: jnp.ndarray,
weights: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""Computes the Weighted Mean score.
Args:
g_values: g-values of shape [batch_size, seq_len, watermarking_depth].
mask: A binary array shape [batch_size, seq_len] indicating which g-values
should be used. g-values with mask value 0 are discarded.
weights: array of non-negative floats, shape [watermarking_depth]. The
weights to be applied to the g-values. If not supplied, defaults to
linearly decreasing weights from 10 to 1.
Returns:
Weighted Mean scores, of shape [batch_size]. This is the mean of the
unmasked g-values, re-weighted using weights.
"""
watermarking_depth = g_values.shape[-1]
if weights is None:
weights = jnp.linspace(start=10, stop=1, num=watermarking_depth)
# Normalise weights so they sum to watermarking_depth.
weights *= watermarking_depth / jnp.sum(weights)
# Apply weights to g-values.
g_values *= jnp.expand_dims(weights, axis=(0, 1))
num_unmasked = jnp.sum(mask, axis=1) # shape [batch_size]
return jnp.sum(g_values * jnp.expand_dims(mask, 2), axis=(1, 2)) / (
watermarking_depth * num_unmasked
)