diff --git a/cleverhans/tf2/attacks/carlini_wagner_l2.py b/cleverhans/tf2/attacks/carlini_wagner_l2.py index 60d97ec65..53e528cc9 100644 --- a/cleverhans/tf2/attacks/carlini_wagner_l2.py +++ b/cleverhans/tf2/attacks/carlini_wagner_l2.py @@ -150,6 +150,17 @@ def _attack(self, x): lower_bound = tf.zeros(shape[:1]) upper_bound = tf.ones(shape[:1]) * 1e10 + # manually broadcast + def explicit_broadcast(tensor): + while len(tensor.shape) < len(shape): + tensor = tf.expand_dims(tensor, -1) + return tensor + + lower_bound = explicit_broadcast(lower_bound) + upper_bound = explicit_broadcast(upper_bound) + + assert len(lower_bound.shape) == len(lower_bound.shape) == len(shape) + const = tf.ones(shape) * self.initial_const # placeholder variables for best values @@ -230,7 +241,7 @@ def _attack(self, x): # mask is of shape [batch_size]; best_attack is [batch_size, image_size] # need to expand - mask = tf.reshape(mask, [-1, 1, 1, 1]) + mask = tf.reshape(mask, [-1] + [1] * (shape.ndims-1)) mask = tf.tile(mask, [1, *best_attack.shape[1:]]) best_attack = set_with_mask(best_attack, x_new, mask) @@ -244,6 +255,8 @@ def _attack(self, x): compare_fn(best_score, lab), tf.not_equal(best_score, -1), ) + upper_mask = explicit_broadcast(upper_mask) + upper_bound = set_with_mask( upper_bound, tf.math.minimum(upper_bound, const), upper_mask ) @@ -332,7 +345,10 @@ def loss_fn( # sum up losses loss_2 = tf.reduce_sum(l2_dist) - loss_1 = tf.reduce_sum(const * loss_1) + if len(loss_1.shape) == 1 and loss_1.shape[0] == const.shape[0]: + loss_1 = tf.reduce_sum(tf.transpose(const) * loss_1) + else: + loss_1 = tf.reduce_sum(const * loss_1) loss = loss_1 + loss_2 return loss, l2_dist diff --git a/cleverhans/tf2/attacks/spsa.py b/cleverhans/tf2/attacks/spsa.py index c900d5f6e..f50833bdd 100644 --- a/cleverhans/tf2/attacks/spsa.py +++ b/cleverhans/tf2/attacks/spsa.py @@ -1,6 +1,7 @@ # pylint: disable=missing-docstring import tensorflow as tf +import inspect tf_dtype = tf.as_dtype("float32") @@ -59,6 +60,7 @@ def loss_fn(x, label): Margin logit loss, with correct sign for targeted vs untargeted loss. """ logits = model_fn(x) + logits = tf.cast(logits, tf_dtype) loss_multiplier = 1 if targeted else -1 return loss_multiplier * margin_logit_loss( logits, label, nb_classes=logits.get_shape()[-1] @@ -95,13 +97,21 @@ def __init__( num_iters=1, compare_to_analytic_grad=False, ): - super(SPSAAdam, self).__init__(lr=lr) + lr_long_name = 'learning_rate' in inspect.signature(tf.optimizers.Adam).parameters + lr_key = 'learning_rate' if lr_long_name else 'lr' + super(SPSAAdam, self).__init__(** { lr_key : lr }) assert num_samples % 2 == 0, "number of samples must be even" self._delta = delta self._num_samples = num_samples // 2 # Since we mirror +/- delta later self._num_iters = num_iters self._compare_to_analytic_grad = compare_to_analytic_grad + def _get_lr(self): + if hasattr(self, 'learning_rate'): + return self.learning_rate + else: + return self.lr + def _get_delta(self, x, delta): x_shape = x.get_shape().as_list() delta_x = delta * tf.sign( @@ -190,7 +200,7 @@ def _apply_gradients(self, grads, x, optim_state): new_optim_state["u"][i] = self.beta_2 * u_old + (1.0 - self.beta_2) * g * g m_hat = new_optim_state["m"][i] / (1.0 - tf.pow(self.beta_1, t)) u_hat = new_optim_state["u"][i] / (1.0 - tf.pow(self.beta_2, t)) - new_x[i] = x[i] - self.lr * m_hat / (tf.sqrt(u_hat) + self.epsilon) + new_x[i] = x[i] - self._get_lr() * m_hat / (tf.sqrt(u_hat) + self.epsilon) return new_x, new_optim_state def init_state(self, x):